Chinese-CLIP 模型代码笔记:模型微调 / Notes about the Chinese-CLIP model code

7月 29, 2024·
Junhong Liu
Junhong Liu
· 7 分钟阅读时长

本博客所展示的模型来自论文《Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese》,源码的GitHub 网址

1. Bash 脚本

以下是运行模型微调程序的 bash 脚本,其主要用于向运行程序传入参数,通过参数名能较好地理解参数地用处,所以就不介绍了。

finetune.sh
#!/usr/bin/env

# Guide:
# This script supports distributed training on multi-gpu workers (as well as single-worker training). 
# Please set the options below according to the comments. 
# For multi-gpu workers training, these options should be manually set for each worker. 
# After setting the options, please run the script on each worker.
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}

### PAY ATTENTION
# 运行时请 cd 到python代码文件所在路径

# Number of GPUs per GPU worker
GPUS_PER_NODE=8
# Number of GPU workers, for single-worker training, please set to 1
WORKER_CNT=1
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
export MASTER_ADDR=localhost
# The port for communication
export MASTER_PORT=8510
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
export RANK=0

export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/

DATAPATH=/raid/ljh/home/game

# data options
train_data=${DATAPATH}/datasets/processed_data/lmdb/train #train
val_data=${DATAPATH}/datasets/processed_data/lmdb/valid # if val_data is not specified, the validation will be automatically disabled

# restore options
resume=${DATAPATH}/experiments/256batch_4epoch_processed_data/checkpoints/epoch2.pt # or specify your customed ckpt path to resume
reset_data_offset="--reset-data-offset"
reset_optimizer="--reset-optimizer"
# reset_optimizer=""

# output options
output_base_dir=${DATAPATH}/experiments/
name=256batch_6epoch_processed_data_freeze_vision
save_step_frequency=999999 # disable it
save_epoch_frequency=1
log_interval=1
report_training_batch_acc="--report-training-batch-acc"
# report_training_batch_acc=""

# training hyper-params
context_length=60
warmup=200
batch_size=256
valid_batch_size=256
accum_freq=1
lr=2e-5
wd=0.001
max_epochs=6 # or you can alternatively specify --max-steps
max_steps=999999 # disable it
valid_step_interval=999999 # disable it
valid_epoch_interval=1
vision_model=ViT-H-14
text_model=RoBERTa-wwm-ext-large-chinese
mask_ratio=0 # use flip: set mask ratio
use_augment="--use-augment"
# use_augment=""

# python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
# --use-flash-attention 加速计算
# --grad-checkpointing 不保存每轮都保存checkpoint
# --freeze-vision 冻结视觉模块
torchrun --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
          --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
          --train-data=${train_data} \
          --val-data=${val_data} \
          --resume=${resume} \
          ${reset_data_offset} \
          ${reset_optimizer} \
          --logs=${output_base_dir} \
          --name=${name} \
          --save-step-frequency=${save_step_frequency} \
          --save-epoch-frequency=${save_epoch_frequency} \
          --log-interval=${log_interval} \
          ${report_training_batch_acc} \
          --context-length=${context_length} \
          --warmup=${warmup} \
          --batch-size=${batch_size} \
          --valid-batch-size=${valid_batch_size} \
          --valid-step-interval=${valid_step_interval} \
          --valid-epoch-interval=${valid_epoch_interval} \
          --accum-freq=${accum_freq} \
          --lr=${lr} \
          --wd=${wd} \
          --max-epochs=${max_epochs} \
          --vision-model=${vision_model} \
          --mask-ratio=${mask_ratio} \
          ${use_augment} \
          --text-model=${text_model} \
          --freeze-vision \
          --grad-checkpointing \
        # --use-flash-attention \
        # --max-epochs=${max_epochs} \
        # --max-steps=${max_steps} \
相关训练配置

模型finetune

在此我们介绍训练的步骤,方便其他用户了解模型细节,使用我们提供的中文CLIP预训练模型进行finetune。基于MUGE和Flickr30K-CN两个下游检索数据集,我们提供了训练样例脚本run_scripts/muge_finetune_vit-b-16_rbt-base.sh和run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh。运行脚本同时支持单机(单卡或多卡)和多机分布式训练,请在运行前,先根据脚本开头的指引注释,填写好分布式相关配置,之后运行如下命令即可开始训练(多机训练请在各机器上都运行命令)。对于显存不足的情况,可以考虑激活配置项中的重计算策略。训练产生的log和模型ckpt文件,会自动保存在用户指定的目录下:

cd Chinese-CLIP/
bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}

相关的训练配置项包括:

分布式 WORKER_CNT: 训练的机器个数
GPUS_PER_NODE: 每个机器上的GPU个数

训练/验证数据 train-data: 训练数据LMDB目录,准备LMDB数据文件的预处理流程见上。
val-data: 验证数据LMDB目录,指定为None时,则不进行训练过程中的验证。
num-workers: 训练集数据处理(DataLoader)的进程数,默认为4。
valid-num-workers: 验证集数据处理(DataLoader)的进程数(如果进行验证),默认为1。

训练超参数 vision-model: 指定视觉backbone, 从 [“ViT-B-16”, “ViT-L-14”, “ViT-L-14-336”, “ViT-H-14”, “RN50”]选择。
text-model: 指定文本backbone, 从 [“RoBERTa-wwm-ext-base-chinese”, “RoBERTa-wwm-ext-large-chinese”, “RBT3-chinese”]选择。
context-length: 文本输入序列长度。
warmup: warmup步数。
batch-size: 训练时单卡batch-size。(请保证训练样本总数 > batch-size * GPU数,至少满足1个训练batch)
lr: 学习率。
wd: weight decay。
max-steps: 训练步数,也可通过max-epochs指定训练轮数。
freeze-vision: 是否freeze视觉backbone。
use-augment: 是否使用AutoAugment对图片进行数据增强。
valid-batch-size: 验证时单机batch-size。(请保证验证集样本总数 > batch-size * GPU数,至少满足1个验证batch)
valid-step-intervalvalid-epoch-interval: 验证step/epoch频率,指定为-1时则在训练中不进行验证。
grad-checkpointing: 使用重计算策略,在前向过程中不保存中间结果,以训练时间换取更小的显存开销,适用于显存不足的情况。(store_true参数,直接在脚本中加上–grad-checkpointing即可,目前要求Pytorch>1.8.0)
mask-ratio: 参照FLIP的策略,在finetune时可指定随机mask一定比例的图像patch,以降低显存开销、加快训练速度。默认为0.0,即不激活这一策略。
use-flash-attention: 使用FlashAttention,可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。(store_true参数,配置好环境后,在脚本中加上–use-flash-attention即可,请详见flash_attention.md)
accum-freq: 梯度累积频率,默认为1。指定为大于1的整数时开启对比学习梯度累积,模拟更大的batch size。如果单卡batch size为m,则总的batch size为accum_freq * m * GPU数。
gather-with-grad: 是否在分布式训练时进行带有完整梯度的特征gather,默认关闭。

输出选项
name: 指定输出路径。超参日志, 训练日志以及产出ckpt均会存放至 ${DATAPATH}/experiments/${name}/
save-step-frequencysave-epoch-frequency: 存ckpt的步数或轮数间隔。
report-training-batch-acc: 日志是否报告训练图到文&文到图batch准确率。

权重读取相关选项
resume: 权重读取的路径。示例脚本中指定为预训练ckpt路径,也可以指定为用户自己finetune的ckpt路径做继续训练。
reset-data-offset: 是否从此前的数据断点续跑。如batch size或GPU卡数超参改变,建议打开此选项。
reset-optimizer: 是否使用optimizer state。

训练完毕,log 会自动存在${DATAPATH}/experiments/${name}/out_${timestamp}.log,训练log格式如下所示:

2022-12-11,20:40:34 | INFO | Rank 0 | Global Steps: 1/735 | Train Epoch: 1 [1024/250880 (0%)] | Loss: 2.371020 | Image2Text Acc: 49.90 | Text2Image Acc: 48.73 | Data Time: 1.039s | Batch Time: 3.625s | LR: 0.000000 | logit_scale: 4.605 | Global Batch Size: 1024

验证log格式如下所示:

2022-12-11,20:42:47 | INFO | Rank 0 | Validation Result (epoch 1 @ 150 steps) | Valid Loss: 0.502810 | Image2Text Acc: 84.95 | Text2Image Acc: 84.26 | logit_scale: 4.605 | Valid Batch Size: 128

2. 微调框架

在 Bash 脚本会运行 main.py 文件之前,我们需要先将 Bash 脚本中定义地参数加载进来。这里使用 params.py 文件完成该操作。

params.py
import argparse


