VAE變分自編碼器原理解析看這一篇就夠了!另附Python代碼實(shí)現(xiàn)
變分自編碼器是近些年較火的一個(gè)生成模型,我個(gè)人認(rèn)為其本質(zhì)上仍然是一個(gè)概率圖模型,只是在此基礎(chǔ)上引入了神經(jīng)網(wǎng)絡(luò)。本文將就變分自編碼器(VAE)進(jìn)行簡(jiǎn)單的原理講解和數(shù)學(xué)推導(dǎo)。
論文:https://arxiv.org/abs/1312.6114
視頻:https://www.bilibili.com/video/BV1op421S7Ep/
引入
高斯混合模型
生成模型,可以簡(jiǎn)單的理解為生成數(shù)據(jù) (不 止 , 但 我 們 暫 且 就 這 么 理 解 它) 。假如現(xiàn)在我們有樣本數(shù)據(jù),而我們發(fā)現(xiàn)這些樣本符合正態(tài)分布且樣本具有充分的代表性,因此我們計(jì)算出樣本的均值和方差,就能得到樣本的概率分布。然后從正態(tài)分布中抽樣,就能得到樣本。這 種 生 成 樣 本 的 過 程 就 是 生 成 過 程 。
可是,假如我們的數(shù)據(jù)長(zhǎng)這樣
很顯然,它的數(shù)據(jù)是由兩個(gè)不同的正態(tài)分布構(gòu)成。我們可以計(jì)算出這些樣本的概率分布。但是一種更為常見的方法就是將其當(dāng)作是兩個(gè)正態(tài)分布。我們引入一個(gè)隱變量z。
假 設(shè) z 的 取 值 為 0,1 ,如果z為0,我們就從藍(lán)色的概率分布中抽樣;否則為1,則從橙色的概率分布中抽樣。這就是生成過程。
但是這個(gè)隱變量z是什么?它其實(shí)就是隱藏特征 訓(xùn) 練 數(shù) 據(jù)x 的 抽 象 出 來 的 特 征 ,比如,如果x偏小,我們則認(rèn)為它數(shù)據(jù)藍(lán)色正太分布,否則為橙色。這個(gè) "偏 小" 就是特征,我們把它的取值為0,1(0代表偏小,1代表偏大)。
那這種模型我們?nèi)绾稳∮?xùn)練它呢?如何去找出這個(gè)z呢?一 種 很 直 觀 的 方 法 就 是 重 構(gòu) 代 價(jià) 最 小 ,我們希望,給一個(gè)訓(xùn)練數(shù)據(jù)x,由x去預(yù)測(cè)隱變量z,再由隱變量z預(yù)測(cè)回x,得到的誤差最小。比如假如我們是藍(lán)色正態(tài)分布,去提取特征z,得到的z再返回來預(yù)測(cè)x,結(jié)果得到的卻是橙色的正態(tài)分布,這是不可取的。其模型圖如下
這個(gè)模型被稱為GMM高斯混合模型
變分自編碼器(VAE)
那它和VAE有什么關(guān)聯(lián)呢?其實(shí)VAE的模型圖跟這個(gè)原理差不多。只是有些許改變, 隱 變 量Z 的 維 度 是 連 續(xù) 且 高 維 的 , 不 再 是 簡(jiǎn) 單 的 離 散 分 布 ,因?yàn)榧偃缥覀兩傻氖菆D片,我們需要提取出來的特征明顯是要很多的,傳統(tǒng)的GMM無法做到。
也就是將訓(xùn)練樣本x給神經(jīng)網(wǎng)絡(luò),讓神經(jīng)網(wǎng)絡(luò)計(jì)算出均值和協(xié)方差矩陣.
取log 的 原 因 是 傳 統(tǒng) 的 神 經(jīng) 網(wǎng) 絡(luò) 輸 出 值 總 是 有 正 有 負(fù) 。有了這兩個(gè)值就可以在對(duì)應(yīng)的高斯分布中采樣,得到隱變量z。再讓z經(jīng)過神經(jīng)網(wǎng)絡(luò)重構(gòu)回樣本,得到新樣本。這就是整個(gè)VAE的大致過程了。
再次強(qiáng)調(diào), 訓(xùn) 練 過 程 我 們 希 望 每 次 重 構(gòu) 的 時(shí) 候 , 新 樣 本 和 訓(xùn) 練 樣 本 盡 可 能 的 相 似。
為什么協(xié)方差會(huì)變成0?因?yàn)椴蓸泳哂须S機(jī)性,也就是存在噪聲,噪聲是肯定會(huì)增加重構(gòu)的誤差的。神經(jīng)網(wǎng)絡(luò)為了讓誤差最小,是肯定讓這個(gè)隨機(jī)性越小越好,因?yàn)橹挥羞@樣,才能重構(gòu)誤差最小。
但是我們肯定是希望有隨機(jī)性的,為什么?因?yàn)橛须S機(jī)性,我們才可以生成不同的樣本?。?/p>
所以,有KL散度去衡量?jī)蓚€(gè)概率分布的相似性
KL散度是大于等于0的值,越小則證明越相似
所以,我們就是兩個(gè)優(yōu)化目標(biāo) ① 最 小 化 重 構(gòu) 代 價(jià) ② 最 小 化 上 述 的 散 度
依照這兩個(gè)條件,建立目標(biāo)函數(shù),直接梯度下降 其 實(shí) 還 需 要 重 參 數(shù) 化 , 后 面 會(huì) 講 到 ,刷刷刷地往下降,最終收斂。
下面,我們就對(duì)其進(jìn)行簡(jiǎn)單的數(shù)學(xué)推導(dǎo),并以此推導(dǎo)出目標(biāo)函數(shù)
原理推導(dǎo)
引入目標(biāo)函數(shù)
以VAE的簡(jiǎn)略圖為例
設(shè)我們有N個(gè)樣本
現(xiàn)在,我們先單獨(dú)看看里面某一個(gè)樣本的似然,某個(gè)樣本記為x
所以左邊等于右邊
按照上面提到的,我們可以把第一步改成
更一般地,我們把它們寫成一起
由于KL散度是大于等于0的,所以第①項(xiàng),就被稱為變分下界。
好,現(xiàn)在我們只需要最大化其變分下界(以下省略掉參數(shù))
細(xì)化目標(biāo)函數(shù)
先 來 看 KL 散 度
可以分為三部分
如果你熟悉高斯分布的高階矩的話,式A和式C完全就是二階原點(diǎn)矩和中心距,是直接可以的得出答案的。
當(dāng)然了,其實(shí)我們也可以不對(duì)其概率分布進(jìn)行約束,歸根究底,其讓然是最小重構(gòu)代價(jià),那么我們的目標(biāo)函數(shù)如果可以充分表達(dá)出“最小重構(gòu)代價(jià)”,那么是什么又有何關(guān)系呢?
可以看到,這就是一個(gè)均方差
重參數(shù)化技巧
有了目標(biāo)函數(shù),理論上我們直接梯度下降就可以了。然而,別忘了,我們是從中采樣出z來??墒俏覀儏s是用的神經(jīng)網(wǎng)絡(luò)去計(jì)算的均值和方差,得到的高斯分布再去采樣,這種情況是不可導(dǎo)的。中間都已經(jīng)出現(xiàn)了一個(gè)斷層了。神經(jīng)網(wǎng)絡(luò)是一層套一層的計(jì)算。而采樣計(jì)算了一層之后,從這一層中去采樣新的值,再計(jì)算下一層。因此,采樣本身是不可導(dǎo)的。
代碼實(shí)現(xiàn)
效果一般,不曉得論文里面用了什么手段,效果看起來比這個(gè)好。(這個(gè)結(jié)果甚至還是我加了一層隱藏層的)
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
class VAE(nn.Module):
def __init__(self,input_dim,hidden_dim,gaussian_dim):
super().__init__()
#編碼器
#隱藏層
self.fc1=nn.Sequential(
nn.Linear(in_features=input_dim,out_features=hidden_dim),
nn.Tanh(),
nn.Linear(in_features=hidden_dim, out_features=256),
nn.Tanh(),
)
#μ和logσ^2
self.mu=nn.Linear(in_features=256,out_features=gaussian_dim)
self.log_sigma=nn.Linear(in_features=256,out_features=gaussian_dim)
#解碼(重構(gòu))
self.fc2=nn.Sequential(
nn.Linear(in_features=gaussian_dim,out_features=256),
nn.Tanh(),
nn.Linear(in_features=256, out_features=512),
nn.Tanh(),
nn.Linear(in_features=512,out_features=input_dim),
nn.Sigmoid() #圖片被轉(zhuǎn)為為0,1的值了,故用此函數(shù)
)
def forward(self,x):
#隱藏層
h=self.fc1(x)
#計(jì)算期望和log方差
mu=self.mu(h)
log_sigma=self.log_sigma(h)
#重參數(shù)化
h_sample=self.reparameterization(mu,log_sigma)
#重構(gòu)
recnotallow=self.fc2(h_sample)
return reconsitution,mu,log_sigma
def reparameterization(self,mu,log_sigma):
#重參數(shù)化
sigma=torch.exp(log_sigma*0.5) #計(jì)算σ
e=torch.randn_like(input=sigma,device=device)
result=mu+e*sigma #依據(jù)重參數(shù)化技巧可得
return result
def predict(self,new_x): #預(yù)測(cè)
recnotallow=self.fc2(new_x)
return reconsitution
def train():
transformer = transforms.Compose([
transforms.ToTensor(),
]) #歸一化
data = MNIST("./data", transform=transformer,download=True) #載入數(shù)據(jù)
dataloader = DataLoader(data, batch_size=128, shuffle=True) #寫入加載器
model = VAE(784, 512, 20).to(device) #初始化模型
optimer = torch.optim.Adam(model.parameters(), lr=1e-3) #初始化優(yōu)化器
loss_fn = nn.MSELoss(reductinotallow="sum") #均方差損失
epochs = 100 #訓(xùn)練100輪
for epoch in torch.arange(epochs):
all_loss = 0
dataloader_len = len(dataloader.dataset)
for data in tqdm(dataloader, desc="第{}輪梯度下降".format(epoch)):
sample, label = data
sample = sample.to(device)
sample = sample.reshape(-1, 784) #重塑
result, mu, log_sigma = model(sample) #預(yù)測(cè)
loss_likelihood = loss_fn(sample, result) #計(jì)算似然損失
#計(jì)算KL損失
loss_KL = torch.pow(mu, 2) + torch.exp(log_sigma) - log_sigma - 1
#總損失
loss = loss_likelihood + 0.5 * torch.sum(loss_KL)
#梯度歸0并反向傳播和更新
optimer.zero_grad()
loss.backward()
optimer.step()
with torch.no_grad():
all_loss += loss.item()
print("函數(shù)損失為:{}".format(all_loss / dataloader_len))
torch.save(model, "./model/VAE.pth")
if __name__ == '__main__':
#是否有閑置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#訓(xùn)練
train()
#載入模型,預(yù)測(cè)
model=torch.load("./model/VAE (1).pth",map_locatinotallow="cpu")
#預(yù)測(cè)20個(gè)樣本
x=torch.randn(size=(20,20))
result=model.predict(x).detach().numpy()
result=result.reshape(-1,28,28)
#繪圖
for i in range(20):
plt.subplot(4,5,i+1)
plt.imshow(result[i])
plt.gray()
plt.show()
VAE有很多的變種優(yōu)化,感興趣的讀者自行查閱。
結(jié)束
以上,就是VAE的原理和推導(dǎo)過程了。能力有限,過程并不嚴(yán)謹(jǐn),如有問題,還望指出。
本文轉(zhuǎn)自 AI生成未來 ,作者:篝火者2312
