自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

圖像風(fēng)格遷移也有框架了:使用Python編寫,與PyTorch完美兼容,外行也能用

開發(fā) 開發(fā)工具
pystiche 是一個用 Python 編寫的 NST 框架,基于 PyTorch 構(gòu)建,并與之完全兼容。相關(guān)研究由 pyOpenSci 進(jìn)行同行評審,并發(fā)表在 JOSS 期刊 (Journal of Open Source Software) 上。

易于使用的神經(jīng)風(fēng)格遷移框架 pystiche。

將內(nèi)容圖片與藝術(shù)風(fēng)格圖片進(jìn)行融合,生成一張具有特定風(fēng)格的新圖,這種想法并不新鮮。早在 2015 年,Gatys、 Ecker 以及 Bethge 開創(chuàng)性地提出了神經(jīng)風(fēng)格遷移(Neural Style Transfer ,NST)。

不同于深度學(xué)習(xí),目前 NST 還沒有現(xiàn)成的庫或框架。因此,新的 NST 技術(shù)要么從頭開始實(shí)現(xiàn)所有內(nèi)容,要么基于現(xiàn)有的方法實(shí)現(xiàn)。但這兩種方法都有各自的缺點(diǎn):前者由于可重用部分的冗長實(shí)現(xiàn),限制了技術(shù)創(chuàng)新;后者繼承了 DL 硬件和軟件快速發(fā)展導(dǎo)致的技術(shù)債務(wù)。

最近,新項(xiàng)目 pystiche 很好地解決了這些問題,雖然它的核心受眾是研究人員,但其易于使用的用戶界面為非專業(yè)人員使用 NST 提供了可能。

pystiche 是一個用 Python 編寫的 NST 框架,基于 PyTorch 構(gòu)建,并與之完全兼容。相關(guān)研究由 pyOpenSci 進(jìn)行同行評審,并發(fā)表在 JOSS 期刊 (Journal of Open Source Software) 上。

論文地址:https://joss.theoj.org/papers/10.21105/joss.02761

項(xiàng)目地址:https://github.com/pmeier/pystiche

在深入實(shí)現(xiàn)之前,我們先來回顧一下 NST 的原理。它有兩種優(yōu)化方式:基于圖像的優(yōu)化和基于模型的優(yōu)化。雖然 pystiche 能夠很好地處理后者,但更為復(fù)雜,因此本文只討論基于圖像的優(yōu)化方法。

在基于圖像的方法中,將圖像的像素迭代調(diào)整訓(xùn)練,來擬合感知損失函數(shù)(perceptual loss)。感知損失是 NST 的核心部分,分為內(nèi)容損失(content loss)和風(fēng)格損失(style loss),這些損失評估輸出圖像與目標(biāo)圖像的匹配程度。與傳統(tǒng)的風(fēng)格遷移算法不同,感知損失包含一個稱為編碼器的多層模型,這就是 pystiche 基于 PyTorch 構(gòu)建的原因。

如何使用 pystiche

讓我們用一個例子介紹怎么使用 pystiche 生成神經(jīng)風(fēng)格遷移圖片。首先導(dǎo)入所需模塊,選擇處理設(shè)備。雖然 pystiche 的設(shè)計(jì)與設(shè)備無關(guān),但使用 GPU 可以將 NST 的速度提高幾個數(shù)量級。

模塊導(dǎo)入與設(shè)備選擇:

  1. import torch 
  2. import pystiche 
  3. from pystiche import demo, enc, loss, ops, optim 
  4.  
  5. print(f"pystiche=={pystiche.__version__}") 
  6.  
  7. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

輸出:

  1. pystiche==0.7.0 

多層編碼器

content_loss 和 style_loss 是對圖像編碼進(jìn)行操作而不是圖像本身,這些編碼是由在不同層級的預(yù)訓(xùn)練編碼器生成的。pystiche 定義了 enc.MultiLayerEncoder 類,該類在單個前向傳遞中可以有效地處理編碼問題。該示例使用基于 VGG19 架構(gòu)的 vgg19_multi_layer_encoder。默認(rèn)情況下,它將加載 torchvision 提供的權(quán)重。

