自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

終于把神經(jīng)網(wǎng)絡(luò)中的知識蒸餾搞懂了!??!

人工智能
在深度學(xué)習(xí)中,尤其是在計(jì)算機(jī)視覺和自然語言處理等任務(wù)中,深度神經(jīng)網(wǎng)絡(luò)(DNN)常常有非常龐大的參數(shù)量。

大家好,我是小寒

今天給大家分享神經(jīng)網(wǎng)絡(luò)中的一個關(guān)鍵知識點(diǎn),知識蒸餾

知識蒸餾是一種模型壓縮方法,用于將大型神經(jīng)網(wǎng)絡(luò)(教師模型)中的知識轉(zhuǎn)移到較小的神經(jīng)網(wǎng)絡(luò)(學(xué)生模型)中。

這一技術(shù)能夠在保持或接近原始模型性能的情況下,顯著減小模型的體積,從而提升推理效率。

知識蒸餾在很多場景中非常有用,尤其是在計(jì)算資源有限或需要部署到邊緣設(shè)備的應(yīng)用中。

知識蒸餾的背景和動機(jī)

在深度學(xué)習(xí)中,尤其是在計(jì)算機(jī)視覺和自然語言處理等任務(wù)中,深度神經(jīng)網(wǎng)絡(luò)(DNN)常常有非常龐大的參數(shù)量。盡管這些大型模型(如BERT、ResNet等)能夠取得非常好的性能,但它們也面臨著存儲、計(jì)算和延遲等挑戰(zhàn)。為了克服這一問題,知識蒸餾被提出作為一種方法,通過訓(xùn)練較小的學(xué)生模型來模擬大型教師模型的行為。

知識蒸餾的基本概念

  1. 教師模型(Teacher Model)
    通常是一個預(yù)訓(xùn)練的、復(fù)雜的深度神經(jīng)網(wǎng)絡(luò),具有較高的精度,但計(jì)算和存儲開銷較大。
  2. 學(xué)生模型(Student Model)
    學(xué)生模型相對簡單,參數(shù)較少,推理速度更快,目標(biāo)是通過知識蒸餾從教師模型中獲取知識,提升其性能。
  3. 軟標(biāo)簽(Soft Labels)
    軟標(biāo)簽是教師模型輸出的概率分布,而非簡單的類別標(biāo)簽。
    教師模型通常使用 softmax 層生成的概率分布作為軟標(biāo)簽,這些分布包含了類別間的相對關(guān)系。
  4. 溫度(Temperature)
    在蒸餾過程中,通常使用一個溫度參數(shù)來調(diào)節(jié)教師模型輸出的概率分布的“平滑程度”。較高的溫度會使得輸出分布更加平滑,從而讓學(xué)生模型學(xué)習(xí)到更多的類間關(guān)系。

知識蒸餾的流程

  1. 訓(xùn)練教師模型
    首先訓(xùn)練一個大型的、高性能的教師模型。
    該模型在給定的訓(xùn)練數(shù)據(jù)集上表現(xiàn)非常好,具有高精度,但計(jì)算開銷較大。
  2. 生成軟標(biāo)簽
    用教師模型對訓(xùn)練數(shù)據(jù)進(jìn)行預(yù)測,得到每個樣本的類別概率分布(即軟標(biāo)簽)。
    可以使用 softmax 函數(shù)將教師模型的原始輸出轉(zhuǎn)換為概率分布,并通過調(diào)節(jié)溫度參數(shù)來控制這些概率分布的平滑度。
  3. 訓(xùn)練學(xué)生模型
    使用教師模型生成的軟標(biāo)簽來訓(xùn)練一個較小的學(xué)生模型。
    學(xué)生模型的目標(biāo)是模仿教師模型的輸出,從而盡可能地學(xué)習(xí)到教師模型的知識。
    訓(xùn)練過程中,學(xué)生模型同時會使用真實(shí)標(biāo)簽(硬標(biāo)簽)和軟標(biāo)簽進(jìn)行監(jiān)督學(xué)習(xí)。
  4. 損失函數(shù)設(shè)計(jì)知識蒸餾的損失函數(shù)通常由兩個部分組成。傳統(tǒng)的監(jiān)督損失:計(jì)算學(xué)生模型輸出與真實(shí)標(biāo)簽之間的交叉熵。蒸餾損失:計(jì)算學(xué)生模型輸出與教師模型輸出之間的差異,通常使用 KL 散度度量兩個概率分布之間的差異。因此,知識蒸餾的損失函數(shù)通常是這兩個損失的加權(quán)和:

溫度的作用

在知識蒸餾中,溫度 T 控制了教師模型輸出的“軟標(biāo)簽”分布的平滑程度。

