自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

對(duì)抗生成網(wǎng)絡(luò)GAN 原創(chuàng) 精華

發(fā)布于 2024-11-26 09:06
瀏覽
0收藏

前言

在前一階段課程中,我們學(xué)習(xí)了圖像分割中的語義分割、實(shí)例分割、旋轉(zhuǎn)目標(biāo)檢測(cè)等。這些圖像分割算法都是有監(jiān)督學(xué)習(xí),而GAN(生成對(duì)抗網(wǎng)絡(luò))是一種特別的學(xué)習(xí)策略,其核心思想非常值得借鑒,所以本章將以GAN網(wǎng)絡(luò)的代碼為切入口,了解掌握其核心思想。

學(xué)習(xí)策略

人工智能方面的學(xué)習(xí)策略有兩種:有監(jiān)督學(xué)習(xí)和無監(jiān)督學(xué)習(xí)。

有監(jiān)督學(xué)習(xí)

定義:有監(jiān)督學(xué)習(xí)是使用帶標(biāo)簽的數(shù)據(jù)集進(jìn)行訓(xùn)練。每個(gè)輸入數(shù)據(jù)都有對(duì)應(yīng)的輸出標(biāo)簽,模型通過學(xué)習(xí)輸入與輸出之間的關(guān)系來進(jìn)行預(yù)測(cè)。

舉個(gè)例子:孩子的個(gè)人成長(zhǎng),有經(jīng)驗(yàn)的家長(zhǎng)為期規(guī)劃了發(fā)展的路線,孩子在規(guī)劃下有計(jì)劃地學(xué)習(xí)成長(zhǎng),這屬于有監(jiān)督學(xué)習(xí)。

無監(jiān)督學(xué)習(xí)

定義:無監(jiān)督學(xué)習(xí)使用沒有標(biāo)簽的數(shù)據(jù)集進(jìn)行訓(xùn)練。模型試圖發(fā)現(xiàn)數(shù)據(jù)中的模式或結(jié)構(gòu),而不依賴于任何預(yù)先定義的標(biāo)簽。

同樣的例子:孩子在無監(jiān)督學(xué)習(xí)下,是沒有家長(zhǎng)為期進(jìn)行規(guī)劃,而是經(jīng)歷社會(huì)"捶打"(做得好了有加分、做不好扣分),最終學(xué)習(xí)成長(zhǎng)起來。

GAN的基礎(chǔ)介紹

在上述的兩種學(xué)習(xí)策略中,有一種特殊的、獨(dú)立的學(xué)習(xí)策略:GAN(生成對(duì)抗網(wǎng)絡(luò))。

它由兩個(gè)網(wǎng)絡(luò)(生成器和判別器),通過對(duì)抗在競(jìng)爭(zhēng)中共同發(fā)展。

  • G:生成器(造假)
  • D:鑒別器(打假)
  • 訓(xùn)練過程:

兩個(gè)網(wǎng)絡(luò)剛開始都沒有任何能力

在競(jìng)爭(zhēng)中共同發(fā)展

最后兩個(gè)網(wǎng)絡(luò)能力都得到提升

舉個(gè)例子:GAN網(wǎng)絡(luò)就像警察和小偷,警察和小偷之間互相對(duì)抗。

GAN示例

為了對(duì)GAN網(wǎng)絡(luò)有個(gè)直觀印象,我們可以參考Github上一個(gè)開源項(xiàng)目,對(duì)GAN有個(gè)初步認(rèn)知。

頁面地址:https://poloclub.github.io/ganlab/

示例目的

  • 在頁面中添加一個(gè)手寫數(shù)字圖像
  • 通過訓(xùn)練模型來模擬手寫數(shù)字圖像
  • 從而達(dá)到新圖像與原來的風(fēng)格類似,分不出真假

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

核心思想

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

論文地址:https://arxiv.org/pdf/1406.2661

