CLIP的Finetune (Train):
OpenAI的三方库最好通过pip git安装
0x00 导入包
import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
0x01 读入图像信息
json_path = '/kaggle/input/indo-fashion-dataset/train_data.json'
with open(json_path, 'r') as f:
input_data = []
for line in f:
obj = json.loads(line)
input_data.append(obj)
# See some example
print(input_data[1])
0x02 构建图像-文本对
要点:
- 读入一个列表,存储图像所在的位置
- 读入一个列表,用于描述图像
- 列表下标下,图像-文本成对匹配
- 图像需要处理成RGB通道
- 使用ResNet50 作为预处理和网络架构的读入
class image_title_dataset():
def __init__(self, list_image_path, list_txt, preprocess):
self.image_paths = list_image_path
self.title = clip.tokenize(list_txt, context_length= 77, truncate= True)
self.preprocess = preprocess
def __len__(self):
return len(self.title)
def __getitem__(self, index):
image = self.preprocess(Image.open(self.image_paths[index]).convert('RGB'))
title = self.title[index]
return image, title
img_path = []
text = []
for data in input_data:
text.append(data['product_title'])
img_path.append('/kaggle/input/indo-fashion-dataset/' + data['image_path'])
device = 'cpu'
net, preprocess = clip.load("RN50",device=device,jit=False)
ds = image_title_dataset(img_path, text,preprocess=preprocess)
0x03 微调
要点:
- batch_size 根据自己的机器调
- num_workers 我经常用不了,所以自废武功
- Adam优化器
- 没加学习率的,可以自行假如lr_sch
- 只考虑了CPU微调(不建议,因为我的电脑环境经常崩)
- GPU微调需要加入代码:clip.model.convert_weights(net)
- 损失函数使用交叉熵的均值
from tqdm import tqdm
train_loader = DataLoader(ds,
batch_size= 4,
num_workers = 0,
pin_memory= False)
optimizer = torch.optim.Adam(model.parameters(),
lr=5e-5,
betas=(0.9, 0.98),
eps=1e-6,
weight_decay=0.2)
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
n_epochs = 30
for epoch in range(n_epochs):
pbar = tqdm(train_loader,total=len(train_loader))
for batch in pbar:
optimizer.zero_grad()
image, text = batch
images = image.to(device)
texts = text.to(device)
logits_per_image, logits_per_text = net(images, texts)
ground_truth = torch.arange(len(images),
dtype=torch.long,
device=device)
total_loss = ((loss_img(logits_per_image, ground_truth) +
loss_txt(logits_per_text, ground_truth))
* 0.5)
total_loss.backward()
optimizer.step()
pbar.set_description(f"Epoch {epoch}/{n_epochs},
Loss: {total_loss.item():.4f}")
0xFF 走一步看一步吧