對(duì)抗生成網(wǎng)絡(luò)GAN 原創(chuàng) 精華
前言
在前一階段課程中,我們學(xué)習(xí)了圖像分割中的語義分割、實(shí)例分割、旋轉(zhuǎn)目標(biāo)檢測(cè)等。這些圖像分割算法都是有監(jiān)督學(xué)習(xí),而GAN(生成對(duì)抗網(wǎng)絡(luò))是一種特別的學(xué)習(xí)策略,其核心思想非常值得借鑒,所以本章將以GAN網(wǎng)絡(luò)的代碼為切入口,了解掌握其核心思想。
學(xué)習(xí)策略
人工智能方面的學(xué)習(xí)策略有兩種:有監(jiān)督學(xué)習(xí)和無監(jiān)督學(xué)習(xí)。
有監(jiān)督學(xué)習(xí)
定義:有監(jiān)督學(xué)習(xí)是使用帶標(biāo)簽的數(shù)據(jù)集進(jìn)行訓(xùn)練。每個(gè)輸入數(shù)據(jù)都有對(duì)應(yīng)的輸出標(biāo)簽,模型通過學(xué)習(xí)輸入與輸出之間的關(guān)系來進(jìn)行預(yù)測(cè)。
舉個(gè)例子:孩子的個(gè)人成長(zhǎng),有經(jīng)驗(yàn)的家長(zhǎng)為期規(guī)劃了發(fā)展的路線,孩子在規(guī)劃下有計(jì)劃地學(xué)習(xí)成長(zhǎng),這屬于有監(jiān)督學(xué)習(xí)。
無監(jiān)督學(xué)習(xí)
定義:無監(jiān)督學(xué)習(xí)使用沒有標(biāo)簽的數(shù)據(jù)集進(jìn)行訓(xùn)練。模型試圖發(fā)現(xiàn)數(shù)據(jù)中的模式或結(jié)構(gòu),而不依賴于任何預(yù)先定義的標(biāo)簽。
同樣的例子:孩子在無監(jiān)督學(xué)習(xí)下,是沒有家長(zhǎng)為期進(jìn)行規(guī)劃,而是經(jīng)歷社會(huì)"捶打"(做得好了有加分、做不好扣分),最終學(xué)習(xí)成長(zhǎng)起來。
GAN的基礎(chǔ)介紹
在上述的兩種學(xué)習(xí)策略中,有一種特殊的、獨(dú)立的學(xué)習(xí)策略:GAN(生成對(duì)抗網(wǎng)絡(luò))。
它由兩個(gè)網(wǎng)絡(luò)(生成器和判別器),通過對(duì)抗在競(jìng)爭(zhēng)中共同發(fā)展。
- G:生成器(造假)
- D:鑒別器(打假)
- 訓(xùn)練過程:
兩個(gè)網(wǎng)絡(luò)剛開始都沒有任何能力
在競(jìng)爭(zhēng)中共同發(fā)展
最后兩個(gè)網(wǎng)絡(luò)能力都得到提升
舉個(gè)例子:GAN網(wǎng)絡(luò)就像警察和小偷,警察和小偷之間互相對(duì)抗。
GAN示例
為了對(duì)GAN網(wǎng)絡(luò)有個(gè)直觀印象,我們可以參考Github上一個(gè)開源項(xiàng)目,對(duì)GAN有個(gè)初步認(rèn)知。
頁面地址:https://poloclub.github.io/ganlab/
示例目的
- 在頁面中添加一個(gè)手寫數(shù)字圖像
- 通過訓(xùn)練模型來模擬手寫數(shù)字圖像
- 從而達(dá)到新圖像與原來的風(fēng)格類似,分不出真假
核心思想
論文地址:https://arxiv.org/pdf/1406.2661
生成器(Generator):
- 作用:負(fù)責(zé)憑空編造假的數(shù)據(jù)出來。
- 目的:通過機(jī)器生成假的數(shù)據(jù)(大部分情況下是圖像),最終目的是“騙過”判別器。
- 過程:G是一個(gè)生成圖片的網(wǎng)絡(luò),它接收一個(gè)隨機(jī)的噪聲z,通過這個(gè)噪聲生成圖片,記做G(z)。
判別器(Discriminator):
- 作用:負(fù)責(zé)判斷傳來的數(shù)據(jù)是真還是假。
- 目的:判斷這張圖像是真實(shí)的還是機(jī)器生成的,目的是找出生成器做的“假數(shù)據(jù)”。
- 過程:D是一個(gè)判別網(wǎng)絡(luò),判別一張圖片是不是“真實(shí)的”。它的輸入?yún)?shù)是x,x代表一張圖片,輸出D(x)代表x為真實(shí)圖片的概率,如果為1,就代表100%是真實(shí)的圖片,而輸出為0,就代表不可能是真實(shí)的圖片。
接下來,我們通過詳細(xì)了解GAN網(wǎng)絡(luò)的代碼,深入了解其運(yùn)行過程。
引入依賴
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
from torch.utils.data importDataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
# 判斷當(dāng)前設(shè)備是否GPU
device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
device
讀取數(shù)據(jù)集
# 加載并預(yù)處理圖像
data = datasets.MNIST(root="data",
train=True,
transform = transforms.Compose(transforms=[transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])]),
download=True)
# 封裝成 DataLoader
data_loader = DataLoader(dataset=data, batch_size=100, shuffle=True)
- 備注:上述?
?transform = transforms.Compose?
? 的作用主要是進(jìn)行數(shù)據(jù)增強(qiáng),詳細(xì)內(nèi)容在補(bǔ)充知識(shí)部分展開介紹。
定義模型
定義生成器
"""
定義生成器
"""
classGenerator(nn.Module):
"""
定義一個(gè)圖像生成
輸入:一個(gè)向量
輸出:一個(gè)向量(代表圖像)
"""
def__init__(self, in_features=100, out_features=28 * 28):
"""
掛載超參數(shù)
"""
# 先初始化父類,再初始化子類
super(Generator, self).__init__()
self.in_features = in_features
self.out_features = out_features
# 第一個(gè)隱藏層
self.hidden0 = nn.Linear(in_features=self.in_features, out_features=256)
# 第二個(gè)隱藏層
self.hidden1 = nn.Linear(in_features=256, out_features=512)
# 第三個(gè)隱藏層
self.hidden2 = nn.Linear(in_features=512, out_features=self.out_features)
defforward(self, x):
# 第一層 [b, 100] --> [b, 256]
h = self.hidden0(x)
h = F.leaky_relu(input=h, negative_slope=0.2)
# 第二層 [b, 256] --> [b, 512]
h = self.hidden1(h)
h = F.leaky_relu(input=h, negative_slope=0.2)
# 第三層 [b, 512] --> [b, 28 * 28]
h = self.hidden2(h)
# 壓縮數(shù)據(jù)的變化范圍
o = torch.tanh(h)
return o
定義鑒別器
"""
定義一個(gè)鑒別器
"""
classDiscriminator(nn.Module):
"""
本質(zhì):二分類分類器
輸入:一個(gè)對(duì)象
輸出:真品還是贗品
"""
def__init__(self, in_features=28*28, out_features=1):
super(Discriminator, self).__init__()
self.in_features=in_features
self.out_features=out_features
# 第一個(gè)隱藏層
self.hidden0= nn.Linear(in_features=self.in_features, out_features=512)
# 第二個(gè)隱藏層
self.hidden1= nn.Linear(in_features=512, out_features=256)
# 第三個(gè)隱藏層
self.hidden2= nn.Linear(in_features=256, out_features=32)
# 第四個(gè)隱藏層
self.hidden3= nn.Linear(in_features=32, out_features=self.out_features)
defforward(self, x):
# 第一層
h = self.hidden0(x)
h = F.leaky_relu(input=h, negative_slope=0.2)
h = F.dropout(input=h, p=0.2)
# 第二層
h = self.hidden1(h)
h = F.leaky_relu(input=h, negative_slope=0.2)
h = F.dropout(input=h, p=0.2)
# 第三層
h = self.hidden2(h)
h = F.leaky_relu(input=h, negative_slope=0.2)
h = F.dropout(input=h, p=0.2)
# 第四層
h = self.hidden3(h)
# 輸出概率
o = torch.sigmoid(h)
return o
構(gòu)建模型
"""
構(gòu)建模型
"""
# 定義一個(gè)生成器
generator = Generator(in_features=100, out_features=784)
generator.to(device=device)
# 定義一個(gè)鑒別器
discriminator = Discriminator(in_features=784, out_features=1)
discriminator.to(device=device)
定義優(yōu)化器
"""
定義優(yōu)化器
"""
# 定義一個(gè)生成器的優(yōu)化器
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=1e-4)
# 定義一個(gè)鑒別的優(yōu)化器
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)
定義損失函數(shù)
"""
定義一個(gè)損失函數(shù)
"""
loss_fn = nn.BCELoss()
籌備訓(xùn)練
定義訓(xùn)練輪次
# 定義訓(xùn)練輪次
num_epochs = 1000
獲取數(shù)據(jù)的標(biāo)簽
"""
獲取數(shù)據(jù)的標(biāo)簽
"""
defget_real_data_labels(size):
"""
獲取真實(shí)數(shù)據(jù)的標(biāo)簽
"""
labels = torch.ones(size,1, device=device)
return labels
defget_fake_data_labels(size):
"""
獲取虛假數(shù)據(jù)的標(biāo)簽
"""
labels = torch.zeros(size,1, device=device)
return labels
定義噪聲生成器
"""
噪聲生成器
"""
defget_noise(size):
"""
給生成器準(zhǔn)備數(shù)據(jù)
- 100維度的向量
"""
X = torch.randn(size,100, device=device)
return X
# 獲取一批測(cè)試數(shù)據(jù)
num_test_samples =16
test_noise = get_noise(num_test_samples)
噪聲生成器的作用:因?yàn)槲覀冃枰O(jiān)控模型訓(xùn)練的效果,所以將噪聲固定下來,在訓(xùn)練過程中看同樣的噪聲最后給出的結(jié)果是否變得越來越好。
訓(xùn)練模型
"""
訓(xùn)練過程
"""
g_losses =[]
d_losses =[]
for epoch inrange(1, num_epochs+1):
print(f"當(dāng)前正在進(jìn)行 第 {epoch} 輪 ....")
# 設(shè)置訓(xùn)練模式
generator.train()
discriminator.train()
# 遍歷真實(shí)的圖像
for batch_idx,(batch_real_data, _)inenumerate(data_loader):
"""
1, 先訓(xùn)練鑒別器
鑒別器就是一個(gè)二分類問題
- 給一批真數(shù)據(jù),輸出真
- 給一批假數(shù)據(jù),輸出假
"""
# 1.1 準(zhǔn)備數(shù)據(jù)
# 圖像轉(zhuǎn)向量 [b, 1, 28, 28] ---> [b, 784]
# 從數(shù)據(jù)集中獲取100個(gè)真實(shí)的手寫數(shù)字圖像
real_data = batch_real_data.view(batch_real_data.size(0),-1).to(device=device)
# 噪聲[b, 100]
# 隨機(jī)生成100個(gè)100維度的噪聲,用于生成假圖像
noise = get_noise(real_data.size(0))
# 根據(jù)噪聲,生成假數(shù)據(jù)
# [b, 100] --> [b, 784]
fake_data = generator(noise).detach()
# 1.2 訓(xùn)練過程
# 鑒別器的優(yōu)化器梯度情況
d_optimizer.zero_grad()
# 對(duì)真實(shí)數(shù)據(jù)鑒別
real_pred = discriminator(real_data)
# 計(jì)算真實(shí)數(shù)據(jù)的誤差
real_loss = loss_fn(real_pred, get_real_data_labels(real_data.size(0)))
# 真實(shí)數(shù)據(jù)的梯度回傳
real_loss.backward()
# 對(duì)假數(shù)據(jù)鑒別
fake_pred = discriminator(fake_data)
# 計(jì)算假數(shù)據(jù)的誤差
fake_loss = loss_fn(fake_pred, get_fake_data_labels(fake_data.size(0)))
# 假數(shù)據(jù)梯度回傳
fake_loss.backward()
# 梯度更新
d_optimizer.step()
# ----------------
d_losses.append((real_loss + fake_loss).item())
# print(f"鑒別器的損失:{real_loss + fake_loss}")
"""2, 再訓(xùn)練生成器"""
# 獲取生成器的生成結(jié)果
fake_pred = generator(get_noise(real_data.size(0)))
# 生產(chǎn)器梯度清空
g_optimizer.zero_grad()
# 把假數(shù)據(jù)讓鑒別器鑒別一下
# 把discriminator requires_grad = False
# 設(shè)置為不可學(xué)習(xí)
for param in discriminator.parameters():
param.requires_grad =False
d_pred = discriminator(fake_pred)
# 設(shè)置為可學(xué)習(xí)
for param in discriminator.parameters():
param.requires_grad =True
# 計(jì)算損失
# 把一個(gè)假東西,給專家看,專家說是真的,這個(gè)時(shí)候,造假的水平就可以了
g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))
# 梯度回傳
g_loss.backward()
# 參數(shù)更新
g_optimizer.step()
# print(f"生成器誤差:{g_loss}")
g_losses.append(g_loss.item())
# 每輪訓(xùn)練之后,觀察生成器的效果
generator.eval()
with torch.no_grad():
# 正向推理
img_pred = generator(test_noise)
img_pred = img_pred.view(img_pred.size(0),28,28).cpu().data
# 畫圖
display.clear_output(wait=True)
# 設(shè)置畫圖的大小
fig = plt.figure(1, figsize=(12,8))
# 劃分為 4 x 4 的 網(wǎng)格
gs = gridspec.GridSpec(4,4)
# 遍歷每一個(gè)
for i inrange(4):
for j inrange(4):
# 取每一個(gè)圖
X = img_pred[i *4+ j,:,:]
# 添加一個(gè)對(duì)應(yīng)網(wǎng)格內(nèi)的子圖
ax = fig.add_subplot(gs[i, j])
# 在子圖內(nèi)繪制圖像
ax.matshow(X, cmap=plt.get_cmap("Greys"))
# ax.set_xlabel(f"{label}")
ax.set_xticks(())
ax.set_yticks(())
plt.show()
運(yùn)行結(jié)果:
核心代碼說明:
訓(xùn)練過程
- 隨機(jī)生成一組潛在向量z,并使用生成器生成一組假數(shù)據(jù)。
- 將一組真實(shí)數(shù)據(jù)和一組假數(shù)據(jù)作為輸入,訓(xùn)練判別器。
- 使用生成器生成一組新的假數(shù)據(jù),并訓(xùn)練判別器。
- 重復(fù)步驟2和3,直到生成器生成的假數(shù)據(jù)與真實(shí)數(shù)據(jù)的分布相似。
核心代碼
- ?
?fake_data = generator(noise).detach()?
?:
作用:是生成器生成一組假數(shù)據(jù),并使用detach()方法將其從計(jì)算圖中分離出來,防止梯度回傳。
說明:(因?yàn)樵谟?xùn)練鑒別器時(shí),生成器只是工具人,其前向傳播過程中記錄的梯度信息不會(huì)被使用,所以不需要記錄梯度信息)
- ?
?g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))?
? 這里是體現(xiàn)對(duì)抗的核心代碼,即:生成器訓(xùn)練的好不好,是要與真實(shí)數(shù)據(jù)的判別結(jié)果越接近越好。
補(bǔ)充知識(shí)
數(shù)據(jù)增強(qiáng)
在人工智能模型的訓(xùn)練中,采集樣本是需要成本的,所以為了提升樣本的豐富性,一般會(huì)采用數(shù)據(jù)增強(qiáng)的方式。
- 方式:在樣本固定的基礎(chǔ)上,通過軟件模擬,來生成假數(shù)據(jù),豐富樣本的多樣性
- 本質(zhì):給樣本加上適當(dāng)?shù)脑肼?,模擬出不同場(chǎng)景的樣本
- 說明:數(shù)據(jù)增強(qiáng)只發(fā)生在模型訓(xùn)練中,為了增加訓(xùn)練樣本的多樣性
transform介紹
在 PyTorch 中,transform 主要用于數(shù)據(jù)預(yù)處理和增強(qiáng),特別是在圖像處理任務(wù)中。transform 是 torchvision 庫的一部分,能夠?qū)?shù)據(jù)集中的圖像進(jìn)行各種轉(zhuǎn)換,以便更好地適應(yīng)模型訓(xùn)練的需求。以下是 transform 的主要作用
import torch
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
# 讀取本地下載的一張圖片
img = Image.open('girl.png')
img
重設(shè)圖片尺寸
resize = transforms.Resize((300, 200))
resize_img = resize(img)
resize_img
運(yùn)行效果:
中心裁剪
centercrop = transforms.CenterCrop(size=(200, 200))
center_img = centercrop(img)
center_img
運(yùn)行效果:
隨機(jī)調(diào)整亮度、飽和度、對(duì)比度等
color_jitter = transforms.ColorJitter(brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.5)
color_jitter(img)
運(yùn)行效果:
隨機(jī)旋轉(zhuǎn)
random_rotation = transforms.RandomRotation(degrees=10)
random_rotation(img)
運(yùn)行效果:
組合變換
Compose:可以將多個(gè)變換組合在一起,形成一個(gè)轉(zhuǎn)換管道,方便批量處理。例如:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # 將PIL Image轉(zhuǎn)換為Tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 將數(shù)據(jù)歸一化到[-1, 1]之間
])
內(nèi)容小結(jié)
- GAN(生成對(duì)抗網(wǎng)絡(luò))是一種特殊的學(xué)習(xí)策略,它由生成器和判別器組成,生成器生成假數(shù)據(jù),判別器判斷真假。
- 生成器(Generator)通過機(jī)器生成假的數(shù)據(jù)(大部分情況下是圖像),最終目的是“騙過”判別器。
- 判別器(Discriminator):判斷這張圖像是真實(shí)的還是機(jī)器生成的,目的是找出生成器做的“假數(shù)據(jù)”。
- 訓(xùn)練過程是:先訓(xùn)練判別器,再訓(xùn)練生成器。
- 訓(xùn)練判別器時(shí),生成器是"工具人",所以需要使用detach()方法,將生成器生成的假數(shù)據(jù)從計(jì)算圖中分離出來,防止梯度回傳。
- 訓(xùn)練生成器時(shí),判別器是"工具人",為了避免整個(gè)梯度消失,需要使用param.requires_grad = False設(shè)置為不可學(xué)習(xí),判別完之后再使用param.requires_grad = True設(shè)置為可學(xué)習(xí)。
- 在人工智能模型訓(xùn)練過程中,通常會(huì)使用數(shù)據(jù)增強(qiáng)的方式,在樣本固定的基礎(chǔ)上,通過軟件模擬,來生成假數(shù)據(jù),豐富樣本的多樣性。
- transform:在 PyTorch 中,transform 主要用于數(shù)據(jù)預(yù)處理和增強(qiáng),特別是在圖像處理任務(wù)中。
?
本文轉(zhuǎn)載自公眾號(hào)一起AI技術(shù) 作者:Dongming
