Chinese-CLIP 模型代码笔记:数据输入 / Notes about the Chinese-CLIP model code:inputs

7月 29, 2024·
Junhong Liu
Junhong Liu
· 3 分钟阅读时长

本博客所展示的模型来自论文《Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese》,源码的GitHub 网址

1. 数据格式

在将数据输入到模型前,想要将数据处理成统一的模型能接受的格式。在论文所展示的代码中,原始数据文件的格式如下所示:

# 文件名:train_imgs.tsv
# 数据类型:图片
# 组织形式:{商品图片id} + \t + {图片base64编码内容}  
1000002	/9j/4AAQSkZJ...YQj7314oA//2Q==  
1000016	/9j/4AAQSkZJ...SSkHcOnegD/2Q==  
1000033	/9j/4AAQSkZJ...FhRRRWx4p//2Q==  
# 文件名:train_texts.jsonl
# 数据类型:文本
{"query_id": 8426, "query_text": "胖妹妹松紧腰长裤", "item_ids": [42967]}  
{"query_id": 8427, "query_text": "大码长款棉麻女衬衫", "item_ids": [63397]}  
{"query_id": 8428, "query_text": "高级感托特包斜挎", "item_ids": [1076345, 517602]}  

2. 代码解读

在运行模型前需先运行 data_process.sh 脚本来将上述格式的数据转化为模型能接受的数据输入格式,该脚本调用了 python 解释器运行 build_lmdb_dataset.py 文件。

python /raid/ljh/home/game/frame/Chinese-CLIP-master/cn_clip/preprocess/build_lmdb_dataset.py \
    --data_dir /raid/ljh/home/game/datasets/conv \
    --splits train # train,valid,test

data_process.sh 脚本主要用来启动 build_lmdb_dataset.py 文件并传入参数,所以现在我们着重来分析 build_lmdb_dataset.py 文件。build_lmdb_dataset.py 文件主要用来对原始数据进行处理,将其拆分成多个对应的图文对,并将其保存为 LMDB 用来加快数据读取速度。

# -*- coding: utf-8 -*-
'''
This script serializes images and image-text pair annotations into LMDB files,
which supports more convenient dataset loading and random access to samples during training 
compared with TSV and Jsonl data files.
'''

import argparse
import os
from tqdm import tqdm
import lmdb
import json
import pickle


def parse_args():
    # 创建 ArgumentParser() 参数对象 parser
    parser = argparse.ArgumentParser()
    # 调用 add_argument() 方法往参数对象 parser 添加参数
    parser.add_argument(
        "--data_dir", type=str, required=True, help="the directory which stores the image tsvfiles and the text jsonl annotations"
    )
    parser.add_argument(
        "--splits", type=str, required=True, help="specify the dataset splits which this script processes, concatenated by comma \
            (e.g. train,valid,test)"
    )
    parser.add_argument(
        "--lmdb_dir", type=str, default=None, help="specify the directory which stores the output lmdb files. \
            If set to None, the lmdb_dir will be set to {args.data_dir}/lmdb"
    )
    return parser.parse_args() # 运行解析器并将提取的数据放入 argparse.Namespace 对象