def get_default_params(model_name):
    # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
    if model_name in ["RN50", "RN101", "RN50x4"]:
        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
    elif model_name in ["ViT-B-32", "ViT-B-16", "ViT-H-14"]:
        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
    elif model_name in ["ViT-L-14", "ViT-L-14-336"]:
        return {"lr": 4.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
    else:
        return {}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train-data",
        type=str,
        required=True,
        help="Path to the LMDB directory with training data split",
    )
    parser.add_argument(
        "--val-data",
        type=str,
        default=None,
        help="Path to the LMDB directory with validation data split, default to None which disables validation",
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="The number of workers for training dataloader."
    )
    parser.add_argument(
        "--valid-num-workers", type=int, default=1, help="The number of workers for validation dataloader (if making validation)."
    )
    parser.add_argument(
        "--logs",
        type=str,
        default="./logs/",
        help="Where to store logs. Use None to avoid storing logs.",
    )
    parser.add_argument(
        "--name",
        type=str,
        default="train_clip",
        help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
    )
    parser.add_argument(
        "--log-interval", type=int, default=10, help="How often to log loss info."
    )
    parser.add_argument(
        "--report-training-batch-acc", default=False, action="store_true", help="Whether to report training batch accuracy."
    )
    parser.add_argument(
        "--batch-size", type=int, default=64, help="Batch size for training per GPU."
    )
    parser.add_argument(
        "--valid-batch-size", type=int, default=64, help="Batch size for validation per GPU."
    )
    parser.add_argument(
        "--max-steps", type=int, default=None, help="Number of steps to train for (in higher priority to --max_epochs)."
    )
    parser.add_argument(
        "--max-epochs", type=int, default=32, help="Number of full epochs to train for (only works if --max_steps is None)."
    )
    parser.add_argument(
        "--valid-step-interval", type=int, default=None, help="The step interval for validation (default to None which disables validation between steps)."
    )
    parser.add_argument(
        "--valid-epoch-interval", type=int, default=1, help="The epoch interval for validation (default to 1, set None to disable validation between epochs)."
    )
    parser.add_argument(
        "--context-length", type=int, default=52, help="The maximum length of input text (include [CLS] & [SEP] tokens). Default to 52."
    )
    parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
    parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
    parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
    parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
    parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
    parser.add_argument(
        "--warmup", type=int, default=500, help="Number of steps to warmup for."
    )
    parser.add_argument("--use-bn-sync",
        default=False,
        action="store_true",
        help="Whether to use batch norm sync."
    )
    parser.add_argument("--use-augment",
        default=False,
        action="store_true",
        help="Whether to use image augment."
    )
    parser.add_argument(
        "--skip-scheduler",
        action="store_true",
        default=False,
        help="Use this flag to skip the learning rate decay.",
    )
    parser.add_argument(
        "--save-epoch-frequency", type=int, default=1, help="How often to save checkpoints by epochs."
    )
    parser.add_argument(
        "--save-step-frequency", type=int, default=-1, help="How often to save checkpoints by steps."
    )
    parser.add_argument(
        "--resume",
        default=None,
        type=str,
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "--reset-optimizer",
        action="store_true",
        default=False,
        help="If resumed from a checkpoint, whether to reset the optimizer states.",
    )
    parser.add_argument(
        "--reset-data-offset",
        action="store_true",
        default=False,
        help="If resumed from a checkpoint, whether to reset the dataset offset to the beginning.",
    )    
    parser.add_argument(
        "--precision",
        choices=["amp", "fp16", "fp32"],
        default="amp",
        help="Floating point precision."
    )
    parser.add_argument(
        "--vision-model",
        choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
        default="ViT-B-16",
        help="Name of the vision backbone to use.",
    )
    parser.add_argument(
        "--mask-ratio",
        default=0,
        type=float,
        help="Random mask ratio of patches during finetuning. Default to zero which does not mask any patches.",
    )
    parser.add_argument(
        "--clip-weight-path",
        default=None,
        type=str,
        help="The path of openai pretrained weight, used to initialize the image encoder, should be set to None if you do not use pretrained CLIP",
    )    
    parser.add_argument(
        "--freeze-vision",
        action="store_true",
        default=False,
        help="Freeze the weight of vision encoder.",
    )
    parser.add_argument(
        "--text-model",
        choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
        default="RoBERTa-wwm-ext-base-chinese",
        help="Name of the text backbone to use.",
    )    
    parser.add_argument(
        "--bert-weight-path",
        default=None,
        type=str,
        help="The path of bert pretrained weight, used to initialize the text encoder, should be set to None if you do not use pretrained BERT",
    )
    parser.add_argument(
        "--grad-checkpointing",
        default=False,
        action='store_true',
        help="Enable gradient checkpointing.",
    )
    parser.add_argument(
        "--use-flash-attention",
        default=False,
        action="store_true",
        help="Enable flash attention."
    )
    parser.add_argument(
        "--accum-freq",
        type=int,
        default=1,
        help="Update the model every --acum-freq steps."
    )
    parser.add_argument(
        "--gather-with-grad",
        default=False,
        action="store_true",
        help="enable full distributed gradient for feature gather"
    )
    # arguments for distributed training
    parser.add_argument(
        "--skip-aggregate",
        default=False,
        action="store_true",
        help="whether to aggregate features across gpus before computing the loss"
    )
    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="If true, more information is logged."
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=123, 
        help="Random seed."
    )
    # arguments for distllation
    parser.add_argument(
        "--distllation",
        default=False,
        action="store_true",
        help="If true, more information is logged."
    )
    parser.add_argument(
        "--teacher-model-name",
        type=str,
        default=None,
        help="The name of teacher model."
    )
    parser.add_argument(
        "--kd_loss_weight",
        type=float,
        default=0.5,
        help="Weight of KD loss."
    )
    args = parser.parse_args()
    args.aggregate = not args.skip_aggregate

    # If some params are not passed, we use the default values based on model name.
    default_params = get_default_params(args.vision_model)
    for name, val in default_params.items():
        # getattr 返回对象地属性值
        # setattr 函数对应函数 getattr(),用于设置属性值,该属性不一定是存在的。
        if getattr(args, name) is None:
            setattr(args, name, val)

    return args

