學習ChatGPT,AI繪畫引入人類反饋會怎樣?
最近,深度生成模型在根據文本 prompt 生成高質量圖像方面取得了顯著成功,部分原因在于深度生成模型擴展到了大規(guī)模網絡數據集(如 LAION)。但是,一些重大挑戰(zhàn)依然存在,因而大規(guī)模文本到圖像模型無法生成與文本 prompt 完全對齊的圖像。舉例而言,當前的文本到圖像模型往往無法生成可靠的視覺文本,并在組合式圖像生成方面存在困難。
回到語言建模領域,從人類反饋中學習已經成為一種用來「對齊模型行為與人類意圖」的強大解決方案。這類方法通過人類對模型輸出的反饋,首先學習一個旨在反映人類在任務中所關心內容的獎勵函數,然后通過一種強化學習算法(如近端策略優(yōu)化 PPO)使用學得的獎勵函數來優(yōu)化語言模型。這種帶有人類反饋框架的強化學習(RLHF)已經成功地將大規(guī)模語言模型(例如 GPT-3)與復雜的人類質量評估結合起來。
近日,受 RLHF 在語言領域的成功,谷歌研究院和加州伯克利的研究者提出了使用人類反饋來對齊文本到圖像模型的微調方法。
論文地址:https://arxiv.org/pdf/2302.12192v1.pdf
本文方法如下圖 1 所示,主要分為 3 個步驟。
第一步:首先從「設計用來測試文本到圖像模型輸出對齊的」一組文本 prompt 中生成不同的圖像。具體地,檢查預訓練模型更容易出錯的 prompt—— 生成具有特定顏色、數量和背景的對象,然后收集用于評估模型輸出的二元人類反饋。
第二步:使用了人工標記的數據集,訓練一個獎勵函數來預測給定圖像和文本 prompt 的人類反饋。研究者提出了一項輔助任務,在一組擾動文本 prompt 中識別原始文本 prompt,以更有效地將人類反饋用于獎勵學習。這一技術改進了獎勵函數對未見過圖像和文本 prompt 的泛化表現。
第三步:通過獎勵加權似然最大化更新文本到圖像模型,以更好地使它與人類反饋保持一致。與之前使用強化學習進行優(yōu)化的工作不同,研究者使用半監(jiān)督學習來更新模型,以測量模型輸出質量即學得的獎勵函數。
研究者使用帶有人類反饋的 27000 個圖像 - 文本對來微調 Stable Diffusion 模型,結果顯示微調后的模型在生成具有特定顏色、數量和背景的對象方面實現顯著提升。圖像 - 文本對齊方面實現了高達 47% 的改進,但圖像保真度略有下降。
此外,組合式生成結果也得到了改進,即在給定未見過顏色、數量和背景 prompt 組合時可以更好地生成未見過的對象。他們還觀察到,學得的獎勵函數比測試文本 prompt 上的 CLIP 分數更符合人類對對齊的評估。
不過,論文一作 Kimin Lee 也表示,本文的結果并沒有解決現有文本到圖像模型中所有的失效模型,仍存在諸多挑戰(zhàn)。他們希望這項工作能夠突出從人類反饋中學習在對齊文生圖模型中的應用潛力。
方法介紹
為了將生成圖像與文本 prompt 對齊,該研究對預訓練模型進行了一系列微調,過程如上圖 1 所示。首先從一組文本 prompt 中生成相應的圖像,這一過程旨在測試文生圖模型的各種性能;然后是人類評分員對這些生成的圖像提供二進制反饋;接下來,該研究訓練了一個獎勵模型來預測以文本 prompt 和圖像作為輸入的人類反饋;最后,該研究使用獎勵加權對數似然對文生圖模型進行微調,以改善文本 - 圖像對齊。
人類數據收集
為了測試文生圖模型的功能,該研究考慮了三類文本 prompt:指定數量(specified count)、顏色、背景。對于每個類別,該研究對每個描述該物體的單詞或短語兩兩進行組合來生成 prompt,例如將綠色(顏色)與一只狗(數量)組合。此外,該研究還考慮了三個類別的組合(例如,在一個城市中兩只染著綠顏色的狗)。下表 1 更好的闡述了數據集分類。每一個 prompt 會被用來生成 60 張圖像,模型主要為 Stable Diffusion v1.5 。
人類反饋
接下來對生成的圖像進行人類反饋。由同一個 prompt 生成的 3 張圖像會被呈遞給打標簽人員,并要求他們評估生成的每幅圖像是否與 prompt 保持一致,評價標準為 good 或 bad。由于這項任務比較簡單,用二元反饋就可以了。
獎勵學習
為了更好的評價圖像 - 文本對齊,該研究使用獎勵函數來衡量,該函數可以將圖像 x 的 CLIP 嵌入和文本 prompt z 映射到標量值。之后其被用來預測人類反饋 k_y ∈ {0, 1} (1 = good, 0 = bad) 。
從形式上來講,就是給定人類反饋數據集 D^human = {(x, z, y)},獎勵函數通過最小化均方誤差 (MSE) 來訓練:
此前,已經有研究表明數據增強方法可以顯著提高數據效率和模型學習性能,為了有效地利用反饋數據集,該研究設計了一個簡單的數據增強方案和獎勵學習的輔助損失(auxiliary loss)。該研究在輔助任務中使用增強 prompt,即對原始 prompt 進行分類獎勵學習。Prompt 分類器使用獎勵函數,如下所示:
輔助損失為:
最后是更新文生圖模型。由于模型生成的數據集多樣性是有限的,可能導致過擬合。為了緩解這一點,該研究還最小化了預訓練損失,如下所示:
實驗結果
實驗部分旨在測試人類反饋參與模型微調的有效性。實驗用到的模型為 Stable Diffusion v1.5 ;數據集信息如表 1(參見上文)和表 2 所示,表 2 顯示了由多個人類標簽者提供的反饋分布。
人類對文本 - 圖像對齊的評分(評估指標為顏色、物體數量)。如圖 4 所示,本文方法顯著提高了圖像 - 文本對齊,具體來說,模型生成的圖像中有 50% 的樣本獲得至少三分之二的贊成票(投票數量為 7 票或更多贊成票),然而,微調會稍微降低圖像保真度(15% 比 10%)。
圖 2 顯示了來自原始模型和本文經過微調的對應模型的圖像示例??梢钥吹皆寄P蜕闪巳鄙偌毠?jié)(例如,顏色、背景或計數)的圖像(圖 2 (a)),本文模型生成的圖像符合 prompt 指定的顏色、計數和背景。值得注意的是,本文模型還能生成沒有見過的文本 prompt 圖像,并且質量非常高(圖 2 (b))。
獎勵學習的結果。圖 3 (a) 為模型在見過的文本 prompt 和未見文本 prompt 中的評分。有獎勵(綠色)比 CLIP 分數(紅色)更符合典型的人類意圖。