生成器(Generator):

  • 作用:負(fù)責(zé)憑空編造假的數(shù)據(jù)出來。
  • 目的:通過機(jī)器生成假的數(shù)據(jù)(大部分情況下是圖像),最終目的是“騙過”判別器。
  • 過程:G是一個(gè)生成圖片的網(wǎng)絡(luò),它接收一個(gè)隨機(jī)的噪聲z,通過這個(gè)噪聲生成圖片,記做G(z)。

判別器(Discriminator):

  • 作用:負(fù)責(zé)判斷傳來的數(shù)據(jù)是真還是假。
  • 目的:判斷這張圖像是真實(shí)的還是機(jī)器生成的,目的是找出生成器做的“假數(shù)據(jù)”。
  • 過程:D是一個(gè)判別網(wǎng)絡(luò),判別一張圖片是不是“真實(shí)的”。它的輸入?yún)?shù)是x,x代表一張圖片,輸出D(x)代表x為真實(shí)圖片的概率,如果為1,就代表100%是真實(shí)的圖片,而輸出為0,就代表不可能是真實(shí)的圖片。

接下來,我們通過詳細(xì)了解GAN網(wǎng)絡(luò)的代碼,深入了解其運(yùn)行過程。

引入依賴

import torch
from torch import nn
from torch.nn import functional as F

import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
from torch.utils.data importDataLoader

import os
import numpy as np
import matplotlib.pyplot as plt

# 判斷當(dāng)前設(shè)備是否GPU
device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
device

讀取數(shù)據(jù)集

# 加載并預(yù)處理圖像
data = datasets.MNIST(root="data", 
                      train=True, 
                      transform = transforms.Compose(transforms=[transforms.ToTensor(),
                      transforms.Normalize(mean=[0.5], std=[0.5])]),
                      download=True)

# 封裝成 DataLoader
data_loader = DataLoader(dataset=data, batch_size=100, shuffle=True)
  • 備注:上述??transform = transforms.Compose?? 的作用主要是進(jìn)行數(shù)據(jù)增強(qiáng),詳細(xì)內(nèi)容在補(bǔ)充知識(shí)部分展開介紹。

定義模型

定義生成器

"""
    定義生成器
"""

classGenerator(nn.Module):
"""
        定義一個(gè)圖像生成
        輸入:一個(gè)向量
        輸出:一個(gè)向量(代表圖像)
    """
def__init__(self, in_features=100, out_features=28 * 28):
"""
            掛載超參數(shù)
        """
