在自定義數(shù)據(jù)集上實(shí)現(xiàn)OpenAI CLIP
在2021年1月,OpenAI宣布了兩個(gè)新模型:DALL-E和CLIP,它們都是以某種方式連接文本和圖像的多模態(tài)模型。CLIP全稱(chēng)是Contrastive Language–Image Pre-training,一種基于對(duì)比文本-圖像對(duì)的預(yù)訓(xùn)練方法。為什么要介紹CLIP呢?因?yàn)楝F(xiàn)在大火得Stable Diffusion 并不是單一模型,而是多個(gè)模型組成。其中會(huì)用到一個(gè) Text encoder 將用戶(hù)的文本輸入進(jìn)行編碼,這個(gè) text encoder 就是 CLIP 模型中 text encoder。
CLIP模型在訓(xùn)練時(shí),可以給它一個(gè)輸入句子,并提取最相關(guān)的圖像來(lái)配合它。CLIP學(xué)習(xí)了一個(gè)完整的句子和它所描述的圖像之間的關(guān)系。也就是說(shuō)它是在完整的句子上訓(xùn)練的,而不是像“汽車(chē)”、“狗”等離散的分類(lèi),這一點(diǎn)對(duì)于應(yīng)用至關(guān)重要。當(dāng)訓(xùn)練完整的短語(yǔ)時(shí),模型可以學(xué)習(xí)更多的東西,并識(shí)別照片和文本之間的模式。他們還證明,當(dāng)在相當(dāng)大的照片和與之相對(duì)應(yīng)的句子數(shù)據(jù)集上進(jìn)行訓(xùn)練時(shí),該模型是可以作為分類(lèi)器的。CLIP在發(fā)布的時(shí)候能在無(wú)任何微調(diào)的情況下(zero-shot ),在 ImageNet 數(shù)據(jù)集上的分類(lèi)表現(xiàn)超 ResNets-50 微調(diào)后的效果,也就是說(shuō)他是非常有用的。
所以在本文中,我們將使用PyTorch中從頭開(kāi)始實(shí)現(xiàn)CLIP模型,以便我們對(duì)CLIP有一個(gè)更好的理解
這里就需要用到2個(gè)庫(kù):timm和transformers,我們先導(dǎo)入代碼
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
下一步就是預(yù)處理數(shù)據(jù)和通用配置config。config是一個(gè)普通的python文件,我們將所有的超參數(shù)放在里面,如果使用Jupyter Notebook的情況下,它是一個(gè)在Notebook開(kāi)頭定義的類(lèi)。
class CFG:
debug = False
image_path = "../input/flickr-image-dataset/flickr30k_images/flickr30k_images"
captions_path = "."
batch_size = 32
num_workers = 4
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200
pretrained = True # for both image encoder and text encoder
trainable = True # for both image encoder and text encoder
temperature = 1.0
# image size
size = 224
# for projection head; used for both image and text encoders
num_projection_layers = 1
projection_dim = 256
dropout = 0.1
還有一些我們自定義指標(biāo)的輔助類(lèi)
class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
我們的目標(biāo)是描述圖像和句子。所以數(shù)據(jù)集必須同時(shí)返回句子和圖像。所以需要使用DistilBERT標(biāo)記器對(duì)句子(標(biāo)題)進(jìn)行標(biāo)記,然后將標(biāo)記id (input_ids)和注意掩碼提供給DistilBERT。DistilBERT比BERT 模型要小,但是模型的結(jié)果都差不多,所以我們選擇使用它。
下一步就是使用HuggingFace tokenizer進(jìn)行標(biāo)記化。在__init__中獲得的tokenizer對(duì)象,將在模型運(yùn)行時(shí)加載。標(biāo)題被填充并截?cái)嗟筋A(yù)定的最大長(zhǎng)度。在加載相關(guān)圖像之前,我們將在__getitem__中加載一個(gè)編碼的標(biāo)題,這是一個(gè)帶有鍵input_ids和attention_mask的字典,并對(duì)其進(jìn)行轉(zhuǎn)換和擴(kuò)充(如果有的話(huà))。然后把它變成一個(gè)張量,并以“image”作為鍵存儲(chǔ)在字典中。最后我們將標(biāo)題的原始文本與關(guān)鍵字“標(biāo)題”一起輸入字典。
class CLIPDataset(torch.utils.data.Dataset):
def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
self.image_filenames = image_filenames
self.captions = list(captions)
self.encoded_captions = tokenizer(
list(captions), padding=True, truncatinotallow=True, max_length=CFG.max_length
)
self.transforms = transforms
def __getitem__(self, idx):
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
return item
def __len__(self):
return len(self.captions)
def get_transforms(mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
圖像和文本編碼器:我們將使用ResNet50作為圖像編碼器。
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
使用DistilBERT作為文本編碼器。使用CLS令牌的最終表示來(lái)獲得句子的整個(gè)表示。
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
self.model = DistilBertModel.from_pretrained(model_name)
else:
self.model = DistilBertModel(cnotallow=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
上面的代碼已經(jīng)將圖像和文本編碼為固定大小的向量(圖像2048,文本768),我們需要圖像和文本具有相似的尺寸,以便能夠比較它們,所以我們把2048維和768維向量投影到256維(projection_dim),只有維度相同我們才能比較它們。
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
所以最后我們的CLIP模型就是這樣:
class CLIPModel(nn.Module):
def __init__(
self,
temperature=CFG.temperature,
image_embedding=CFG.image_embedding,
text_embedding=CFG.text_embedding,
):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reductinotallow='none')
images_loss = cross_entropy(logits.T, targets.T, reductinotallow='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
#這里還加了一個(gè)交叉熵函數(shù)
def cross_entropy(preds, targets, reductinotallow='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
這里需要說(shuō)明下,CLIP使用 symmetric cross entropy 作為損失函數(shù),可以降低噪音影響,提高模型魯棒性,我們這里為了簡(jiǎn)單只是用cross entropy 。
我們可以進(jìn)行測(cè)試:
# A simple Example
batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))
下一步就是訓(xùn)練了,有一些函數(shù)可以幫助我們加載訓(xùn)練和驗(yàn)證的dataloader
def make_train_valid_dfs():
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(42)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
然后就是訓(xùn)練和評(píng)估
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
loss_meter = AvgMeter()
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch in tqdm_object:
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
return loss_meter
def valid_epoch(model, valid_loader):
loss_meter = AvgMeter()
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm_object:
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
return loss_meter
最后整合起來(lái)就是全部流程
def main():
train_df, valid_df = make_train_valid_dfs()
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
params = [
{"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
{"params": itertools.chain(
model.image_projection.parameters(), model.text_projection.parameters()
), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)
step = "epoch"
best_loss = float('inf')
for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}")
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
torch.save(model.state_dict(), "best.pt")
print("Saved Best Model!")
lr_scheduler.step(valid_loss.avg)
應(yīng)用:獲取圖像嵌入并找到匹配。
我們訓(xùn)練完成后如何實(shí)際應(yīng)用呢?我們需要編寫(xiě)一個(gè)函數(shù)加載訓(xùn)練后的模型,為其提供驗(yàn)證集中的圖像,并返回形狀(valid_set_size, 256)和模型本身的image_embeddings。
def get_image_embeddings(valid_df, model_path):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_locatinotallow=CFG.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
return model, torch.cat(valid_image_embeddings)
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
def find_matches(model, image_embeddings, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = model.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [image_filenames[idx] for idx in indices[::5]]
_, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"{CFG.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image)
ax.axis("off")
plt.show()
調(diào)用方法如下:
find_matches(model,
image_embeddings,
query="one dog sitting on the grass",
image_filenames=valid_df['image'].values,
n=9)
可以看到我們自定義效果還是不錯(cuò)的(但是圖里面有個(gè)貓,哈)。也就是說(shuō)CLIP這種方法在小數(shù)據(jù)集上自定義也是可行的。
以下是本文的代碼和數(shù)據(jù)集:
https://www.kaggle.com/code/jyotidabas/simple-openai-clip-implementation