只需幾個(gè)演示就能對齊大模型,楊笛一團(tuán)隊(duì)提出的DITTO竟如此高效
養(yǎng)育孩子時(shí),古往今來人們都會(huì)談到一種重要方法:以身作則。也就是讓自己成為孩子模仿學(xué)習(xí)的范例,而不是單純地告訴他們應(yīng)該怎么做。在訓(xùn)練大語言模型(LLM)時(shí),我們或許也能采用這樣的方法 —— 向模型進(jìn)行演示。
近日,斯坦福大學(xué)楊笛一團(tuán)隊(duì)提出了一種新框架 DITTO,可通過少量演示(用戶提供的期望行為示例)來將 LLM 與特定設(shè)置對齊。這些示例可以從用戶現(xiàn)有的交互日志獲取,也能通過直接編輯 LLM 的輸出得到。這樣就可以讓模型針對不同的用戶和任務(wù)高效地理解并對齊用戶偏好。
- 論文標(biāo)題:Show, Don’t Tell: Aligning Language Models with Demonstrated Feedback
- 論文地址:https://arxiv.org/pdf/2406.00888
DITTO 可基于少量演示(少于 10)自動(dòng)創(chuàng)建一個(gè)包含大量偏好比較數(shù)據(jù)的數(shù)據(jù)集(這個(gè)過程被稱為 scaffold),其具體做法是默認(rèn)這一點(diǎn):相比于原始 LLM 及早期迭代版本的輸出,用戶更偏好演示。然后,將演示與模型輸出組成數(shù)據(jù)對,得到增強(qiáng)數(shù)據(jù)集。之后便可以使用 DPO 等對齊算法來更新語言模型。
此外,該團(tuán)隊(duì)還發(fā)現(xiàn),DITTO 可被視為一種在線模仿學(xué)習(xí)算法,其中從 LLM 采樣的數(shù)據(jù)會(huì)被用于區(qū)分專家行為。從這一角度出發(fā),該團(tuán)隊(duì)證明 DITTO 可通過外推實(shí)現(xiàn)超越專家的表現(xiàn)。
該團(tuán)隊(duì)也通過實(shí)驗(yàn)驗(yàn)證了 DITTO 的效果。
DITTO 框架
為了對齊 LLM,此前的各類方法往往需要使用成千上萬對比較數(shù)據(jù),而 DITTO 僅需使用少量演示就能修改模型的行為。這種低成本的快速適應(yīng)之所以能實(shí)現(xiàn),主要得益于該團(tuán)隊(duì)的核心見解:可通過演示輕松獲取在線比較數(shù)據(jù)。
符號(hào)和背景
語言模型可被視為一個(gè)策略 π(y|x),這會(huì)得到 prompt x 和完成結(jié)果 y 的一個(gè)分布。RLHF 的目標(biāo)是訓(xùn)練 LLM 以最大化一個(gè)獎(jiǎng)勵(lì)函數(shù) r (x, y),其評估的是 prompt - 完成結(jié)果對 (x, y) 的質(zhì)量。通常來說,還會(huì)添加一個(gè) KL 散度,以防止更新后的模型偏離基礎(chǔ)語言模型(π_ref)太遠(yuǎn)??傮w而言,RLHF 方法優(yōu)化的目標(biāo)為:
這是最大化在 prompt 分布 p 上的預(yù)期獎(jiǎng)勵(lì),而 p 則受 α 調(diào)節(jié)的 KL 約束的影響。通常而言,優(yōu)化這一目標(biāo)使用的是形式為 {(x, y^w, y^l )} 的比較數(shù)據(jù)集,其中「獲勝」的完成結(jié)果 y^w 優(yōu)于「失敗」的完成結(jié)果 y^l,記為 y^w ? y^l。
另外,這里把小型專家演示數(shù)據(jù)集記為 D_E,并假設(shè)這些演示是由專家策略 π_E 生成的,其能最大化預(yù)測獎(jiǎng)勵(lì)。DITTO 能直接使用語言模型輸出和專家演示來生成比較數(shù)據(jù)。也就是說,不同于合成數(shù)據(jù)的生成范式,DITTO 無需在給定任務(wù)上已經(jīng)表現(xiàn)很好的模型。
關(guān)鍵思路
DITTO 的關(guān)鍵見解在于語言模型本身,再加上專家演示,可以得到用于對齊的比較數(shù)據(jù)集,這樣就無需收集大量成對的偏好數(shù)據(jù)了。這會(huì)得到一個(gè)類似對比的目標(biāo),其中專家演示是正例。
生成比較。假定我們從專家策略采樣了一個(gè)完成結(jié)果 y^E ~ π_E (?|x) 。那么可以認(rèn)為,從其它策略 π 采樣的樣本對應(yīng)的獎(jiǎng)勵(lì)都低于或等于從 π_E 采樣的樣本的獎(jiǎng)勵(lì)?;谶@一觀察,該團(tuán)隊(duì)構(gòu)建了比較數(shù)據(jù) (x, y^E, y^π ),其中 y^E ? y^π。盡管這樣的比較數(shù)據(jù)源自策略而不是各個(gè)樣本,但之前已有研究證明了這種方法的有效性。對 DITTO 來說,一個(gè)很自然的做法就是使用這個(gè)數(shù)據(jù)集以及一個(gè)現(xiàn)成可用的 RLHF 算法來優(yōu)化 (1) 式。這樣做能在提升專家響應(yīng)的概率同時(shí)降低當(dāng)前模型樣本的概率,這不同于標(biāo)準(zhǔn)微調(diào)方法 —— 只會(huì)做前者。關(guān)鍵在于,通過使用來自 π 的樣本,可使用少量演示就構(gòu)建出無邊界的偏好數(shù)據(jù)集。但是,該團(tuán)隊(duì)發(fā)現(xiàn),通過考慮學(xué)習(xí)過程的時(shí)間方面,還能做到更好。
從比較到排名。僅使用來自專家和單個(gè)策略 π 的比較數(shù)據(jù),可能不足以獲得優(yōu)良性能。這樣做只會(huì)降低特定 π 的可能性,導(dǎo)致過擬合問題 —— 這也困擾著少數(shù)據(jù)情況下的 SFT。該團(tuán)隊(duì)提出還可以考慮 RLHF 期間隨時(shí)間而學(xué)習(xí)到的所有策略所生成的數(shù)據(jù),這類似于強(qiáng)化學(xué)習(xí)中的 replay(重放)。
令第一輪迭代時(shí)的初始策略為 π_0。通過采樣該策略可得到一個(gè)數(shù)據(jù)集 D_0。然后可以基于此生成一個(gè)用于 RLHF 的比較數(shù)據(jù)集,可記為 D_E ? D_0。使用這些導(dǎo)出的比較數(shù)據(jù),可以對 π_0 進(jìn)行更新而得到 π_1。根據(jù)定義,也成立。之后,繼續(xù)使用 π_1 生成比較數(shù)據(jù),并且 D_E ? D_1。繼續(xù)這一過程,不斷使用之前的所有策略生成越來越多樣化的比較數(shù)據(jù)。該團(tuán)隊(duì)將這些比較數(shù)據(jù)稱為「重放比較數(shù)據(jù)(replay comparisons)」。
盡管這種方法理論上說得通,但如果 D_E 較小,卻可能出現(xiàn)過擬合。但是,如果假設(shè)每一輪迭代后策略都會(huì)獲得提升,則也可在訓(xùn)練期間考慮策略之間的比較。不同于與專家的比較,我們并不能保證每一輪迭代之后策略都更好,但該團(tuán)隊(duì)發(fā)現(xiàn)模型每次迭代后總體依然是提升的,這可能是是因?yàn)楠?jiǎng)勵(lì)建模和 (1) 式都是凸的。這樣便可以依照以下的排名來采樣比較數(shù)據(jù):
通過添加這些「模型間」和「重放」比較數(shù)據(jù),得到的效果是早期樣本(比如 D_1 中的樣本)的似然會(huì)比后期的(如 D_t 中的)壓得更低,從而使隱含的獎(jiǎng)勵(lì)圖景變得平滑。在實(shí)踐實(shí)現(xiàn)中,該團(tuán)隊(duì)的做法是除了使用與專家的比較數(shù)據(jù),也聚合了一些這些模型間比較數(shù)據(jù)。
一個(gè)實(shí)踐算法。在實(shí)踐中,DITTO 算法是一個(gè)迭代過程,其由三個(gè)簡單的組件構(gòu)成,如算法 1 所示。
首先,在專家演示集上運(yùn)行監(jiān)督式微調(diào),執(zhí)行數(shù)量有限的梯度步驟。將這設(shè)為初始策略 π_0. 第二步,采樣比較數(shù)據(jù):在訓(xùn)練過程中,對于 D_E 中的 N 個(gè)演示中的每一個(gè),通過從 π_t 采樣 M 個(gè)完成結(jié)果而構(gòu)建一個(gè)新的數(shù)據(jù)集 D_t,然后根據(jù)策略 (2) 式將它們添加到排名中。當(dāng)從 (2) 式采樣比較數(shù)據(jù)時(shí),每一批 B 都由 70% 的「在線」比較數(shù)據(jù) D_E ? D_t、20% 的「重放」比較數(shù)據(jù) D_E ? D_{i<t} 以及 10% 的「模型間比較數(shù)據(jù)」D_{i≤t} ? D_{j<i}。最后,使用 RLHF 更新該策略。具體來說,該團(tuán)隊(duì)的做法是使用 DPO 損失函數(shù),通過前述流程采樣的數(shù)據(jù)批來更新 π_t,從而得到 π_{t+1}:
其中 σ 是來自 Bradley-Terry 偏好模型的 logistic 函數(shù)。在每次更新期間,來自 SFT 策略的參考模型都不會(huì)更新,以避免偏離初始化過遠(yuǎn)。
將 DITTO 推導(dǎo)成在線模仿學(xué)習(xí)
DITTO 可通過在線模仿學(xué)習(xí)角度推導(dǎo)出來,其中組合使用專家演示和在線數(shù)據(jù)來同時(shí)學(xué)習(xí)獎(jiǎng)勵(lì)函數(shù)和策略。具體來說,策略玩家會(huì)最大化預(yù)期獎(jiǎng)勵(lì) ?? (π, r),而獎(jiǎng)勵(lì)玩家則會(huì)最小化在在線數(shù)據(jù)集 D^π 上的損失 min_r L (D^π , r) 更具體而言,該團(tuán)隊(duì)的做法是使用 (1) 式中的策略目標(biāo)和標(biāo)準(zhǔn)的獎(jiǎng)勵(lì)建模損失來實(shí)例化該優(yōu)化問題:
推導(dǎo) DITTO,簡化 (3) 式的第一步是解決其內(nèi)部策略最大化問題。幸運(yùn)的是,該團(tuán)隊(duì)基于之前的研究發(fā)現(xiàn)策略目標(biāo) ??_KL 有一個(gè)閉式解,其形式為,其中 Z (x) 用于歸一化分布的配分函數(shù)。值得注意的是,這會(huì)在策略和獎(jiǎng)勵(lì)函數(shù)之間建立一種雙射關(guān)系,這可以被用于消除內(nèi)部優(yōu)化。通過重新排列這個(gè)解,可將獎(jiǎng)勵(lì)函數(shù)寫成:
此外,之前有研究表明這種重新參數(shù)化可以表示任意獎(jiǎng)勵(lì)函數(shù)。于是,通過代入到 (3) 式,可以將變量 r 變成 π,從而得到 DITTO 目標(biāo):
請注意,類似于 DPO,這里是隱式地估計(jì)獎(jiǎng)勵(lì)函數(shù)。而不同于 DPO 的地方是 DITTO 依賴一個(gè)在線的偏好數(shù)據(jù)集 D^π。
為什么 DITTO 比僅使用 SFT 好?
DITTO 表現(xiàn)更好的一個(gè)原因是:通過生成比較數(shù)據(jù),其使用的數(shù)據(jù)量遠(yuǎn)多于 SFT。另一個(gè)原因是在某些情況下,在線模仿學(xué)習(xí)方法的表現(xiàn)會(huì)超過演示者,而 SFT 只能模仿演示。
實(shí)驗(yàn)結(jié)果
該團(tuán)隊(duì)也進(jìn)行了實(shí)證研究,證明了 DITTO 的有效性。實(shí)驗(yàn)的具體設(shè)置請參閱原論文,我們這里僅關(guān)注實(shí)驗(yàn)結(jié)果。
基于靜態(tài)基準(zhǔn)的研究結(jié)果
靜態(tài)基準(zhǔn)的評估使用了 GPT-4,結(jié)果見表 1。
平均而言,DITTO 勝過其它所有方法:在 CMCC 上平均勝率為 71.67%,在 CCAT50 上平均勝率為 82.50%;總體平均勝率為 77.09%。在 CCAT50 上,對于所有作者,DITTO 僅在其中一個(gè)上沒有取得全面優(yōu)勝。在 CMCC 上,對于所有作者,DITTO 全面勝過其中一半基準(zhǔn),之后是 few-shot prompting 贏得 3 成。盡管 SFT 的表現(xiàn)很不錯(cuò),但 DITTO 相較于其的平均勝率提升了 11.7%。
用戶研究:測試泛化到自然任務(wù)的能力
總體而言,用戶研究的結(jié)果與在靜態(tài)基準(zhǔn)上的結(jié)果一致。DITTO 在對齊演示的偏好方面優(yōu)于對比方法,如表 2 所示:其中 DITTO (72.1% 勝率) > SFT (60.1%) > few-shot (48.1%) > self-prompt (44.2%) > zero-shot (25.0%)。
DITTO 在什么時(shí)候有用?
在使用 DITTO 之前,用戶必須考慮一些前提條件,從他們有多少演示到必須從語言模型采樣多少負(fù)例。該團(tuán)隊(duì)探索了這些決定的影響,并重點(diǎn)關(guān)注了 CMCC,因?yàn)槠涓采w的任務(wù)超過 CCAT。此外,他們還分析了演示與成對反饋的樣本效率。
算法擾動(dòng)
該團(tuán)隊(duì)對 DITTO 的組件進(jìn)行了消融研究。
如圖 2(左)所示,增加 DITTO 的迭代次數(shù)通常可以提升性能。
可以看到,當(dāng)?shù)螖?shù)從 1 次提升到 4 次,GPT-4 評估的勝率會(huì)有 31.5% 的提升。這樣的提升是非單調(diào)的 —— 在第 2 次迭代時(shí),性能稍有降低(-3.4%)。這是因?yàn)樵缙诘牡赡軙?huì)得到噪聲更大的樣本,從而降低性能。另一方面,如圖 2(中)所示,增加負(fù)例數(shù)量會(huì)使 DITTO 性能單調(diào)提升。此外,隨著采樣的負(fù)例增多,DITTO 性能的方差會(huì)下降。
另外,如表 3 所示,對 DITTO 的消融研究發(fā)現(xiàn),去掉其任何組件都會(huì)導(dǎo)致性能下降。
比如如果放棄在線方式的迭代式采樣,相比于使用 DITTO,勝率會(huì)從 70.1% 降至 57.3%。而如果在在線過程中持續(xù)更新 π_ref,則會(huì)導(dǎo)致性能大幅下降:從 70.1% 降至 45.8%。該團(tuán)隊(duì)猜想其原因是:更新 π_ref 可能會(huì)導(dǎo)致過擬合。最后,我們也能從表 3 中看到重放和策略間比較數(shù)據(jù)的重要性。
樣本效率
DITTO 的一大關(guān)鍵優(yōu)勢是其樣本效率。該團(tuán)隊(duì)對此進(jìn)行了評估,結(jié)果見圖 2(右);同樣,這里報(bào)告的是歸一化后的勝率。
首先可以看到,DITTO 的勝率一開始會(huì)快速提升。在演示數(shù)量從 1 變成 3 時(shí),每次增加都會(huì)讓歸一化性能大幅提升(0% → 5% → 11.9%)。
但是,當(dāng)演示數(shù)量進(jìn)一步增加時(shí),收益增幅降低了(從 4 增至 7 時(shí)為 11.9% → 15.39%),這說明隨著演示數(shù)量增加,DITTO 的性能會(huì)飽和。
另外,該團(tuán)隊(duì)猜想,不止演示數(shù)量會(huì)影響 DITTO 的性能,演示質(zhì)量也會(huì),但這還留待未來研究。
成對偏好與演示相比如何?
DITTO 的一個(gè)核心假設(shè)是樣本效率源自于演示。理論上講,如果用戶心中有一套完美的演示集合,通過標(biāo)注許多成對的偏好數(shù)據(jù)也能實(shí)現(xiàn)類似的效果。
該團(tuán)隊(duì)做了一個(gè)近似實(shí)驗(yàn),使用從指令遵從 Mistral 7B 采樣的輸出,讓一位提供了用戶研究的演示的作者也標(biāo)注了 500 對偏好數(shù)據(jù)。
總之,他們構(gòu)建了一個(gè)成對的偏好數(shù)據(jù)集 D_pref = {(x, y^i , y^j )},其中 y^i ? y^j。然后他們計(jì)算了采樣自兩個(gè)模型的 20 對結(jié)果的勝率情況 —— 其一是使用 DITTO 在 4 個(gè)演示上訓(xùn)練的,其二是僅使用 DPO 在 {0...500} 偏好數(shù)據(jù)對訓(xùn)練的。
當(dāng)僅從 π_ref 采樣成對偏好數(shù)據(jù)時(shí),可以觀察到生成的數(shù)據(jù)對位于演示的分布外 —— 成對的偏好不涉及用戶演示的行為(圖 3 中 Base policy 的結(jié)果,藍(lán)色)。即使當(dāng)他們使用用戶演示對 π_ref 進(jìn)行微調(diào)時(shí),仍然需要超過 500 對偏好數(shù)據(jù)才能比肩 DITTO 的性能(圖 3 中 Demo-finetuned policy 的結(jié)果,橙色)。