從零開(kāi)始構(gòu)建 DINO:自監(jiān)督視覺(jué) Transformer
DINO模型輸出的狗沖刺
無(wú)標(biāo)簽自蒸餾(DINO)
《從幾個(gè)“補(bǔ)丁”中重建完整圖像 | 構(gòu)建可擴(kuò)展學(xué)習(xí)器的掩模自編碼器》這邊文章講了如何構(gòu)建可擴(kuò)展學(xué)習(xí)器,這是我對(duì)視覺(jué)變換器系列的繼續(xù),其中我解釋了最重要的架構(gòu)及其從零開(kāi)始的實(shí)現(xiàn)。
自監(jiān)督學(xué)習(xí)
自監(jiān)督學(xué)習(xí)(SSL)是一種機(jī)器學(xué)習(xí)類型,模型通過(guò)無(wú)需手動(dòng)標(biāo)記的示例來(lái)學(xué)習(xí)理解數(shù)據(jù)。相反,它從數(shù)據(jù)本身生成其監(jiān)督信號(hào)。當(dāng)標(biāo)記數(shù)據(jù)有限且獲取成本高昂時(shí),這種方法非常有益。在SSL中,學(xué)習(xí)過(guò)程涉及創(chuàng)建任務(wù),其中輸入數(shù)據(jù)可以用來(lái)預(yù)測(cè)數(shù)據(jù)本身的某些部分。常見(jiàn)的技術(shù)包括:
- 對(duì)比學(xué)習(xí):模型通過(guò)區(qū)分相似和不相似的數(shù)據(jù)對(duì)來(lái)學(xué)習(xí)。
- 預(yù)測(cè)任務(wù):模型從其他部分預(yù)測(cè)輸入數(shù)據(jù)的一部分,例如預(yù)測(cè)句子中的下一個(gè)詞或從其周圍環(huán)境中預(yù)測(cè)詞的上下文。
DINO模型
DINO(無(wú)標(biāo)簽蒸餾)模型是一種應(yīng)用于視覺(jué)變換器(ViTs)的尖端自監(jiān)督學(xué)習(xí)方法。它代表了計(jì)算機(jī)視覺(jué)領(lǐng)域的一個(gè)重大進(jìn)步,使模型能夠在不需要任何標(biāo)記數(shù)據(jù)的情況下學(xué)習(xí)有效的圖像表示。由Facebook AI Research(FAIR)的研究人員開(kāi)發(fā),DINO利用學(xué)生-教師框架和創(chuàng)新的訓(xùn)練技術(shù),在各種視覺(jué)任務(wù)上取得了卓越的性能。
學(xué)生-教師網(wǎng)絡(luò)
在DINO模型中,學(xué)生-教師網(wǎng)絡(luò)是實(shí)現(xiàn)無(wú)需標(biāo)記數(shù)據(jù)的自監(jiān)督學(xué)習(xí)的核心機(jī)制。這個(gè)框架涉及兩個(gè)網(wǎng)絡(luò):學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)。兩個(gè)網(wǎng)絡(luò)都是視覺(jué)變換器,它們被設(shè)計(jì)用來(lái)通過(guò)將圖像處理為序列塊來(lái)處理圖像,類似于變換器處理文本序列的方式。
學(xué)生網(wǎng)絡(luò)的任務(wù)是從輸入圖像中學(xué)習(xí)生成有意義的表示。另一方面,教師網(wǎng)絡(luò)提供目標(biāo)表示,學(xué)生網(wǎng)絡(luò)旨在匹配這些表示。教師網(wǎng)絡(luò)不是一個(gè)靜態(tài)實(shí)體;它通過(guò)逐漸整合學(xué)生網(wǎng)絡(luò)的參數(shù)隨時(shí)間演變。這是通過(guò)一種稱為指數(shù)移動(dòng)平均的技術(shù)完成的,其中教師的參數(shù)被更新為其當(dāng)前參數(shù)和學(xué)生參數(shù)的加權(quán)平均值。
目標(biāo)是最小化學(xué)生表示和教師表示之間的差異,這些表示是針對(duì)相同增強(qiáng)圖像視圖的。這通常是通過(guò)使用一個(gè)損失函數(shù)來(lái)實(shí)現(xiàn)的,該函數(shù)鼓勵(lì)學(xué)生和教師輸出之間的對(duì)齊,同時(shí)確保不同圖像的表示保持不同。
通過(guò)根據(jù)學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)進(jìn)度不斷更新教師網(wǎng)絡(luò),并訓(xùn)練學(xué)生網(wǎng)絡(luò)以匹配教師的輸出,DINO有效地利用了兩個(gè)網(wǎng)絡(luò)的優(yōu)勢(shì)。教師網(wǎng)絡(luò)為學(xué)生提供了穩(wěn)定和一致的目標(biāo),而學(xué)生網(wǎng)絡(luò)推動(dòng)了學(xué)習(xí)過(guò)程。這種協(xié)作設(shè)置允許模型在無(wú)需手動(dòng)標(biāo)簽的情況下從數(shù)據(jù)中學(xué)習(xí)強(qiáng)大和不變的特征,從而實(shí)現(xiàn)有效的自監(jiān)督學(xué)習(xí)。
學(xué)生和教師的增強(qiáng)輸入
在DINO模型中,X1和X2(見(jiàn)上圖)指的是同一原始圖像X的不同增強(qiáng)視圖。這些視圖分別用作學(xué)生和教師網(wǎng)絡(luò)的輸入。目標(biāo)是讓學(xué)生網(wǎng)絡(luò)學(xué)習(xí)在這些增強(qiáng)下產(chǎn)生一致的表示。學(xué)生和教師模型根據(jù)以下策略接收不同的增強(qiáng):
- 全局裁剪:從原始圖像創(chuàng)建兩個(gè)全局裁剪。這些是覆蓋圖像大部分的較大裁剪,通常與原始圖像有很高的重疊。除了其他增強(qiáng)(如顏色抖動(dòng)、高斯模糊、翻轉(zhuǎn)等)之外。
- 局部裁剪:除了全局裁剪外,教師網(wǎng)絡(luò)還接收幾個(gè)局部裁剪。這些是關(guān)注圖像不同部分的較小裁剪,捕捉更多局部細(xì)節(jié)。
我們將如何為參數(shù)圖像定義這些增強(qiáng),這些圖像包含我們?cè)谟?xùn)練期間想要轉(zhuǎn)換的一批圖像。
# These augmentations are defined exactly as proposed in the paper
def global_augment(images):
global_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.4, 1.0)), # Larger crops
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # Color jittering
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return torch.stack([global_transform(img) for img in images])
def multiple_local_augments(images, num_crops=6):
size = 96 # Smaller crops for local
local_transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.05, 0.4)), # Smaller, more concentrated crops
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # Same level of jittering
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply the transformation multiple times to the same image
return torch.stack([local_transform(img) for img in images])
蒸餾損失
在這里,我們希望使用某種距離度量來(lái)計(jì)算學(xué)生輸出和教師輸出之間的損失。我們這樣做:
- 獲取教師預(yù)測(cè)輸出的中心化Softmax,然后應(yīng)用銳化。
- 獲取學(xué)生的Softmax預(yù)測(cè),然后應(yīng)用銳化。
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
# Detach teacher output to stop gradients.
teacher_output = teacher_output.detach()
# Center and sharpen teacher's outputs
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)
# Sharpen student's outputs
student_probs = F.log_softmax(student_output / tau_s, dim=1)
# Calculate cross-entropy loss between students' and teacher's probabilities.
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss
- 中心化:中心化教師的輸出確保學(xué)生模型更多地關(guān)注教師輸出分布中最顯著的特征或區(qū)別。通過(guò)中心化分布,鼓勵(lì)學(xué)生更多地關(guān)注對(duì)準(zhǔn)確預(yù)測(cè)至關(guān)重要的顯著特征,而不是受數(shù)據(jù)中的變化或偏差的影響。這有助于更有效的知識(shí)傳遞,并可能導(dǎo)致學(xué)生模型的性能提高。
- 銳化:銳化涉及放大數(shù)據(jù)分布中的特定特征,旨在強(qiáng)調(diào)教師模型突出的區(qū)分。這個(gè)過(guò)程使學(xué)生模型能夠?qū)W⒂趯W(xué)習(xí)教師預(yù)測(cè)中存在的復(fù)雜細(xì)節(jié),這對(duì)于在數(shù)據(jù)集上準(zhǔn)確復(fù)制其輸出至關(guān)重要。
訓(xùn)練DINO模型
闡明DINO偽代碼的圖像,取自官方論文
有3個(gè)重要的步驟需要強(qiáng)調(diào):
(1) 獲取學(xué)生和教師架構(gòu)的不同輸入(x1,x2)的增強(qiáng)。
(2) 我們之前討論的蒸餾損失函數(shù),注意它是如何計(jì)算不同增強(qiáng)輸入的架構(gòu)的蒸餾損失的,即gs({x1, x2})和gt({x1, x2})。
(3) 更新(a)學(xué)生參數(shù)(b)教師參數(shù)和(c)中心。這里的關(guān)鍵是我們對(duì)更新教師參數(shù)執(zhí)行指數(shù)移動(dòng)平均更新。
- 教師參數(shù):EMA應(yīng)用于教師模型的參數(shù)。而不是在每次訓(xùn)練迭代中直接更新教師參數(shù),EMA隨時(shí)間維護(hù)這些參數(shù)的移動(dòng)平均值。這個(gè)移動(dòng)平均值作為教師模型的更平滑、更穩(wěn)定的表示,可以幫助指導(dǎo)學(xué)生模型的訓(xùn)練。
- 中心:此外,在DINO的一些實(shí)現(xiàn)中,EMA也用于更新中心。中心代表教師輸出分布的平均值,用于歸一化目的。通過(guò)應(yīng)用EMA更新中心,它在整個(gè)訓(xùn)練過(guò)程中逐漸演變,為歸一化提供更穩(wěn)定的參考點(diǎn)。
DINO模型
class DINO(nn.Module):
def __init__(self, student_arch: Callable, teacher_arch: Callable, device: torch.device):
"""
Args:
student_arch (nn.Module): ViT Network for student_arch
teacher_arch (nn.Module): ViT Network for teacher_arch
device: torch.device ('cuda' or 'cpu')
"""
super(DINO, self).__init__()
self.student = student_arch().to(device)
self.teacher = teacher_arch().to(device)
self.teacher.load_state_dict(self.student.state_dict())
# Initialize center as buffer to avoid backpropagation
self.register_buffer('center', torch.zeros(1, student_arch().output_dim))
# Ensure the teacher parameters do not get updated during backprop
for param in self.teacher.parameters():
param.requires_grad = False
@staticmethod
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
# Detach teacher output to stop gradients.
teacher_output = teacher_output.detach()
# Center and sharpen teacher's outputs
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)
# Sharpen student's outputs
student_probs = F.log_softmax(student_output / tau_s, dim=1)
# Calculate cross-entropy loss between student's and teacher's probabilities.
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss
def teacher_update(self, beta: float):
for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))
為了更新教師的參數(shù),我們使用論文中提出公式,即gt.param = gt.param*beta + gs.param*(1 — beta),其中beta是移動(dòng)平均衰減,gt、gs分別是相應(yīng)的教師和學(xué)生架構(gòu)。
進(jìn)一步,我們?cè)赺_init__下看到,教師的參數(shù)已設(shè)置為“required_grads = False”,因?yàn)槲覀儾幌M诜聪騻鞑テ陂g更新它們,而是應(yīng)用移動(dòng)平均更新。
此外,在PyTorch中將變量初始化為bugger是一種常見(jiàn)方法,用于將其保持在梯度圖之外,并不參與反向傳播。
Dino模型進(jìn)一步需要如下調(diào)用
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dino = DINO(ViT(), ViT(), device)
在這里,我們傳遞學(xué)生和教師架構(gòu),這不過(guò)是標(biāo)準(zhǔn)的視覺(jué)變換器,即ViT-B/16或ViT-L/16,正如第一篇論文中提出的。
最終訓(xùn)練
現(xiàn)在可以將整個(gè)實(shí)現(xiàn)放入訓(xùn)練循環(huán)中,正如論文中提出的。
def train_dino(dino: DINO,
data_loader: DataLoader,
optimizer: Optimizer,
device: torch.device,
num_epochs,
tps=0.9,
tpt= 0.04,
beta= 0.9,
m= 0.9,
):
"""
Args:
dino: DINO Module
data_loader (nn.Module): Dataloader for training
optimizer (nn.optimizer): Optimizer for optimization (SGD etc.)
defice (torch.device): 'cuda', 'cpu'
num_epochs: Number of Epochs
tps (float): tau for sharpening student logits
tpt: for sharpening teacher logits
beta (float): moving average decay
m (float): center moveing average decay
"""
for epoch in range(num_epochs):
print(f"Epoch: {epoch+1}/{len(num_epochs)}")
for x in data_loader:
x1, x2 = global_augment(x), multiple_local_augments(x)
student_output1, student_output2 = dino.student(x1.to(device)), dino.student(x2.to(device))
with torch.no_grad():
teacher_output1, teacher_output2 = dino.teacher(x1.to(device)), dino.teacher(x2.to(device))
# Compute distillation loss
loss = (dino.distillation_loss(teacher_output1, student_output2, dino.center, tps, tpt) +
dino.distillation_loss(teacher_output2, student_output1, dino.center, tps, tpt)) / 2
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update the teacher network parameters
dino.teacher_update(beta)
# Update the center
with torch.no_grad():
dino.center = m * dino.center + (1 - m) * torch.cat([teacher_output1, teacher_output2], dim=0).mean(dim=0)
(1) 我們用不同的全局和局部增強(qiáng)計(jì)算x1和x2。
(2) 之后,我們根據(jù)論文中提出的,為學(xué)生和教師模型獲取輸出,回想上面的算法循環(huán)圖。
(3) 在這里,我們將torch設(shè)置為no_grad()函數(shù),以確保教師的參數(shù)不會(huì)通過(guò)反向傳播更新。
(4) 最后,我們?cè)俅胃鶕?jù)論文中提出的方法計(jì)算蒸餾損失。
(5) 在蒸餾損失中,我們首先中心化教師模型的輸出,這樣學(xué)生模型就不容易崩潰,也不會(huì)只學(xué)習(xí)不重要的特征,或者比另一個(gè)特征更多地學(xué)習(xí)一個(gè)特征,而是專注于從教師模型中學(xué)習(xí)最獨(dú)特和潛在的特征。
(6) 然后我們銳化特征,以便在計(jì)算損失時(shí),我們現(xiàn)在能夠比較兩個(gè)特征(學(xué)生和教師的)具有非常不同的數(shù)據(jù)分布,這意味著銳化后,更重要的特征會(huì)被銳化,而不太重要的特征則不會(huì),這將創(chuàng)建一個(gè)更獨(dú)特的特征圖,使學(xué)生更容易學(xué)習(xí)。
(7) 然后我們執(zhí)行反向傳播并執(zhí)行optimizer.step(),更新學(xué)生模型并通過(guò)之前實(shí)現(xiàn)的指數(shù)移動(dòng)平均更新教師網(wǎng)絡(luò)。
(8) 作為最后一步,我們將再次將torch設(shè)置為no_grad()并通過(guò)移動(dòng)平均更新中心。我們根據(jù)教師的輸出更新中心,因此它與訓(xùn)練過(guò)程中輸出數(shù)據(jù)分布的變化保持一致。
就這樣,這就是如何從零開(kāi)始訓(xùn)練DINO模型。到目前為止,在視覺(jué)變換器系列中,我們已經(jīng)實(shí)現(xiàn)了標(biāo)準(zhǔn)的ViT、Swin、CvT、Mae和DINO(自監(jiān)督)。希望你喜歡閱讀這篇文章。
# Create your own CustomDataset and dataloader
dataloader = DataLoader(CustomDataset, batch_size=32, shuffle=True)
optimizer = torch.optim.AdamW(dino.parameters(), lr=1e-4)
train_dino(dino,
DataLoader=dataloader,
Optimizer=optimizer,
device=device,
num_epochs=300,
tps=0.9,
tpt= 0.04,
beta= 0.9,
m= 0.9)