較高的溫度會使得輸出的概率分布更加平滑,減少類間的差異,使學(xué)生模型能夠?qū)W習(xí)到更多的類之間的相似性。

  • 在高溫度下,教師模型的輸出概率分布更加平滑,類之間的概率差異較小。
  • 在低溫度下,輸出概率分布變得更加尖銳,教師模型的預(yù)測結(jié)果接近于硬標(biāo)簽。

通過調(diào)節(jié)溫度,可以讓學(xué)生模型更好地學(xué)習(xí)到教師模型的知識。

知識蒸餾的優(yōu)點(diǎn)

  1. 模型壓縮
    通過蒸餾,學(xué)生模型通常比教師模型更小,參數(shù)數(shù)量更少,可以大幅度降低計(jì)算和存儲開銷。
  2. 提高推理速度
    由于學(xué)生模型體積較小,推理速度較快,適合部署到移動設(shè)備或資源有限的邊緣設(shè)備上。

案例分享

以下是一個基于 PyTorch 實(shí)現(xiàn)的簡單示例代碼,展示了如何進(jìn)行神經(jīng)網(wǎng)絡(luò)中的知識蒸餾。

首先,定義教師模型和學(xué)生模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F

# 教師模型(較大網(wǎng)絡(luò))
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(7*7*64, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 7*7*64)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 學(xué)生模型(較小網(wǎng)絡(luò))
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(7*7*32, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 7*7*32)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

接下來,定義蒸餾損失函數(shù)。

def distillation_loss(y_student, y_teacher, T=2.0, alpha=0.7):
    # 計(jì)算軟標(biāo)簽的交叉熵?fù)p失
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(y_student / T, dim=1),
        F.softmax(y_teacher / T, dim=1)
    )

    # 計(jì)算真實(shí)標(biāo)簽的交叉熵?fù)p失
    hard_loss = F.cross_entropy(y_student, torch.argmax(y_teacher, dim=1))

    # 綜合蒸餾損失
    return alpha * soft_loss + (1 - alpha) * hard_loss

接下來定義一個訓(xùn)練函數(shù),其中教師模型先訓(xùn)練好,然后使用蒸餾損失訓(xùn)練學(xué)生模型。

def train(model, device, train_loader, optimizer, epoch, teacher_model=None, T=2.0, alpha=0.7):
    model.train()
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # 教師模型和學(xué)生模型的輸出
        output = model(data)
        with torch.no_grad():  # 教師模型在蒸餾時不更新參數(shù)
            teacher_output = teacher_model(data)

        # 計(jì)算蒸餾損失
        loss = distillation_loss(output, teacher_output, T, alpha)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Train Epoch: {epoch} \tLoss: {running_loss / len(train_loader):.6f}")
    
    
batch_size = 64
epochs = 10
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 數(shù)據(jù)加載
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 初始化教師模型和學(xué)生模型
teacher_model = TeacherModel().to(device)
student_model = StudentModel().to(device)

# 教師模型訓(xùn)練(簡單訓(xùn)練)
optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=lr)
teacher_model.train()
for epoch in range(1, epochs + 1):
    train_teacher(teacher_model, device, train_loader, optimizer_teacher, epoch)

# 學(xué)生模型訓(xùn)練(蒸餾)
optimizer_student = optim.Adam(student_model.parameters(), lr=lr)
student_model.train()
for epoch in range(1, epochs + 1):
     train(student_model, device, train_loader, optimizer_student, epoch, teacher_mod


責(zé)任編輯:武曉燕 來源: 程序員學(xué)長
相關(guān)推薦

2024-09-12 08:28:32

2024-10-17 13:05:35

神經(jīng)網(wǎng)絡(luò)算法機(jī)器學(xué)習(xí)深度學(xué)習(xí)

2024-07-24 08:04:24

神經(jīng)網(wǎng)絡(luò)激活函數(shù)

2024-11-07 08:26:31

神經(jīng)網(wǎng)絡(luò)激活函數(shù)信號

2024-09-20 07:36:12

2024-10-28 00:38:10

2024-11-15 13:20:02

2025-02-21 08:29:07

2024-12-12 00:29:03

2024-10-05 23:00:35

2024-09-26 07:39:46

2024-08-01 08:41:08

2024-12-03 08:16:57

2024-10-16 07:58:48

2024-07-17 09:32:19

2024-09-23 09:12:20

2025-02-17 13:09:59

深度學(xué)習(xí)模型壓縮量化

2024-10-14 14:02:17

機(jī)器學(xué)習(xí)評估指標(biāo)人工智能

2024-11-05 12:56:06

機(jī)器學(xué)習(xí)函數(shù)MSE

2024-08-23 09:06:35

機(jī)器學(xué)習(xí)混淆矩陣預(yù)測
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號