簡單使用PyTorch搭建GAN模型
以往人們普遍認為生成圖像是不可能完成的任務,因為按照傳統(tǒng)的機器學習思路,我們根本沒有真值(ground truth)可以拿來檢驗生成的圖像是否合格。
2014年,Goodfellow等人則提出生成 對抗網絡(Generative Adversarial Network, GAN) ,能夠讓我們完全依靠機器學習來生成極為逼真的圖片。GAN的橫空出世使得整個人工智能行業(yè)都為之震動,計算機視覺和圖像生成領域發(fā)生了巨變。
本文將帶大家了解 GAN的工作原理 ,并介紹如何 通過PyTorch簡單上手GAN 。
GAN的原理
按照傳統(tǒng)的方法,模型的預測結果可以直接與已有的真值進行比較。然而,我們卻很難定義和衡量到底怎樣才算作是“正確的”生成圖像。
Goodfellow等人則提出了一個有趣的解決辦法:我們可以先訓練好一個分類工具,來自動區(qū)分生成圖像和真實圖像。這樣一來,我們就可以用這個分類工具來訓練一個生成網絡,直到它能夠輸出完全以假亂真的圖像,連分類工具自己都沒有辦法評判真假。
按照這一思路,我們便有了GAN:也就是一個 生成器(generator) 和一個 判別器(discriminator) 。生成器負責根據給定的數據集生成圖像,判別器則負責區(qū)分圖像是真是假。GAN的運作流程如上圖所示。
損失函數
在GAN的運作流程中,我們可以發(fā)現一個明顯的矛盾:同時優(yōu)化生成器和判別器是很困難的??梢韵胂?,這兩個模型有著完全相反的目標:生成器想要盡可能偽造出真實的東西,而判別器則必須要識破生成器生成的圖像。
為了說明這一點,我們設D(x)為判別器的輸出,即x是真實圖像的概率,并設G(z)為生成器的輸出。判別器類似于一種二進制的分類器,所以其目標是使該函數的結果最大化:這一函數本質上是非負的二元交叉熵損失函數。另一方面,生成器的目標是最小化判別器做出正確判斷的機率,因此它的目標是使上述函數的結果最小化。
因此,最終的損失函數將會是兩個分類器之間的極小極大博弈,表示如下:理論上來說,博弈的最終結果將是讓判別器判斷成功的概率收斂到0.5。然而在實踐中,極大極小博弈通常會導致網絡不收斂,因此仔細調整模型訓練的參數非常重要。
在訓練GAN時,我們尤其要注意學習率等超參數,學習率比較小時能讓GAN在輸入噪音較多的情況下也能有較為統(tǒng)一的輸出。
計算環(huán)境
庫
本文將指導大家通過PyTorch搭建整個程序(包括torchvision)。同時,我們將會使用Matplotlib來讓GAN的生成結果可視化。以下代碼能夠導入上述所有庫:
- """
- Import necessary libraries to create a generative adversarial network
- The code is mainly developed using the PyTorch library
- """
- import time
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- from torchvision import datasets
- from torchvision.transforms import transforms
- from model import discriminator, generator
- import numpy as np
- import matplotlib.pyplot as plt
數據集
數據集對于訓練GAN來說非常重要,尤其考慮到我們在GAN中處理的通常是非結構化數據(一般是圖片、視頻等),任意一class都可以有數據的分布。這種數據分布恰恰是GAN生成輸出的基礎。
為了更好地演示GAN的搭建流程,本文將帶大家使用最簡單的MNIST數據集,其中含有6萬張手寫阿拉伯數字的圖片。
像 MNIST 這樣高質量的非結構化數據集都可以在 格物鈦 的 公開數據集 網站上找到。事實上,格物鈦Open Datasets平臺涵蓋了很多優(yōu)質的公開數據集,同時也可以實現 數據集托管及一站式搜索的功能 ,這對AI開發(fā)者來說,是相當實用的社區(qū)平臺。
硬件需求
一般來說,雖然可以使用CPU來訓練神經網絡,但最佳選擇其實是GPU,因為這樣可以大幅提升訓練速度。我們可以用下面的代碼來測試自己的機器能否用GPU來訓練:
- """
- Determine if any GPUs are available
- """
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實現
網絡結構
由于數字是非常簡單的信息,我們可以將判別器和生成器這兩層結構都組建成全連接層(fully connected layers)。
我們可以用以下代碼在PyTorch中搭建判別器和生成器:
- """
- Network Architectures
- The following are the discriminator and generator architectures
- """
- class discriminator(nn.Module):
- def __init__(self):
- super(discriminator, self).__init__()
- self.fc1 = nn.Linear(784, 512)
- self.fc2 = nn.Linear(512, 1)
- self.activation = nn.LeakyReLU(0.1)
- def forward(self, x):
- x = x.view(-1, 784)
- x = self.activation(self.fc1(x))
- x = self.fc2(x)
- return nn.Sigmoid()(x)
- class generator(nn.Module):
- def __init__(self):
- super(generator, self).__init__()
- self.fc1 = nn.Linear(128, 1024)
- self.fc2 = nn.Linear(1024, 2048)
- self.fc3 = nn.Linear(2048, 784)
- self.activation = nn.ReLU()
- def forward(self, x):
- x = self.activation(self.fc1(x))
- x = self.activation(self.fc2(x))
- x = self.fc3(x)
- x = x.view(-1, 1, 28, 28)
- return nn.Tanh()(x)
訓練
在訓練GAN的時候,我們需要一邊優(yōu)化判別器,一邊改進生成器,因此每次迭代我們都需要同時優(yōu)化兩個互相矛盾的損失函數。
對于生成器,我們將輸入一些隨機噪音,讓生成器來根據噪音的微小改變輸出的圖像:
- """
- Network training procedure
- Every step both the loss for disciminator and generator is updated
- Discriminator aims to classify reals and fakes
- Generator aims to generate images as realistic as possible
- """
- for epoch in range(epochs):
- for idx, (imgs, _) in enumerate(train_loader):
- idx += 1
- # Training the discriminator
- # Real inputs are actual images of the MNIST dataset
- # Fake inputs are from the generator
- # Real inputs should be classified as 1 and fake as 0
- real_inputs = imgs.to(device)
- real_outputs = D(real_inputs)
- real_label = torch.ones(real_inputs.shape[0], 1).to(device)
- noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
- outputs = torch.cat((real_outputs, fake_outputs), 0)
- targets = torch.cat((real_label, fake_label), 0)
- D_loss = loss(outputs, targets)
- D_optimizer.zero_grad()
- D_loss.backward()
- D_optimizer.step()
- # Training the generator
- # For generator, goal is to make the discriminator believe everything is 1
- noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
- G_loss = loss(fake_outputs, fake_targets)
- G_optimizer.zero_grad()
- G_loss.backward()
- G_optimizer.step()
- if idx % 100 == 0 or idx == len(train_loader):
- print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
- if (epoch+1) % 10 == 0:
- torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
- print('Model saved.')
結果
經過100個訓練時期之后,我們就可以對數據集進行可視化處理,直接看到模型從隨機噪音生成的數字:
我們可以看到,生成的結果和真實的數據非常相像??紤]到我們在這里只是搭建了一個非常簡單的模型,實際的應用效果會有非常大的上升空間。
不僅是有樣學樣
GAN和以往機器視覺專家提出的想法都不一樣,而利用GAN進行的具體場景應用更是讓許多人贊嘆深度網絡的無限潛力。下面我們來看一下兩個最為出名的GAN延申應用。
CycleGAN
朱俊彥等人2017年發(fā)表的CycleGAN能夠在沒有配對圖片的情況下將一張圖片從X域直接轉換到Y域,比如把馬變成斑馬、將熱夏變成隆冬、把莫奈的畫變成梵高的畫等等。這些看似天方夜譚的轉換CycleGAN都能輕松做到,并且結果非常準確。
GauGAN
英偉達則通過GAN讓人們能夠只需要寥寥數筆勾勒出自己的想法,便能得到一張極為逼真的真實場景圖片。雖然這種應用需要的計算成本極為高昂,但是GauGAN憑借它的轉換能力探索出了前所未有的研究和應用領域。
結語
相信看到這里,你已經知道了GAN的大致工作原理,并且能夠自己動手簡單搭建一個GAN了。