if __name__ == "__main__":
    args = parse_args()
    # assert(断言),条件为 false 时执行
    assert os.path.isdir(args.data_dir), "The data_dir does not exist! Please check the input args..."

    # read specified dataset splits 
    # str.strip() 函数,删除字符串首尾的字符,默认包括("\n","\r","\t"," ")
    # str.split() 函数,按字符切分
    specified_splits = list(set(args.splits.strip().split(",")))
    # str.format() 格式化函数
    # str.join(sequence) sequence -- 要连接的元素序列,返回通过指定字符连接序列中元素后生成的新字符串。
    print("Dataset splits to be processed: {}".format(", ".join(specified_splits)))

    # build LMDB data files
    if args.lmdb_dir is None:
        # os.path.join() 拼接文件路径
        args.lmdb_dir = os.path.join(args.data_dir, "lmdb")
    for split in specified_splits:
        # open new LMDB files
        lmdb_split_dir = os.path.join(args.lmdb_dir, split)
        if os.path.isdir(lmdb_split_dir):
            print("We will overwrite an existing LMDB file {}".format(lmdb_split_dir))
        # os.makedirs()创建目录,exist_ok=True 路径存在时不报错
        os.makedirs(lmdb_split_dir, exist_ok=True)
        lmdb_img = os.path.join(lmdb_split_dir, "imgs")
        # lmdb.open() 打开一个现有或生成一个空lmab数据库
        # map_size定义最大储存容量,单位是kb,以下定义1TB容量 
        env_img = lmdb.open(lmdb_img, map_size=1024**4)
        # 建立事务,参数write设置为True才可以写入 
        txn_img = env_img.begin(write=True)
        lmdb_pairs = os.path.join(lmdb_split_dir, "pairs")
        env_pairs = lmdb.open(lmdb_pairs, map_size=1024**4)
        txn_pairs = env_pairs.begin(write=True)

        # write LMDB file storing (image_id, text_id, text) pairs
        pairs_annotation_path = os.path.join(args.data_dir, "{}_texts.jsonl".format(split))
        with open(pairs_annotation_path, "r", encoding="utf-8") as fin_pairs:
            write_idx = 0
            for line in tqdm(fin_pairs):
                line = line.strip()
                obj = json.loads(line)

                ### Modified
                # for field in ("text_id", "text", "image_ids"):
                for field in ("query_id", "query_text", "item_ids"):
                    assert field in obj, "Field {} does not exist in line {}. \
                        Please check the integrity of the text annotation Jsonl file."
                    
                ### Modified
                # for image_id in obj["image_ids"]:
                    # dump = pickle.dumps((image_id, obj['text_id'], obj['text'])) # encoded (image_id, text_id, text)
                for image_id in obj["item_ids"]:
                    # pickle.dumps:将 python 对象序列化,用于传输和保存,便于不同 python 程序间对象传输
                    dump = pickle.dumps((image_id, obj['query_id'], obj['query_text']))
                    # .put() 对数据进行插入和修改
                    txn_pairs.put(key="{}".format(write_idx).encode('utf-8'), value=dump)  
                    write_idx += 1
                    if write_idx % 5000 == 0:
                        # .commit() 提交更改
                        txn_pairs.commit()
                        txn_pairs = env_pairs.begin(write=True)
            txn_pairs.put(key=b'num_samples',
                    value="{}".format(write_idx).encode('utf-8'))
            txn_pairs.commit()
            env_pairs.close()
        print("Finished serializing {} {} split pairs into {}.".format(write_idx, split, lmdb_pairs))

        # write LMDB file storing image base64 strings
        base64_path = os.path.join(args.data_dir, "{}_imgs.tsv".format(split))
        with open(base64_path, "r", encoding="utf-8") as fin_imgs:
            write_idx = 0
            for line in tqdm(fin_imgs):
                line = line.strip()
                image_id, b64 = line.split("\t")
                txn_img.put(key="{}".format(image_id).encode('utf-8'), value=b64.encode("utf-8"))
                write_idx += 1
                if write_idx % 1000 == 0:
                    txn_img.commit()
                    txn_img = env_img.begin(write=True)
            txn_img.put(key=b'num_images',
                    value="{}".format(write_idx).encode('utf-8'))
            txn_img.commit()
            env_img.close()                
        print("Finished serializing {} {} split images into {}.".format(write_idx, split, lmdb_img))

    print("done!")

该模块较简单,我们直接来看这条语句 if __name__ == "__main__":。以往的面向对象的编程语言都有一个 “main” 函数作为程序的入口, 但 Python 的不同之处在于 Python 的源码文件(.py)除了可以作为程序入口被直接运行外,还可以作为模块被其他程序调用。所以 if __name__ == "__main__": 的作用是当该文件被直接运行时,if __name__ == "__main__": 所包含的代码块将被运行;当该文件被以模块的形式被导入时,if __name__ == "__main__": 所包含的代码块则不会被运行。