Chinese-CLIP 模型代码笔记:模型微调 / Notes about the Chinese-CLIP model code
本博客所展示的模型来自论文《Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese》,源码的GitHub 网址。
1. Bash 脚本
以下是运行模型微调程序的 bash 脚本,其主要用于向运行程序传入参数,通过参数名能较好地理解参数地用处,所以就不介绍了。
finetune.sh
#!/usr/bin/env
# Guide:
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
# Please set the options below according to the comments.
# For multi-gpu workers training, these options should be manually set for each worker.
# After setting the options, please run the script on each worker.
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}
### PAY ATTENTION
# 运行时请 cd 到python代码文件所在路径
# Number of GPUs per GPU worker
GPUS_PER_NODE=8
# Number of GPU workers, for single-worker training, please set to 1
WORKER_CNT=1
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
export MASTER_ADDR=localhost
# The port for communication
export MASTER_PORT=8510
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
export RANK=0
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/
DATAPATH=/raid/ljh/home/game
# data options
train_data=${DATAPATH}/datasets/processed_data/lmdb/train #train
val_data=${DATAPATH}/datasets/processed_data/lmdb/valid # if val_data is not specified, the validation will be automatically disabled
# restore options
resume=${DATAPATH}/experiments/256batch_4epoch_processed_data/checkpoints/epoch2.pt # or specify your customed ckpt path to resume
reset_data_offset="--reset-data-offset"
reset_optimizer="--reset-optimizer"
# reset_optimizer=""
# output options
output_base_dir=${DATAPATH}/experiments/
name=256batch_6epoch_processed_data_freeze_vision
save_step_frequency=999999 # disable it
save_epoch_frequency=1
log_interval=1
report_training_batch_acc="--report-training-batch-acc"
# report_training_batch_acc=""
# training hyper-params
context_length=60
warmup=200
batch_size=256
valid_batch_size=256
accum_freq=1
lr=2e-5
wd=0.001
max_epochs=6 # or you can alternatively specify --max-steps
max_steps=999999 # disable it
valid_step_interval=999999 # disable it
valid_epoch_interval=1
vision_model=ViT-H-14
text_model=RoBERTa-wwm-ext-large-chinese
mask_ratio=0 # use flip: set mask ratio
use_augment="--use-augment"
# use_augment=""
# python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
# --use-flash-attention 加速计算
# --grad-checkpointing 不保存每轮都保存checkpoint
# --freeze-vision 冻结视觉模块
torchrun --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
--train-data=${train_data} \
--val-data=${val_data} \
--resume=${resume} \
${reset_data_offset} \
${reset_optimizer} \
--logs=${output_base_dir} \
--name=${name} \
--save-step-frequency=${save_step_frequency} \
--save-epoch-frequency=${save_epoch_frequency} \
--log-interval=${log_interval} \
${report_training_batch_acc} \
--context-length=${context_length} \
--warmup=${warmup} \
--batch-size=${batch_size} \
--valid-batch-size=${valid_batch_size} \
--valid-step-interval=${valid_step_interval} \
--valid-epoch-interval=${valid_epoch_interval} \
--accum-freq=${accum_freq} \
--lr=${lr} \
--wd=${wd} \
--max-epochs=${max_epochs} \
--vision-model=${vision_model} \
--mask-ratio=${mask_ratio} \
${use_augment} \
--text-model=${text_model} \
--freeze-vision \
--grad-checkpointing \
# --use-flash-attention \
# --max-epochs=${max_epochs} \
# --max-steps=${max_steps} \
相关训练配置
模型finetune
在此我们介绍训练的步骤,方便其他用户了解模型细节,使用我们提供的中文CLIP预训练模型进行finetune。基于MUGE和Flickr30K-CN两个下游检索数据集,我们提供了训练样例脚本run_scripts/muge_finetune_vit-b-16_rbt-base.sh和run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh。运行脚本同时支持单机(单卡或多卡)和多机分布式训练,请在运行前,先根据脚本开头的指引注释,填写好分布式相关配置,之后运行如下命令即可开始训练(多机训练请在各机器上都运行命令)。对于显存不足的情况,可以考虑激活配置项中的重计算策略。训练产生的log和模型ckpt文件,会自动保存在用户指定的目录下:
cd Chinese-CLIP/
bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}
相关的训练配置项包括:
分布式
WORKER_CNT
: 训练的机器个数
GPUS_PER_NODE
: 每个机器上的GPU个数
训练/验证数据
train-data
: 训练数据LMDB目录,准备LMDB数据文件的预处理流程见上。
val-data
: 验证数据LMDB目录,指定为None时,则不进行训练过程中的验证。
num-workers
: 训练集数据处理(DataLoader)的进程数,默认为4。
valid-num-workers
: 验证集数据处理(DataLoader)的进程数(如果进行验证),默认为1。
训练超参数
vision-model
: 指定视觉backbone, 从 [“ViT-B-16”, “ViT-L-14”, “ViT-L-14-336”, “ViT-H-14”, “RN50”]选择。
text-model
: 指定文本backbone, 从 [“RoBERTa-wwm-ext-base-chinese”, “RoBERTa-wwm-ext-large-chinese”, “RBT3-chinese”]选择。
context-length
: 文本输入序列长度。
warmup
: warmup步数。
batch-size
: 训练时单卡batch-size。(请保证训练样本总数 > batch-size * GPU数,至少满足1个训练batch)
lr
: 学习率。
wd
: weight decay。
max-steps
: 训练步数,也可通过max-epochs指定训练轮数。
freeze-vision
: 是否freeze视觉backbone。
use-augment
: 是否使用AutoAugment对图片进行数据增强。
valid-batch-size
: 验证时单机batch-size。(请保证验证集样本总数 > batch-size * GPU数,至少满足1个验证batch)
valid-step-interval
和valid-epoch-interval
: 验证step/epoch频率,指定为-1时则在训练中不进行验证。
grad-checkpointing
: 使用重计算策略,在前向过程中不保存中间结果,以训练时间换取更小的显存开销,适用于显存不足的情况。(store_true参数,直接在脚本中加上–grad-checkpointing即可,目前要求Pytorch>1.8.0)
mask-ratio
: 参照FLIP的策略,在finetune时可指定随机mask一定比例的图像patch,以降低显存开销、加快训练速度。默认为0.0,即不激活这一策略。
use-flash-attention
: 使用FlashAttention,可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。(store_true参数,配置好环境后,在脚本中加上–use-flash-attention即可,请详见flash_attention.md)
accum-freq
: 梯度累积频率,默认为1。指定为大于1的整数时开启对比学习梯度累积,模拟更大的batch size。如果单卡batch size为m,则总的batch size为accum_freq * m * GPU数。
gather-with-grad
: 是否在分布式训练时进行带有完整梯度的特征gather,默认关闭。
输出选项
name
: 指定输出路径。超参日志, 训练日志以及产出ckpt均会存放至 ${DATAPATH}/experiments/${name}/
。
save-step-frequency
及save-epoch-frequency
: 存ckpt的步数或轮数间隔。
report-training-batch-acc
: 日志是否报告训练图到文&文到图batch准确率。
权重读取相关选项
resume
: 权重读取的路径。示例脚本中指定为预训练ckpt路径,也可以指定为用户自己finetune的ckpt路径做继续训练。
reset-data-offset
: 是否从此前的数据断点续跑。如batch size或GPU卡数超参改变,建议打开此选项。
reset-optimizer
: 是否使用optimizer state。
训练完毕,log 会自动存在${DATAPATH}/experiments/${name}/out_${timestamp}.log
,训练log格式如下所示:
2022-12-11,20:40:34 | INFO | Rank 0 | Global Steps: 1/735 | Train Epoch: 1 [1024/250880 (0%)] | Loss: 2.371020 | Image2Text Acc: 49.90 | Text2Image Acc: 48.73 | Data Time: 1.039s | Batch Time: 3.625s | LR: 0.000000 | logit_scale: 4.605 | Global Batch Size: 1024
验证log格式如下所示:
2022-12-11,20:42:47 | INFO | Rank 0 | Validation Result (epoch 1 @ 150 steps) | Valid Loss: 0.502810 | Image2Text Acc: 84.95 | Text2Image Acc: 84.26 | logit_scale: 4.605 | Valid Batch Size: 128
2. 微调框架
在 Bash 脚本会运行 main.py 文件之前,我们需要先将 Bash 脚本中定义地参数加载进来。这里使用 params.py
文件完成该操作。
params.py
import argparse
def get_default_params(model_name):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
if model_name in ["RN50", "RN101", "RN50x4"]:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
elif model_name in ["ViT-B-32", "ViT-B-16", "ViT-H-14"]:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
elif model_name in ["ViT-L-14", "ViT-L-14-336"]:
return {"lr": 4.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
type=str,
required=True,
help="Path to the LMDB directory with training data split",
)
parser.add_argument(
"--val-data",
type=str,
default=None,
help="Path to the LMDB directory with validation data split, default to None which disables validation",
)
parser.add_argument(
"--num-workers", type=int, default=4, help="The number of workers for training dataloader."
)
parser.add_argument(
"--valid-num-workers", type=int, default=1, help="The number of workers for validation dataloader (if making validation)."
)
parser.add_argument(
"--logs",
type=str,
default="./logs/",
help="Where to store logs. Use None to avoid storing logs.",
)
parser.add_argument(
"--name",
type=str,
default="train_clip",
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument(
"--log-interval", type=int, default=10, help="How often to log loss info."
)
parser.add_argument(
"--report-training-batch-acc", default=False, action="store_true", help="Whether to report training batch accuracy."
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Batch size for training per GPU."
)
parser.add_argument(
"--valid-batch-size", type=int, default=64, help="Batch size for validation per GPU."
)
parser.add_argument(
"--max-steps", type=int, default=None, help="Number of steps to train for (in higher priority to --max_epochs)."
)
parser.add_argument(
"--max-epochs", type=int, default=32, help="Number of full epochs to train for (only works if --max_steps is None)."
)
parser.add_argument(
"--valid-step-interval", type=int, default=None, help="The step interval for validation (default to None which disables validation between steps)."
)
parser.add_argument(
"--valid-epoch-interval", type=int, default=1, help="The epoch interval for validation (default to 1, set None to disable validation between epochs)."
)
parser.add_argument(
"--context-length", type=int, default=52, help="The maximum length of input text (include [CLS] & [SEP] tokens). Default to 52."
)
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument(
"--warmup", type=int, default=500, help="Number of steps to warmup for."
)
parser.add_argument("--use-bn-sync",
default=False,
action="store_true",
help="Whether to use batch norm sync."
)
parser.add_argument("--use-augment",
default=False,
action="store_true",
help="Whether to use image augment."
)
parser.add_argument(
"--skip-scheduler",
action="store_true",
default=False,
help="Use this flag to skip the learning rate decay.",
)
parser.add_argument(
"--save-epoch-frequency", type=int, default=1, help="How often to save checkpoints by epochs."
)
parser.add_argument(
"--save-step-frequency", type=int, default=-1, help="How often to save checkpoints by steps."
)
parser.add_argument(
"--resume",
default=None,
type=str,
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--reset-optimizer",
action="store_true",
default=False,
help="If resumed from a checkpoint, whether to reset the optimizer states.",
)
parser.add_argument(
"--reset-data-offset",
action="store_true",
default=False,
help="If resumed from a checkpoint, whether to reset the dataset offset to the beginning.",
)
parser.add_argument(
"--precision",
choices=["amp", "fp16", "fp32"],
default="amp",
help="Floating point precision."
)
parser.add_argument(
"--vision-model",
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--mask-ratio",
default=0,
type=float,
help="Random mask ratio of patches during finetuning. Default to zero which does not mask any patches.",
)
parser.add_argument(
"--clip-weight-path",
default=None,
type=str,
help="The path of openai pretrained weight, used to initialize the image encoder, should be set to None if you do not use pretrained CLIP",
)
parser.add_argument(
"--freeze-vision",
action="store_true",
default=False,
help="Freeze the weight of vision encoder.",
)
parser.add_argument(
"--text-model",
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
default="RoBERTa-wwm-ext-base-chinese",
help="Name of the text backbone to use.",
)
parser.add_argument(
"--bert-weight-path",
default=None,
type=str,
help="The path of bert pretrained weight, used to initialize the text encoder, should be set to None if you do not use pretrained BERT",
)
parser.add_argument(
"--grad-checkpointing",
default=False,
action='store_true',
help="Enable gradient checkpointing.",
)
parser.add_argument(
"--use-flash-attention",
default=False,
action="store_true",
help="Enable flash attention."
)
parser.add_argument(
"--accum-freq",
type=int,
default=1,
help="Update the model every --acum-freq steps."
)
parser.add_argument(
"--gather-with-grad",
default=False,
action="store_true",
help="enable full distributed gradient for feature gather"
)
# arguments for distributed training
parser.add_argument(
"--skip-aggregate",
default=False,
action="store_true",
help="whether to aggregate features across gpus before computing the loss"
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed."
)
# arguments for distllation
parser.add_argument(
"--distllation",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--teacher-model-name",
type=str,
default=None,
help="The name of teacher model."
)
parser.add_argument(
"--kd_loss_weight",
type=float,
default=0.5,
help="Weight of KD loss."
)
args = parser.parse_args()
args.aggregate = not args.skip_aggregate
# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.vision_model)
for name, val in default_params.items():
# getattr 返回对象地属性值
# setattr 函数对应函数 getattr(),用于设置属性值,该属性不一定是存在的。
if getattr(args, name) is None:
setattr(args, name, val)
return args
之后程序会查找可用的 GPU 并分配分布式工作组,然后创建日志监控进程
logger.py
import logging # 导入Python的日志处理模块
from logging import Filter # 从日志模块中导入Filter类,用于过滤日志记录
from logging.handlers import QueueHandler, QueueListener # 导入处理日志队列的处理器和监听器
import torch # 导入PyTorch库
import torch.distributed as dist # 导入PyTorch的分布式计算模块
import torch.multiprocessing as mp # 导入PyTorch的多进程模块
from torch.multiprocessing import Queue # 从多进程模块中导入Queue类
logging.raiseExceptions = False # 设置日志模块在遇到错误时不抛出异常
def setup_primary_logging(log_file, level, rank): # 定义一个函数用于设置主进程的日志系统
log_queue = Queue(-1) # 创建一个无限大小的日志队列
formatter = logging.Formatter(
'%(asctime)s | %(levelname)s | %(message)s',
datefmt='%Y-%m-%d,%H:%M:%S') # 定义日志的格式化器,包括时间、日志级别和消息
if rank == 0: # 如果是主进程
file_handler = logging.FileHandler(filename=log_file) # 创建一个文件处理器,将日志写入文件
file_handler.setFormatter(formatter) # 设置文件处理器的格式化器
file_handler.setLevel(level) # 设置文件处理器的日志级别
stream_handler = logging.StreamHandler() # 创建一个流处理器,输出日志到控制台
stream_handler.setFormatter(formatter) # 设置流处理器的格式化器
stream_handler.setLevel(level) # 设置流处理器的日志级别
if rank == 0: # 如果是主进程
listener = QueueListener(log_queue, file_handler, stream_handler) # 创建一个日志监听器,监听日志队列
else: # 如果是子进程
listener = QueueListener(log_queue, stream_handler) # 创建一个只有控制台输出的日志监听器
listener.start() # 启动日志监听器
return log_queue # 返回日志队列供其他进程使用
class WorkerLogFilter(Filter): # 定义一个日志过滤器类,用于修改日志消息
def __init__(self, rank=-1): # 初始化过滤器,可以指定进程rank
super().__init__()
self._rank = rank # 存储进程rank
def filter(self, record): # 修改日志记录的方法
if self._rank != -1: # 如果指定了rank
record.msg = f"Rank {self._rank} | {record.msg}" # 在消息前添加进程rank
return True # 总是返回True,表示通过过滤
def setup_worker_logging(rank, log_queue, level): # 定义函数设置子进程的日志系统
queue_handler = QueueHandler(log_queue) # 创建一个日志队列处理器
worker_filter = WorkerLogFilter(rank) # 创建一个日志过滤器,附带当前进程的rank
queue_handler.addFilter(worker_filter) # 向处理器添加过滤器
queue_handler.setLevel(level) # 设置队列处理器的日志级别
root_logger = logging.getLogger() # 获取默认的日志记录器
if len(root_logger.handlers) > 0: # 如果记录器已经有处理器
root_logger.removeHandler(root_logger.handlers[0]) # 移除现有的处理器
root_logger.addHandler(queue_handler) # 添加新的队列处理器
root_logger.setLevel(level) # 设置记录器的日志级别
随后开始加载 CLIP 模型的配置文件并构建模型