多層編碼器:

  1. multi_layer_encoder = enc.vgg19_multi_layer_encoder() 
  2. print(multi_layer_encoder) 

輸出:

  1. VGGMultiLayerEncoder( 
  2.   arch=vgg19framework=torchallow_inplace=True 
  3.   (preprocessing): TorchPreprocessing( 
  4.    (0): Normalize( 
  5.      mean=('0.485', '0.456', '0.406'), 
  6.      std=('0.229', '0.224', '0.225') 
  7.     ) 
  8.   ) 
  9.  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  10.  (relu1_1): ReLU(inplace=True
  11.  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  12.  (relu1_2): ReLU(inplace=True
  13.  (pool1): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  14.  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  15.  (relu2_1): ReLU(inplace=True
  16.  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  17.  (relu2_2): ReLU(inplace=True
  18.  (pool2): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  19.  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  20.  (relu3_1): ReLU(inplace=True
  21.  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  22.  (relu3_2): ReLU(inplace=True
  23.  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  24.  (relu3_3): ReLU(inplace=True
  25.  (conv3_4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  26.  (relu3_4): ReLU(inplace=True
  27.  (pool3): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  28.  (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  29.  (relu4_1): ReLU(inplace=True
  30.  (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  31.  (relu4_2): ReLU(inplace=True
  32.  (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  33.  (relu4_3): ReLU(inplace=True
  34.  (conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  35.  (relu4_4): ReLU(inplace=True
  36.  (pool4): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  37.  (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  38.  (relu5_1): ReLU(inplace=True
  39.  (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  40.  (relu5_2): ReLU(inplace=True
  41.  (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  42.  (relu5_3): ReLU(inplace=True
  43.  (conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  44.  (relu5_4): ReLU(inplace=True
  45.  (pool5): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False

感知損失

pystiche 將內(nèi)容損失和風(fēng)格損失定義為操作符。使用 ops.FeatureReconstructionOperator 作為 content_loss,直接與編碼進(jìn)行對比。如果編碼器針對分類任務(wù)進(jìn)行過訓(xùn)練,如該示例中這些編碼表示內(nèi)容。對于content_layer,選擇 multi_layer_encoder 的較深層來獲取抽象的內(nèi)容表示,而不是許多不必要的細(xì)節(jié)。

  1. content_layer = "relu4_2" 
  2. encoder = multi_layer_encoder.extract_encoder(content_layer) 
  3. content_loss = ops.FeatureReconstructionOperator(encoder) 

pystiche 使用 ops.GramOperator 作為 style_loss 的基礎(chǔ),通過比較編碼各個通道之間的相關(guān)性來丟棄空間信息。這樣就可以在輸出圖像中的任意區(qū)域合成風(fēng)格元素,而不僅僅是風(fēng)格圖像中它們所在的位置。對于 ops.GramOperator,如果它在淺層和深層 style_layers 都能很好地運(yùn)行,則其性能達(dá)到最佳。

style_weight 可以控制模型對輸出圖像的重點(diǎn)——內(nèi)容或風(fēng)格。為了方便起見,pystiche 將所有內(nèi)容包裝在 ops.MultiLayerEncodingOperator 中,該操作處理在同一 multi_layer_encoder 的多個層上進(jìn)行操作的相同類型操作符的情況。

  1. style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1") 
  2. style_weight = 1e3 
  3.  
  4.  
  5. def get_encoding_op(encoder, layer_weight): 
  6.     return ops.GramOperator(encoder, score_weight=layer_weight
  7.  
  8.  
  9. style_loss = ops.MultiLayerEncodingOperator( 
  10.     multi_layer_encoder, style_layers, get_encoding_op, score_weight=style_weight

loss.PerceptualLoss 結(jié)合了 content_loss 與 style_loss,將作為優(yōu)化的標(biāo)準(zhǔn)。

  1. criterion = loss.PerceptualLoss(content_loss, style_loss).to(device) 
  2. print(criterion) 

輸出:

  1. PerceptualLoss( 
  2.  (content_loss): FeatureReconstructionOperator( 
  3.    score_weight=1
  4.    encoder=VGGMultiLayerEncoder
  5.      layer=relu4_2
  6.      arch=vgg19
  7.      framework=torch
  8.      allow_inplace=True 
  9.    ) 
  10.  ) 
  11.  (style_loss): MultiLayerEncodingOperator( 
  12.    encoder=VGGMultiLayerEncoder
  13.      arch=vgg19
  14.      framework=torch
  15.      allow_inplace=True 
  16.  ), 
  17.  score_weight=1000 
  18.  (relu1_1): GramOperator(score_weight=0.2) 
  19.  (relu2_1): GramOperator(score_weight=0.2) 
  20.  (relu3_1): GramOperator(score_weight=0.2) 
  21.  (relu4_1): GramOperator(score_weight=0.2) 
  22.  (relu5_1): GramOperator(score_weight=0.2) 
  23.  ) 

圖像加載

首先加載并顯在 NST 需要的目標(biāo)圖片。因?yàn)?NST 占用內(nèi)存較多,故將圖像大小調(diào)整為 500 像素。

  1. size = 500 
  2. images = demo.images() 
  1. content_image = images["bird1"].read(sizesize=size, devicedevice=device) 
  2. criterion.set_content_image(content_image) 

內(nèi)容圖片

  1. style_image = images["paint"].read(sizesize=size, devicedevice=device) 
  2. criterion.set_style_image(style_image) 

風(fēng)格圖片

神經(jīng)風(fēng)格遷移

創(chuàng)建 input_image。從 content_image 開始執(zhí)行 NST,這樣可以實(shí)現(xiàn)快速收斂。image_optimization 函數(shù)是為了方便,也可以由手動優(yōu)化循環(huán)代替,且不受限制。如果沒有指定,則使用 torch.optim.LBFGS 作為優(yōu)化器。

  1. input_image = content_image.clone() 
  2. output_image = optim.image_optimization(input_image, criterion, num_steps=500

【本文是51CTO專欄機(jī)構(gòu)“機(jī)器之心”的原創(chuàng)譯文,微信公眾號“機(jī)器之心( id: almosthuman2014)”】 

戳這里,看該作者更多好文

責(zé)任編輯:趙寧寧 來源: 51CTO專欄
相關(guān)推薦

2024-12-11 15:15:42

2023-12-13 15:00:38

淺拷貝深拷貝Python

2009-09-17 08:35:56

Windows 7Itunes兼容性

2009-02-01 14:34:26

PythonUnix管道風(fēng)格

2009-11-27 09:05:19

Windows 7Chrome兼容性

2009-11-26 11:00:28

Chrome瀏覽器Windows 7

2018-05-07 14:11:15

RootAndroidXposed

2017-03-10 16:32:44

Apache Spar大數(shù)據(jù)工具

2025-03-26 02:00:00

2024-12-13 16:01:35

2023-12-04 09:00:00

PythonRuff

2020-11-20 11:05:39

編程工具開發(fā)

2023-09-27 23:08:08

Web前端Vue.jsVue3.0

2013-06-10 23:23:29

操作系統(tǒng)OS X

2009-07-15 11:00:48

proxool連接池

2021-06-18 15:15:51

機(jī)器學(xué)習(xí)Rust框架

2016-03-30 11:20:10

2022-10-30 15:00:40

小樣本學(xué)習(xí)數(shù)據(jù)集機(jī)器學(xué)習(xí)

2023-10-09 14:36:28

工具PLGEFK

2012-04-26 19:27:18

點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號