CycleGAN生成對(duì)抗網(wǎng)絡(luò)圖像處理工具
1. GAN簡(jiǎn)介
"干飯人,干飯魂,干飯都是人上人"。
此GAN飯人非彼干飯人。本文要講的GAN是Goodfellow2014提出的生成產(chǎn)生對(duì)抗模型,即Generative Adversarial Nets。那么GAN到底有什么神奇的地方?
常規(guī)的深度學(xué)習(xí)任務(wù)如圖像分類,目標(biāo)檢測(cè)以及語(yǔ)義分割或者實(shí)例分割,這些任務(wù)的結(jié)果都可以歸結(jié)為預(yù)測(cè)。圖像分類是預(yù)測(cè)單一的類別,目標(biāo)檢測(cè)是預(yù)測(cè)bbox和類別,語(yǔ)義分割或者實(shí)例分割是預(yù)測(cè)每個(gè)像素的類別。而GAN是生成一個(gè)新的東西如一個(gè)圖片。
GAN的原理用一句話來說明:
- 通過對(duì)抗的方式,去學(xué)習(xí)數(shù)據(jù)分布的生成式模型。GAN是無監(jiān)督的過程,能夠捕捉數(shù)據(jù)集的分布,以便于可以從隨機(jī)噪聲中生成同樣分布的數(shù)據(jù)
GAN的組成:判別式模型和生成式模型的左右手博弈
- D判別式模型:學(xué)習(xí)真假邊界,判斷數(shù)據(jù)是真的還是假的
- G生成式模型:學(xué)習(xí)數(shù)據(jù)分布并生成數(shù)據(jù)
GAN經(jīng)典的loss如下(minmax體現(xiàn)的就是對(duì)抗)
2. 實(shí)戰(zhàn)cycleGAN 風(fēng)格轉(zhuǎn)換
了解了GAN的作用,來體驗(yàn)的GAN的神奇效果。這里以cycleGAN為例子來實(shí)現(xiàn)圖像的風(fēng)格轉(zhuǎn)換。所謂的風(fēng)格轉(zhuǎn)換就是改變?cè)紙D片的風(fēng)格,如下圖左邊是原圖,中間是風(fēng)格圖(梵高畫),生成后是右邊的具有梵高風(fēng)格的原圖,可以看到總體上生成后的圖保留大部分原圖的內(nèi)容。
2.1 cycleGAN簡(jiǎn)介
cycleGAN本質(zhì)上和GAN是一樣的,是學(xué)習(xí)數(shù)據(jù)集中潛在的數(shù)據(jù)分布。GAN是從隨機(jī)噪聲生成同分布的圖片,cycleGAN是在有意義的圖上加上學(xué)習(xí)到的分布從而生成另一個(gè)領(lǐng)域的圖。cycleGAN假設(shè)image-to-image的兩個(gè)領(lǐng)域存在的潛在的聯(lián)系。
眾所周知,GAN的映射函數(shù)很難保證生成圖片的有效性。cycleGAN利用cycle consistency來保證生成的圖片與輸入圖片的結(jié)構(gòu)上一致性。我們看下cycleGAN的結(jié)構(gòu):
特點(diǎn)總結(jié)如下:
- 兩路GAN:兩個(gè)生成器[ G:X->Y , F:Y->X ] 和兩個(gè)判別器[Dx, Dy], G和Dy目的是生成的對(duì)象,Dy(正類是Y領(lǐng)域)無法判別。同理F和Dx也是一樣的。
- cycle consistency:G是生成Y的生成器, F是生成X的生成器,cycle consistency是為了約束G和F生成的對(duì)象的范圍, 是的G生成的對(duì)象通過F生成器能夠回到原始的領(lǐng)域如:x->G(x)->F(G(x))=x
對(duì)抗loss如下:
2.2 實(shí)現(xiàn)cycleGAN
2.2.1 生成器
從上面簡(jiǎn)介中生成器有兩個(gè)生成器,一個(gè)是正向,一個(gè)是反向的。結(jié)構(gòu)是參考論文Perceptual Losses for Real-Time Style Transfer and Super-Resolution: Supplementary Material。大致可以分為:下采樣 + residual 殘差block + 上采樣,如下圖(摘自論文):
實(shí)現(xiàn)上下采樣是stride=2的卷積, 上采樣用nn.Upsample:
- # 殘差block
- class ResidualBlock(nn.Module):
- def __init__(self, in_features):
- super(ResidualBlock, self).__init__()
- self.block = nn.Sequential(
- nn.ReflectionPad2d(1),
- nn.Conv2d(in_features, in_features, 3),
- nn.InstanceNorm2d(in_features),
- nn.ReLU(inplace=True),
- nn.ReflectionPad2d(1),
- nn.Conv2d(in_features, in_features, 3),
- nn.InstanceNorm2d(in_features),
- )
- def forward(self, x):
- return x + self.block(x)
- class GeneratorResNet(nn.Module):
- def __init__(self, input_shape, num_residual_blocks):
- super(GeneratorResNet, self).__init__()
- channels = input_shape[0]
- # Initial convolution block
- out_features = 64
- model = [
- nn.ReflectionPad2d(channels),
- nn.Conv2d(channels, out_features, 7),
- nn.InstanceNorm2d(out_features),
- nn.ReLU(inplace=True),
- ]
- in_features = out_features
- # Downsampling
- for _ in range(2):
- out_features *= 2
- model += [
- nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
- nn.InstanceNorm2d(out_features),
- nn.ReLU(inplace=True),
- ]
- in_features = out_features
- # Residual blocks
- for _ in range(num_residual_blocks):
- model += [ResidualBlock(out_features)]
- # Upsampling
- for _ in range(2):
- out_features //= 2
- model += [
- nn.Upsample(scale_factor=2),
- nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
- nn.InstanceNorm2d(out_features),
- nn.ReLU(inplace=True),
- ]
- in_features = out_features
- # Output layer
- model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
- self.model = nn.Sequential(*model)
- def forward(self, x):
- return self.model(x)
2.2.2 判別器
傳統(tǒng)的GAN 判別器輸出的是一個(gè)值,判斷真假的程度。而patchGAN輸出是N*N值,每一個(gè)值代表著原始圖像上的一定大小的感受野,直觀上就是對(duì)原圖上crop下可重復(fù)的一部分區(qū)域進(jìn)行判斷真假,可以認(rèn)為是一個(gè)全卷積網(wǎng)絡(luò),最早是在pix2pix提出(Image-to-Image Translation with Conditional Adversarial Networks)。好處是參數(shù)少,另外一個(gè)從局部可以更好的抓取高頻信息。
- class Discriminator(nn.Module):
- def __init__(self, input_shape):
- super(Discriminator, self).__init__()
- channels, height, width = input_shape
- # Calculate output shape of image discriminator (PatchGAN)
- self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
- def discriminator_block(in_filters, out_filters, normalize=True):
- """Returns downsampling layers of each discriminator block"""
- layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
- if normalize:
- layers.append(nn.InstanceNorm2d(out_filters))
- layers.append(nn.LeakyReLU(0.2, inplace=True))
- return layers
- self.model = nn.Sequential(
- *discriminator_block(channels, 64, normalize=False),
- *discriminator_block(64, 128),
- *discriminator_block(128, 256),
- *discriminator_block(256, 512),
- nn.ZeroPad2d((1, 0, 1, 0)),
- nn.Conv2d(512, 1, 4, padding=1)
- )
- def forward(self, img):
- return self.model(img)
2.2.3 訓(xùn)練
loss和模型初始化
- # Losses
- criterion_GAN = torch.nn.MSELoss()
- criterion_cycle = torch.nn.L1Loss()
- criterion_identity = torch.nn.L1Loss()
- cuda = torch.cuda.is_available()
- input_shape = (opt.channels, opt.img_height, opt.img_width)
- # Initialize generator and discriminator
- G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
- G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
- D_A = Discriminator(input_shape)
- D_B = Discriminator(input_shape)
優(yōu)化器和訓(xùn)練策略
- # Optimizers
- optimizer_G = torch.optim.Adam(
- itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
- )
- optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
- optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
- # Learning rate update schedulers
- lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
- optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
- )
- lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
- optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
- )
- lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
- optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
- )
訓(xùn)練迭代
- 訓(xùn)練數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的。A是原圖,B是風(fēng)格圖
- 生成器訓(xùn)練
- GAN loss:判別器判別A和B生成的兩個(gè)圖fake_A、fake_B與GT的loss
- Cycle loss:反過來fake_A和fake_B 生成的圖與A和B像素上差異
- 判別器訓(xùn)練:
- loss_real: 判別A/B和GT的MSELoss
- loss_fake:判別生成的fake_A/fake_B與GT的MSELoss
- for epoch in range(opt.epoch, opt.n_epochs):
- for i, batch in enumerate(dataloader):
- # 數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的
- real_A = Variable(batch["A"].type(Tensor))
- real_B = Variable(batch["B"].type(Tensor))
- # Adversarial ground truths
- valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
- fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
- # ------------------
- # Train Generators
- # ------------------
- G_AB.train()
- G_BA.train()
- optimizer_G.zero_grad()
- # Identity loss
- loss_id_A = criterion_identity(G_BA(real_A), real_A)
- loss_id_B = criterion_identity(G_AB(real_B), real_B)
- loss_identity = (loss_id_A + loss_id_B) / 2
- # GAN loss
- fake_B = G_AB(real_A)
- loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
- fake_A = G_BA(real_B)
- loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
- loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
- # Cycle loss
- recov_A = G_BA(fake_B)
- loss_cycle_A = criterion_cycle(recov_A, real_A)
- recov_B = G_AB(fake_A)
- loss_cycle_B = criterion_cycle(recov_B, real_B)
- loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
- # Total loss
- loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
- loss_G.backward()
- optimizer_G.step()
- # -----------------------
- # Train Discriminator A
- # -----------------------
- optimizer_D_A.zero_grad()
- # Real loss
- loss_real = criterion_GAN(D_A(real_A), valid)
- # Fake loss (on batch of previously generated samples)
- # fake_A_ = fake_A_buffer.push_and_pop(fake_A)
- loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
- # Total loss
- loss_D_A = (loss_real + loss_fake) / 2
- loss_D_A.backward()
- optimizer_D_A.step()
- # -----------------------
- # Train Discriminator B
- # -----------------------
- optimizer_D_B.zero_grad()
- # Real loss
- loss_real = criterion_GAN(D_B(real_B), valid)
- # Fake loss (on batch of previously generated samples)
- # fake_B_ = fake_B_buffer.push_and_pop(fake_B)
- loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
- # Total loss
- loss_D_B = (loss_real + loss_fake) / 2
- loss_D_B.backward()
- optimizer_D_B.step()
- loss_D = (loss_D_A + loss_D_B) / 2
- # --------------
- # Log Progress
- # --------------
- # Determine approximate time left
- batches_done = epoch * len(dataloader) + i
- batches_left = opt.n_epochs * len(dataloader) - batches_done
- time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
- prev_time = time.time()
- # Update learning rates
- lr_scheduler_G.step()
- lr_scheduler_D_A.step()
- lr_scheduler_D_B.step()
2.2.4 結(jié)果展示
本文訓(xùn)練的是莫奈風(fēng)格的轉(zhuǎn)變,如下圖:第一二行是莫奈風(fēng)格畫轉(zhuǎn)換為普通照片,第三四行為普通照片轉(zhuǎn)換為莫奈風(fēng)格畫
再來看實(shí)際手機(jī)拍攝圖片:
2.2.5 cycleGAN其他用途
3. 總結(jié)
本文詳細(xì)介紹了GAN的其中一種應(yīng)用cycleGAN,并將它應(yīng)用到圖像風(fēng)格的轉(zhuǎn)換??偨Y(jié)如下:
- GAN是學(xué)習(xí)數(shù)據(jù)中分布,并生成同樣分布但全新的數(shù)據(jù)
- CycleGAN是兩路GAN:兩個(gè)生成器和兩個(gè)判別器;為了保證生成器的生成的圖片與輸入圖存在一定的關(guān)系,不是隨機(jī)生產(chǎn)的圖片, 引入cycle consistency,判定A->fake_B->recove_A和A的差異
- 生成器:下采樣 + residual 殘差block + 上采樣
- 判別器: 不是一個(gè)圖生成一個(gè)判定值,而是patchGAN方式,生成很N*N個(gè)值,而后取均值