深度學(xué)習(xí)架構(gòu)的超級(jí)英雄——BatchNorm2d 原創(chuàng)
本文旨在探索2D批處理規(guī)范化在深度學(xué)習(xí)架構(gòu)中的關(guān)鍵作用,并通過簡(jiǎn)單的例子來解釋該技術(shù)的內(nèi)部工作原理。
由作者本人創(chuàng)建的圖像
深度學(xué)習(xí)(DL)已經(jīng)改變了卷積神經(jīng)網(wǎng)絡(luò)(CNN)和生成式人工智能(Gen AI)發(fā)展的游戲規(guī)則。這種深度學(xué)習(xí)模型可以從多維空間數(shù)據(jù)(如圖像)中提取復(fù)雜的模式和特征,并進(jìn)行預(yù)測(cè)。輸入數(shù)據(jù)中的模式越復(fù)雜,模型架構(gòu)就越復(fù)雜。
盡管存在很多方法可以加速模型訓(xùn)練收斂并提高模型推理性能;但是,批量歸一化2D(BN2D)已成為這方面的超級(jí)英雄。這篇文章旨在展示如何將BN2D集成到DL架構(gòu)中,從而實(shí)現(xiàn)更快的收斂和更好的推理。
了解一下BN2D
BN2D是一種批量應(yīng)用于多維空間輸入(如圖像)的歸一化技術(shù),以歸一化它們的維度(通道)值,從而使這些批次的維度具有0的平均值和1的方差。
合并BN2D組件的主要目的是防止來自網(wǎng)絡(luò)內(nèi)先前層的輸入數(shù)據(jù)中跨維度或通道的內(nèi)部協(xié)變量偏移。當(dāng)維度數(shù)據(jù)的分布由于在訓(xùn)練周期(epoch)對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行的更新而改變時(shí),會(huì)發(fā)生跨維度的內(nèi)部協(xié)變量偏移。例如,卷積層中的N個(gè)過濾器產(chǎn)生N維激活作為輸出。該層維護(hù)其過濾器的權(quán)重和偏差參數(shù),這些參數(shù)隨著每個(gè)訓(xùn)練周期而逐漸更新。
作為這些更新的結(jié)果,來自一個(gè)過濾器的激活可以具有與來自相同卷積層的另一過濾器的激活明顯不同的分布。這種分布上的差異表明,來自一個(gè)過濾器的激活與來自另一過濾器的激活在很大程度上不同。當(dāng)將這種尺度大不相同的維度數(shù)據(jù)輸入到網(wǎng)絡(luò)中的下一層時(shí),該層的可學(xué)習(xí)性受到阻礙,因?yàn)榕c尺度較小的維度相比,尺度較大的維度的權(quán)重在梯度下降期間需要更大的更新。
另一個(gè)可能的后果是,尺度較小的權(quán)重梯度可能消失,而尺度較大的權(quán)重梯度則可能爆炸式增長(zhǎng)。當(dāng)網(wǎng)絡(luò)遇到這種學(xué)習(xí)障礙時(shí),梯度下降將在更大的尺度上振蕩,嚴(yán)重阻礙學(xué)習(xí)收斂和訓(xùn)練穩(wěn)定性。BN2D通過將維度數(shù)據(jù)標(biāo)準(zhǔn)化為平均值為0、標(biāo)準(zhǔn)偏差為1的標(biāo)準(zhǔn)尺度,有效地緩解了這一現(xiàn)象,并促進(jìn)了訓(xùn)練過程中更快的收斂,減少了實(shí)現(xiàn)最佳性能所需的訓(xùn)練周期(epoch)數(shù)量。因此,通過簡(jiǎn)化網(wǎng)絡(luò)的訓(xùn)練階段,該技術(shù)確保網(wǎng)絡(luò)可以專注于學(xué)習(xí)更復(fù)雜和抽象的特征,從而從輸入數(shù)據(jù)中提取更豐富的表示。
在標(biāo)準(zhǔn)實(shí)際應(yīng)用中,BN2D實(shí)例插入卷積層后面,但插入到激活層(如ReLU)前,如圖1中的示例DL網(wǎng)絡(luò)所示:
圖1:一個(gè)深度CNN的示例(圖片由作者本人創(chuàng)建)
BN2D內(nèi)部工作原理
圖2中顯示了一批簡(jiǎn)單的多維空間數(shù)據(jù)的示例(如僅使用了3個(gè)通道的圖像),以說明BN2D技術(shù)的內(nèi)部工作原理。
圖2:BN2D的內(nèi)部工作原理(作者本人創(chuàng)建的圖像)
如圖2所示,BN2D的功能是在每個(gè)維度或通道處理一個(gè)批次。如果輸入批次具有N個(gè)維度或通道,則BN2D實(shí)例將具有N個(gè)BN2D層。在示例情況下,紅色、綠色和藍(lán)色通道的單獨(dú)處理意味著對(duì)應(yīng)的BN2D實(shí)例具有3個(gè)BN2D層。
圖3:BN2D使用的公式(作者本人創(chuàng)建的圖像)
在訓(xùn)練過程中,BN2D計(jì)算每個(gè)批次維度的平均值和方差,并使用圖3所示的訓(xùn)練時(shí)間公式對(duì)值進(jìn)行歸一化,如圖2所示。預(yù)設(shè)的ε是分母中的一個(gè)常數(shù),以避免出現(xiàn)被零除的錯(cuò)誤。BN2D實(shí)例維護(hù)每個(gè)維度或BN2D層的可學(xué)習(xí)參數(shù)——比例(γ)和偏移(β),這些參數(shù)在訓(xùn)練優(yōu)化期間更新。BN2D實(shí)例還維護(hù)每個(gè)BN2D層的移動(dòng)平均值和方差,如圖2所示,它們?cè)谟?xùn)練過程中使用圖3所示的公式進(jìn)行更新。預(yù)設(shè)動(dòng)量(α)用作指數(shù)平均因子。
在推理過程中,使用如圖3所示的推理時(shí)間公式,BN2D實(shí)例使用特定維度的移動(dòng)平均值、移動(dòng)方差以及學(xué)習(xí)的比例(γ)和偏移(β)參數(shù)對(duì)每個(gè)維度的值進(jìn)行歸一化。圖2中顯示了批量輸入中每個(gè)維度的訓(xùn)練時(shí)間批量歸一化計(jì)算示例。圖2中的示例還說明了BN2D實(shí)例的輸出,該實(shí)例包含跨維度或通道獨(dú)立規(guī)范化的整個(gè)批次。用于完成圖2所示示例的PyTorch Jupyter筆記本文件可在以下GitHub存儲(chǔ)庫(kù)中找到:?https://github.com/kbmurali/hindi_hw_digits/blob/main/how_batch_norm2d_works.ipynb?。
使用BN2D
為了檢查在DL網(wǎng)絡(luò)架構(gòu)中結(jié)合BN2D實(shí)例的預(yù)期性能改進(jìn),我們使用一個(gè)簡(jiǎn)單的(類似玩具的)圖像數(shù)據(jù)集來構(gòu)建具有和不具有BN2D的相對(duì)比較簡(jiǎn)單的DL網(wǎng)絡(luò),來實(shí)現(xiàn)所屬類別的預(yù)測(cè)。以下是使用BN2D預(yù)期帶來的幾個(gè)關(guān)鍵DL模型性能改進(jìn)方面:
1. 改進(jìn)的泛化:BN2D引入的歸一化有望提高DL模型的泛化能力。在該示例中,當(dāng)在網(wǎng)絡(luò)中引入BN2D層時(shí),預(yù)計(jì)會(huì)提高推理時(shí)間分類精度。
2. 更快的收斂:引入BN2D層有望促進(jìn)訓(xùn)練過程中更快的收斂,減少實(shí)現(xiàn)最佳性能所需的訓(xùn)練周期數(shù)量。在該示例中,在引入BN2D層之后的早期訓(xùn)練周期開始,預(yù)計(jì)訓(xùn)練損失會(huì)降低。
3. 更平滑的梯度下降:由于BN2D將維度數(shù)據(jù)標(biāo)準(zhǔn)化為標(biāo)準(zhǔn)尺度,平均值為0,標(biāo)準(zhǔn)偏差為1,因此有望將梯度下降在較大規(guī)模維度上振蕩的可能性降至最低,梯度下降有望順利進(jìn)行。
示例數(shù)據(jù)集
Kaggle在地址?https://www.kaggle.com/datasets/suvooo/hindi-character-recognition/data?處發(fā)布的印地語手寫數(shù)字(0-9)數(shù)據(jù)(GNU許可證)將用于訓(xùn)練和測(cè)試包含和不包含BN2D的卷積DL模型。讀者可參考本文頂部的橫幅圖片,了解印地語數(shù)字是如何書寫的。DL模型網(wǎng)絡(luò)是使用PyTorch DL模塊構(gòu)建而成的。選擇手寫的印地語數(shù)字而不是英語數(shù)字是基于與后者相比的復(fù)雜性。印地語數(shù)字的邊緣檢測(cè)比英語更具挑戰(zhàn)性,因?yàn)橛〉卣Z數(shù)字中的曲線多于直線。此外,根據(jù)一個(gè)人的寫作風(fēng)格,同一個(gè)數(shù)字可能會(huì)有更多的變化。
而且,我還開發(fā)了一個(gè)實(shí)用的Python函數(shù),以便對(duì)數(shù)字?jǐn)?shù)據(jù)的訪問更符合PyTorch數(shù)據(jù)集/數(shù)據(jù)加載器的操作,如下面的代碼片段所示。訓(xùn)練數(shù)據(jù)集有17000個(gè)樣本,而測(cè)試數(shù)據(jù)集有3000個(gè)。請(qǐng)注意,在將圖像加載為PyTorch張量時(shí)應(yīng)用了PyTorch灰度轉(zhuǎn)換器。另外,我還專門開發(fā)了一個(gè)名為“ml_utils.py”的實(shí)用程序模塊,用于打包使用基于PyTorch張量的操作運(yùn)行控制訓(xùn)練周期、訓(xùn)練和測(cè)試深度學(xué)習(xí)模型的函數(shù)。其中,訓(xùn)練和測(cè)試函數(shù)還能夠捕獲模型度量指標(biāo),以幫助評(píng)估模型的性能。Python筆記本文件和實(shí)用程序模塊均可以在作者的公共GitHub存儲(chǔ)庫(kù)中訪問,其鏈接如下:
https://github.com/kbmurali/hindi_hw_digits?。
import torch
import torch.nn as nn
from torch.utils.data import *
import torchvision
from torchvision import transforms
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from ml_utils import *
from hindi.datasets import Digits
set_seed( 5842 )
batch_size = 32
img_transformer = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
])
train_dataset = Digits( "./data", train=True, transform=img_transformer, download=True )
test_dataset = Digits( "./data", train=False, transform=img_transformer, download=True )
train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True )
test_loader = DataLoader( test_dataset, batch_size=batch_size )
DL模型示例
第一個(gè)DL模型將包括具有16個(gè)過濾器的三個(gè)卷積層,每個(gè)過濾器的核大小為3,填充為1,以便產(chǎn)生“相同”的卷積。每個(gè)卷積的激活函數(shù)是修正線性單元(ReLU)。池大小為2的最大池化層被放置在完全連接層之前,導(dǎo)致產(chǎn)生一個(gè)輸出10個(gè)類別的softmax層。該模型的網(wǎng)絡(luò)架構(gòu)如圖4所示。下面的代碼片段顯示了相應(yīng)的PyTorch模型定義。
圖4:沒有BN2D的卷積網(wǎng)絡(luò)(圖片由作者本人創(chuàng)建)
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
loss_func = nn.CrossEntropyLoss()
input_channels = 1
classes = 10
filters = 16
kernel_size = 3
padding = kernel_size//2
pool_size = 2
original_pixels_per_channel = 32*32
three_convs_model = nn.Sequential(
nn.Conv2d( input_channels, filters, kernel_size, padding=padding ), # 1x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.Conv2d(filters, filters, kernel_size, padding=padding ), # 16x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.Conv2d(filters, filters, kernel_size, padding=padding ), # 16x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.MaxPool2d(pool_size), # 16x32x32 => 16x16x16
nn.Flatten(), # 16x16x16 => 4096
nn.Linear( 4096, classes) # 1024 => 10
)
第二個(gè)DL模型與第一個(gè)DL模型共享相似的結(jié)構(gòu),但在卷積之后和激活之前引入了BN2D實(shí)例。該模型的網(wǎng)絡(luò)架構(gòu)如圖5所示。下面的代碼片段顯示了相應(yīng)的PyTorch模型定義。
圖5:帶BN2D的卷積網(wǎng)絡(luò)(圖片由作者本人創(chuàng)建)
three_convs_wth_bn_model = nn.Sequential(
nn.Conv2d( input_channels, filters, kernel_size, padding=padding ), # 1x32x32 => 16x32x32
nn.BatchNorm2d( filters ), #16x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.Conv2d(filters, filters, kernel_size, padding=padding ), # 16x32x32 => 16x32x32
nn.BatchNorm2d( filters ), #16x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.Conv2d(filters, filters, kernel_size, padding=padding ), # 16x32x32 => 16x32x32
nn.BatchNorm2d( filters ), #16x32x32 => 16x32x32
nn.ReLU(inplace=True), #16x32x32 => 16x32x32
nn.MaxPool2d(pool_size), # 16x32x32 => 16x16x16
nn.Flatten(), # 16x16x16 => 4096
nn.Linear( 4096, classes) # 4096 => 10
)
我們使用以下代碼片段中所示的實(shí)用函數(shù),以便在示例印地語數(shù)字?jǐn)?shù)據(jù)集上訓(xùn)練兩個(gè)DL模型。注意,代碼中捕獲了來自最后卷積層中過濾器的兩個(gè)維度/通道的兩個(gè)樣本權(quán)重,以便可視化展示訓(xùn)練損失的梯度下降信息。
three_convs_model_results_df = train_model(
three_convs_model,
loss_func,
train_loader,
test_loader=test_loader,
score_funcs={'accuracy': accuracy_score},
device=device,
epochs=30,
capture_conv_sample_weights=True,
conv_index=4,
wx_flt_index=3,
wx_ch_index=4,
wx_ro_index=1,
wx_index=0,
wy_flt_index=3,
wy_ch_index=8,
wy_ro_index=1,
wy_index=0
)
three_convs_wth_bn_model_results_df = train_model(
three_convs_wth_bn_model,
loss_func,
train_loader,
test_loader=test_loader,
score_funcs={'accuracy': accuracy_score},
device=device,
epochs=30,
capture_conv_sample_weights=True,
conv_index=6,
wx_flt_index=3,
wx_ch_index=4,
wx_ro_index=1,
wx_index=0,
wy_flt_index=3,
wy_ch_index=8,
wy_ro_index=1,
wy_index=0
)
發(fā)現(xiàn)之1:提高了測(cè)試精度
通過引入BN2D實(shí)例,DL模型的測(cè)試精度更好,如圖6所示。對(duì)于具有BN2D的模型,測(cè)試精度隨著訓(xùn)練時(shí)期的增加而逐漸提高,而對(duì)于沒有BN2D的模式,測(cè)試精度則隨著訓(xùn)練周期而振蕩。在第30個(gè)訓(xùn)練周期結(jié)束時(shí),具有BN2D的模型的測(cè)試精度為99.1%,而不具有BN2D模型的測(cè)試準(zhǔn)確率為92.4%。這些結(jié)果表明,引入BN2D實(shí)例對(duì)模型的性能產(chǎn)生了積極影響,顯著提高了測(cè)試精度。
sns.lineplot( x='epoch', y='test accuracy', data=three_convs_model_results_df, label="Three Convs Without BN2D Model" )
sns.lineplot( x='epoch', y='test accuracy', data=three_convs_wth_bn_model_results_df, label="Three Convs Wth BN2D Model" )
圖6:訓(xùn)練周期的測(cè)試準(zhǔn)確性(圖片由作者本人創(chuàng)建)
發(fā)現(xiàn)之2:更快的收斂
引入BN2D實(shí)例時(shí),DL模型的訓(xùn)練損失要低得多,如圖7所示。大約訓(xùn)練到第3個(gè)周期時(shí),具有BN2D的模型表現(xiàn)出比沒有BN2D的模型更低的訓(xùn)練損失。較低的訓(xùn)練損失表明,BN2D有助于在訓(xùn)練過程中更快地收斂,可能會(huì)減少合理收斂的訓(xùn)練周期數(shù)量。
sns.lineplot( x='epoch', y='train loss', data=three_convs_model_results_df, label="Three Convs Without BN2D Model" )
sns.lineplot( x='epoch', y='train loss', data=three_convs_wth_bn_model_results_df, label="Three Convs Wth BN2D Model" )
圖7:在各個(gè)訓(xùn)練周期過程對(duì)應(yīng)的訓(xùn)練損失(圖片由作者本人創(chuàng)建)
發(fā)現(xiàn)之3:更平滑的梯度下降
如圖8所示,從帶有BN2D的模型的最后一次卷積中獲得的兩個(gè)樣本權(quán)重的損失函數(shù)顯示出比沒有BN2D的模型更平滑的梯度下降。沒有BN2D的模型的損失函數(shù)遵循相當(dāng)明顯的“之”字形的梯度下降。BN2D的更平滑的梯度下降表明,將維度數(shù)據(jù)標(biāo)準(zhǔn)化為平均值為0、標(biāo)準(zhǔn)偏差為1的標(biāo)準(zhǔn)規(guī)模,使得不同維度的權(quán)重可能趨于相似的規(guī)模,從而減少梯度下降的可能振蕩。
fig1 = draw_loss_descent( three_convs_model_results_df, title='Three Convs Model Without BN2D Training Loss' )
fig2 = draw_loss_descent( three_convs_wth_bn_model_results_df, title='Three Convs With BN2D Model Training Loss' )
圖8:樣本權(quán)重上的損失函數(shù)梯度下降(圖片由作者本人創(chuàng)建)
實(shí)際注意事項(xiàng)
雖然BN2D的好處是顯而易見的,但要真正使用它還需要仔細(xì)考量。權(quán)重的適當(dāng)初始化、適當(dāng)?shù)膶W(xué)習(xí)率以及在DL網(wǎng)絡(luò)中放置BN2D層是最大化其有效性的關(guān)鍵因素。雖然BN2D通??梢苑乐惯^度擬合,但在某些情況下,它甚至可能導(dǎo)致過度擬合。例如,如果將BN2D與另一種稱為Dropout的技術(shù)一起使用,則根據(jù)具體配置和數(shù)據(jù)集,組合可能會(huì)對(duì)過度擬合產(chǎn)生不同的影響。同樣,在小批量的情況下,批量平均值和方差可能不能很好地代表整個(gè)數(shù)據(jù)集的統(tǒng)計(jì)數(shù)據(jù),這可能會(huì)導(dǎo)致有噪聲的歸一化,這在防止過度擬合方面可能沒有那么有效。
結(jié)論
本文旨在展示在深度學(xué)習(xí)網(wǎng)絡(luò)中使用BN2D的背后邏輯。文中使用類玩具圖像數(shù)據(jù)的示例卷積模型僅用于展示在DL網(wǎng)絡(luò)架構(gòu)中結(jié)合BN2D實(shí)例情況下的預(yù)期結(jié)果的性能改進(jìn)。最后的結(jié)論是,跨空間和通道維度的BN2D歸一化帶來了訓(xùn)練穩(wěn)定性、更快的收斂性和增強(qiáng)的泛化能力,最終有助于深度學(xué)習(xí)模型的成功。希望這篇文章有助于您很好地理解BN2D的工作原理及其背后的實(shí)現(xiàn)邏輯。在開發(fā)更復(fù)雜的DL模型時(shí),我相信這種理解和直覺感受會(huì)助您一臂之力。
參考資料
1. “印地語字符識(shí)別”,解決分類梵文的問題。地址:?https://www.kaggle.com/datasets/suvooo/hindi-character-recognition/data?source=post_page-----b4eb869e8b60--------------------------------?。
2. “BatchNorm2d-PyTorch 2.1文檔”,讀者可加入PyTorch開發(fā)者社區(qū),貢獻(xiàn)、學(xué)習(xí)并回答您的問題。地址:?https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html?source=post_page-----b4eb869e8b60--------------------------------?。
3. “為什么在特征中使用2D批量歸一化和在分類器中使用1D?”,討論BatchNorm2d和BatchNorm1d之間有什么區(qū)別?為什么在特征中使用BatchNorm2d而BatchNorm1d是……地址:?https://discuss.pytorch.org/t/why-2d-batch-normalisation-is-used-in-features-and-1d-in-classifiers/88360/3?source=post_page-----b4eb869e8b60--------------------------------?。
4. “Keras文檔:BatchNormalization層”,Keras文檔有關(guān)內(nèi)容,地址:?https://keras.io/api/layers/normalization_layers/batch_normalization/?source=post_page-----b4eb869e8b60--------------------------------?。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。
原文標(biāo)題:Exploring the Superhero Role of 2D Batch Normalization in Deep Learning Architectures,作者:Murali Kashaboina
鏈接:??https://towardsdatascience.com/exploring-the-superhero-role-of-2d-batch-normalization-in-deep-learning-architectures-b4eb869e8b60?。