# 先初始化父類,再初始化子類
super(Generator, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

# 第一個(gè)隱藏層
        self.hidden0 = nn.Linear(in_features=self.in_features, out_features=256)

# 第二個(gè)隱藏層
        self.hidden1 = nn.Linear(in_features=256, out_features=512)

# 第三個(gè)隱藏層
        self.hidden2 = nn.Linear(in_features=512, out_features=self.out_features)

defforward(self, x):

# 第一層 [b, 100] --> [b, 256]
        h = self.hidden0(x)
        h = F.leaky_relu(input=h, negative_slope=0.2)

# 第二層 [b, 256] --> [b, 512]
        h = self.hidden1(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)

# 第三層 [b, 512] --> [b, 28 * 28]
        h = self.hidden2(h)

# 壓縮數(shù)據(jù)的變化范圍
        o = torch.tanh(h)

return o

定義鑒別器

"""
    定義一個(gè)鑒別器
"""

classDiscriminator(nn.Module):
"""
        本質(zhì):二分類分類器
        輸入:一個(gè)對(duì)象
        輸出:真品還是贗品
    """
def__init__(self, in_features=28*28, out_features=1):
super(Discriminator, self).__init__()

        self.in_features=in_features
        self.out_features=out_features

# 第一個(gè)隱藏層
        self.hidden0= nn.Linear(in_features=self.in_features, out_features=512)

# 第二個(gè)隱藏層
        self.hidden1= nn.Linear(in_features=512, out_features=256)

# 第三個(gè)隱藏層
        self.hidden2= nn.Linear(in_features=256, out_features=32)

# 第四個(gè)隱藏層
        self.hidden3= nn.Linear(in_features=32, out_features=self.out_features)


defforward(self, x):

# 第一層
        h = self.hidden0(x)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第二層
        h = self.hidden1(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第三層
        h = self.hidden2(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第四層
        h = self.hidden3(h)

# 輸出概率
        o = torch.sigmoid(h)

return o

構(gòu)建模型

"""
    構(gòu)建模型
"""
# 定義一個(gè)生成器
generator = Generator(in_features=100, out_features=784)
generator.to(device=device)

# 定義一個(gè)鑒別器
discriminator = Discriminator(in_features=784, out_features=1)
discriminator.to(device=device)

定義優(yōu)化器

"""
    定義優(yōu)化器
"""

# 定義一個(gè)生成器的優(yōu)化器
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=1e-4)

# 定義一個(gè)鑒別的優(yōu)化器
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)

定義損失函數(shù)

"""
    定義一個(gè)損失函數(shù)
"""
loss_fn = nn.BCELoss()

籌備訓(xùn)練

定義訓(xùn)練輪次

# 定義訓(xùn)練輪次
num_epochs = 1000

獲取數(shù)據(jù)的標(biāo)簽

"""
    獲取數(shù)據(jù)的標(biāo)簽
"""

defget_real_data_labels(size):
"""
        獲取真實(shí)數(shù)據(jù)的標(biāo)簽
    """
    labels = torch.ones(size,1, device=device)

return labels

defget_fake_data_labels(size):
"""
        獲取虛假數(shù)據(jù)的標(biāo)簽
    """
    labels = torch.zeros(size,1, device=device)

return labels

定義噪聲生成器

"""
    噪聲生成器
"""
defget_noise(size):
"""
        給生成器準(zhǔn)備數(shù)據(jù)
        - 100維度的向量
    """
    X = torch.randn(size,100, device=device)

return X

# 獲取一批測(cè)試數(shù)據(jù)

num_test_samples =16
test_noise = get_noise(num_test_samples)

噪聲生成器的作用:因?yàn)槲覀冃枰O(jiān)控模型訓(xùn)練的效果,所以將噪聲固定下來,在訓(xùn)練過程中看同樣的噪聲最后給出的結(jié)果是否變得越來越好。

訓(xùn)練模型

"""
    訓(xùn)練過程
"""

g_losses =[]
d_losses =[]
for epoch inrange(1, num_epochs+1):

print(f"當(dāng)前正在進(jìn)行 第 {epoch} 輪 ....")

# 設(shè)置訓(xùn)練模式
    generator.train()
    discriminator.train()

# 遍歷真實(shí)的圖像
for batch_idx,(batch_real_data, _)inenumerate(data_loader):
"""
        1, 先訓(xùn)練鑒別器
            鑒別器就是一個(gè)二分類問題
            - 給一批真數(shù)據(jù),輸出真
            - 給一批假數(shù)據(jù),輸出假
        
        """

# 1.1 準(zhǔn)備數(shù)據(jù)
# 圖像轉(zhuǎn)向量 [b, 1, 28, 28] ---> [b, 784]
# 從數(shù)據(jù)集中獲取100個(gè)真實(shí)的手寫數(shù)字圖像
        real_data = batch_real_data.view(batch_real_data.size(0),-1).to(device=device)

# 噪聲[b, 100]
# 隨機(jī)生成100個(gè)100維度的噪聲,用于生成假圖像
        noise = get_noise(real_data.size(0))

# 根據(jù)噪聲,生成假數(shù)據(jù) 
# [b, 100] --> [b, 784]
        fake_data = generator(noise).detach()


# 1.2 訓(xùn)練過程

# 鑒別器的優(yōu)化器梯度情況
        d_optimizer.zero_grad()

# 對(duì)真實(shí)數(shù)據(jù)鑒別
        real_pred = discriminator(real_data)

# 計(jì)算真實(shí)數(shù)據(jù)的誤差
        real_loss = loss_fn(real_pred, get_real_data_labels(real_data.size(0)))

# 真實(shí)數(shù)據(jù)的梯度回傳
        real_loss.backward()