之后程序会查找可用的 GPU 并分配分布式工作组,然后创建日志监控进程

logger.py
import logging  # 导入Python的日志处理模块
from logging import Filter  # 从日志模块中导入Filter类,用于过滤日志记录
from logging.handlers import QueueHandler, QueueListener  # 导入处理日志队列的处理器和监听器

import torch  # 导入PyTorch库
import torch.distributed as dist  # 导入PyTorch的分布式计算模块
import torch.multiprocessing as mp  # 导入PyTorch的多进程模块
from torch.multiprocessing import Queue  # 从多进程模块中导入Queue类

logging.raiseExceptions = False  # 设置日志模块在遇到错误时不抛出异常

def setup_primary_logging(log_file, level, rank):  # 定义一个函数用于设置主进程的日志系统
    log_queue = Queue(-1)  # 创建一个无限大小的日志队列

    formatter = logging.Formatter(
        '%(asctime)s | %(levelname)s | %(message)s', 
        datefmt='%Y-%m-%d,%H:%M:%S')  # 定义日志的格式化器,包括时间、日志级别和消息

    if rank == 0:  # 如果是主进程
        file_handler = logging.FileHandler(filename=log_file)  # 创建一个文件处理器,将日志写入文件
        file_handler.setFormatter(formatter)  # 设置文件处理器的格式化器
        file_handler.setLevel(level)  # 设置文件处理器的日志级别
    
    stream_handler = logging.StreamHandler()  # 创建一个流处理器,输出日志到控制台
    stream_handler.setFormatter(formatter)  # 设置流处理器的格式化器
    stream_handler.setLevel(level)  # 设置流处理器的日志级别

    if rank == 0:  # 如果是主进程
        listener = QueueListener(log_queue, file_handler, stream_handler)  # 创建一个日志监听器,监听日志队列
    else:  # 如果是子进程
        listener = QueueListener(log_queue, stream_handler)  # 创建一个只有控制台输出的日志监听器

    listener.start()  # 启动日志监听器

    return log_queue  # 返回日志队列供其他进程使用


class WorkerLogFilter(Filter):  # 定义一个日志过滤器类,用于修改日志消息
    def __init__(self, rank=-1):  # 初始化过滤器,可以指定进程rank
        super().__init__()
        self._rank = rank  # 存储进程rank

    def filter(self, record):  # 修改日志记录的方法
        if self._rank != -1:  # 如果指定了rank
            record.msg = f"Rank {self._rank} | {record.msg}"  # 在消息前添加进程rank
        return True  # 总是返回True,表示通过过滤


def setup_worker_logging(rank, log_queue, level):  # 定义函数设置子进程的日志系统
    queue_handler = QueueHandler(log_queue)  # 创建一个日志队列处理器

    worker_filter = WorkerLogFilter(rank)  # 创建一个日志过滤器,附带当前进程的rank
    queue_handler.addFilter(worker_filter)  # 向处理器添加过滤器

    queue_handler.setLevel(level)  # 设置队列处理器的日志级别

    root_logger = logging.getLogger()  # 获取默认的日志记录器
    if len(root_logger.handlers) > 0:  # 如果记录器已经有处理器
        root_logger.removeHandler(root_logger.handlers[0])  # 移除现有的处理器
    root_logger.addHandler(queue_handler)  # 添加新的队列处理器
    root_logger.setLevel(level)  # 设置记录器的日志级别

随后开始加载 CLIP 模型的配置文件并构建模型