GitHub 7.5k star量,各種視覺(jué)Transformer的PyTorch實(shí)現(xiàn)合集整理好了
近一兩年,Transformer 跨界 CV 任務(wù)不再是什么新鮮事了。
自 2020 年 10 月谷歌提出 Vision Transformer (ViT) 以來(lái),各式各樣視覺(jué) Transformer 開(kāi)始在圖像合成、點(diǎn)云處理、視覺(jué) - 語(yǔ)言建模等領(lǐng)域大顯身手。
之后,在 PyTorch 中實(shí)現(xiàn) Vision Transformer 成為了研究熱點(diǎn)。GitHub 中也出現(xiàn)了很多優(yōu)秀的項(xiàng)目,今天要介紹的就是其中之一。
該項(xiàng)目名為「vit-pytorch」,它是一個(gè) Vision Transformer 實(shí)現(xiàn),展示了一種在 PyTorch 中僅使用單個(gè) transformer 編碼器來(lái)實(shí)現(xiàn)視覺(jué)分類(lèi) SOTA 結(jié)果的簡(jiǎn)單方法。
項(xiàng)目當(dāng)前的 star 量已經(jīng)達(dá)到了 7.5k,創(chuàng)建者為 Phil Wang,ta 在 GitHub 上有 147 個(gè)資源庫(kù)。
項(xiàng)目地址:https://github.com/lucidrains/vit-pytorch
項(xiàng)目作者還提供了一段動(dòng)圖展示:
項(xiàng)目介紹
首先來(lái)看 Vision Transformer-PyTorch 的安裝、使用、參數(shù)、蒸餾等步驟。
第一步是安裝:
$ pip install vit-pytorch
第二步是使用:
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
第三步是所需參數(shù),包括如下:
- image_size:圖像大小
- patch_size:patch 數(shù)量
- num_classes:分類(lèi)類(lèi)別的數(shù)量
- dim:線(xiàn)性變換 nn.Linear(..., dim) 后輸出張量的最后維
- depth:Transformer 塊的數(shù)量
- heads:多頭注意力層中頭的數(shù)量
- mlp_dim:MLP(前饋)層的維數(shù)
- channels:圖像通道的數(shù)量
- dropout:Dropout rate
- emb_dropout:嵌入 dropout rate
- ……
最后是蒸餾,采用的流程出自 Facebook AI 和索邦大學(xué)的論文《Training data-efficient image transformers & distillation through attention》。
論文地址:https://arxiv.org/pdf/2012.12877.pdf
從 ResNet50(或任何教師網(wǎng)絡(luò))蒸餾到 vision transformer 的代碼如下:
import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillationalpha = 0.5, # trade between main loss and distillation losshard = False # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)loss.backward()
# after lots of training above ...pred = v(img) # (2, 1000)
除了 Vision Transformer 之外,該項(xiàng)目還提供了 Deep ViT、CaiT、Token-to-Token ViT、PiT 等其他 ViT 變體模型的 PyTorch 實(shí)現(xiàn)。
對(duì) ViT 模型 PyTorch 實(shí)現(xiàn)感興趣的讀者可以參閱原項(xiàng)目。