# 對(duì)假數(shù)據(jù)鑒別
        fake_pred = discriminator(fake_data)

# 計(jì)算假數(shù)據(jù)的誤差
        fake_loss = loss_fn(fake_pred, get_fake_data_labels(fake_data.size(0)))

# 假數(shù)據(jù)梯度回傳
        fake_loss.backward()

# 梯度更新
        d_optimizer.step()

# ----------------
        d_losses.append((real_loss + fake_loss).item())
# print(f"鑒別器的損失:{real_loss + fake_loss}")


"""2, 再訓(xùn)練生成器"""

# 獲取生成器的生成結(jié)果
        fake_pred = generator(get_noise(real_data.size(0)))

# 生產(chǎn)器梯度清空
        g_optimizer.zero_grad()

# 把假數(shù)據(jù)讓鑒別器鑒別一下
# 把discriminator requires_grad = False
# 設(shè)置為不可學(xué)習(xí)
for param in discriminator.parameters():
            param.requires_grad =False

        d_pred = discriminator(fake_pred)

# 設(shè)置為可學(xué)習(xí)
for param in discriminator.parameters():
            param.requires_grad =True

# 計(jì)算損失
# 把一個(gè)假東西,給專家看,專家說是真的,這個(gè)時(shí)候,造假的水平就可以了
        g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))

# 梯度回傳
        g_loss.backward()

# 參數(shù)更新
        g_optimizer.step()

# print(f"生成器誤差:{g_loss}")
        g_losses.append(g_loss.item())

# 每輪訓(xùn)練之后,觀察生成器的效果
    generator.eval()

with torch.no_grad():

# 正向推理
        img_pred = generator(test_noise)
        img_pred = img_pred.view(img_pred.size(0),28,28).cpu().data

# 畫圖
        display.clear_output(wait=True)

# 設(shè)置畫圖的大小
        fig = plt.figure(1, figsize=(12,8))
# 劃分為 4 x 4 的 網(wǎng)格
        gs = gridspec.GridSpec(4,4)

# 遍歷每一個(gè)
for i inrange(4):
for j inrange(4):
# 取每一個(gè)圖
                X = img_pred[i *4+ j,:,:]
# 添加一個(gè)對(duì)應(yīng)網(wǎng)格內(nèi)的子圖
                ax = fig.add_subplot(gs[i, j])
# 在子圖內(nèi)繪制圖像
                ax.matshow(X, cmap=plt.get_cmap("Greys"))
# ax.set_xlabel(f"{label}")
                ax.set_xticks(())
                ax.set_yticks(())
        plt.show()

運(yùn)行結(jié)果:

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

核心代碼說明:

訓(xùn)練過程

  • 隨機(jī)生成一組潛在向量z,并使用生成器生成一組假數(shù)據(jù)。
  • 將一組真實(shí)數(shù)據(jù)和一組假數(shù)據(jù)作為輸入,訓(xùn)練判別器。
  • 使用生成器生成一組新的假數(shù)據(jù),并訓(xùn)練判別器。
  • 重復(fù)步驟2和3,直到生成器生成的假數(shù)據(jù)與真實(shí)數(shù)據(jù)的分布相似。

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

核心代碼

  • ??fake_data = generator(noise).detach()??:

作用:是生成器生成一組假數(shù)據(jù),并使用detach()方法將其從計(jì)算圖中分離出來,防止梯度回傳。

說明:(因?yàn)樵谟?xùn)練鑒別器時(shí),生成器只是工具人,其前向傳播過程中記錄的梯度信息不會(huì)被使用,所以不需要記錄梯度信息)

  • ??g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))?? 這里是體現(xiàn)對(duì)抗的核心代碼,即:生成器訓(xùn)練的好不好,是要與真實(shí)數(shù)據(jù)的判別結(jié)果越接近越好。

補(bǔ)充知識(shí)

數(shù)據(jù)增強(qiáng)

在人工智能模型的訓(xùn)練中,采集樣本是需要成本的,所以為了提升樣本的豐富性,一般會(huì)采用數(shù)據(jù)增強(qiáng)的方式。

  • 方式:在樣本固定的基礎(chǔ)上,通過軟件模擬,來生成假數(shù)據(jù),豐富樣本的多樣性
  • 本質(zhì):給樣本加上適當(dāng)?shù)脑肼?,模擬出不同場(chǎng)景的樣本
  • 說明:數(shù)據(jù)增強(qiáng)只發(fā)生在模型訓(xùn)練中,為了增加訓(xùn)練樣本的多樣性

transform介紹

在 PyTorch 中,transform 主要用于數(shù)據(jù)預(yù)處理和增強(qiáng),特別是在圖像處理任務(wù)中。transform 是 torchvision 庫的一部分,能夠?qū)?shù)據(jù)集中的圖像進(jìn)行各種轉(zhuǎn)換,以便更好地適應(yīng)模型訓(xùn)練的需求。以下是 transform 的主要作用

import torch
from torchvision import datasets, transforms

from PIL import Image
import matplotlib.pyplot as plt

# 讀取本地下載的一張圖片
img = Image.open('girl.png')
img

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

重設(shè)圖片尺寸

resize = transforms.Resize((300, 200))
resize_img = resize(img)
resize_img

運(yùn)行效果:

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

中心裁剪

centercrop = transforms.CenterCrop(size=(200, 200))
center_img = centercrop(img)
center_img

運(yùn)行效果:

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

隨機(jī)調(diào)整亮度、飽和度、對(duì)比度等

color_jitter = transforms.ColorJitter(brightness=0.5, 
                                      contrast=0.5, 
                                      saturation=0.5, 
                                      hue=0.5)
color_jitter(img)

運(yùn)行效果:

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

隨機(jī)旋轉(zhuǎn)

random_rotation = transforms.RandomRotation(degrees=10)
random_rotation(img)

運(yùn)行效果:

對(duì)抗生成網(wǎng)絡(luò)GAN-AI.x社區(qū)

組合變換

Compose:可以將多個(gè)變換組合在一起,形成一個(gè)轉(zhuǎn)換管道,方便批量處理。例如:

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),          # 將PIL Image轉(zhuǎn)換為Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 將數(shù)據(jù)歸一化到[-1, 1]之間
])

內(nèi)容小結(jié)

  • GAN(生成對(duì)抗網(wǎng)絡(luò))是一種特殊的學(xué)習(xí)策略,它由生成器和判別器組成,生成器生成假數(shù)據(jù),判別器判斷真假。
  • 生成器(Generator)通過機(jī)器生成假的數(shù)據(jù)(大部分情況下是圖像),最終目的是“騙過”判別器。
  • 判別器(Discriminator):判斷這張圖像是真實(shí)的還是機(jī)器生成的,目的是找出生成器做的“假數(shù)據(jù)”。
  • 訓(xùn)練過程是:先訓(xùn)練判別器,再訓(xùn)練生成器。
  • 訓(xùn)練判別器時(shí),生成器是"工具人",所以需要使用detach()方法,將生成器生成的假數(shù)據(jù)從計(jì)算圖中分離出來,防止梯度回傳。
  • 訓(xùn)練生成器時(shí),判別器是"工具人",為了避免整個(gè)梯度消失,需要使用param.requires_grad = False設(shè)置為不可學(xué)習(xí),判別完之后再使用param.requires_grad = True設(shè)置為可學(xué)習(xí)。
  • 在人工智能模型訓(xùn)練過程中,通常會(huì)使用數(shù)據(jù)增強(qiáng)的方式,在樣本固定的基礎(chǔ)上,通過軟件模擬,來生成假數(shù)據(jù),豐富樣本的多樣性。
  • transform:在 PyTorch 中,transform 主要用于數(shù)據(jù)預(yù)處理和增強(qiáng),特別是在圖像處理任務(wù)中。

?

本文轉(zhuǎn)載自公眾號(hào)一起AI技術(shù) 作者:Dongming

原文鏈接:??https://mp.weixin.qq.com/s/tHEfP7_rTXkXWfWJZu9vYw??

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