LLM實(shí)踐系列-細(xì)聊LLM的拒絕采樣
最近學(xué)強(qiáng)化的過(guò)程中,總是遇到“拒絕采樣”這個(gè)概念,我嘗試科普一下,爭(zhēng)取用最大白話的方式讓每個(gè)感興趣的同學(xué)都理解其中思想。
拒絕采樣是 LLM 從統(tǒng)計(jì)學(xué)借鑒過(guò)來(lái)的一個(gè)概念。其實(shí)大家很早就接觸過(guò)這個(gè)概念,每個(gè)刷過(guò) leetcode 的同學(xué)大概率都遇到過(guò)這樣一個(gè)問(wèn)題:“如何用一枚骰子獲得 1/7 的概率?”
答案很簡(jiǎn)單:把骰子扔兩次,獲得 6 * 6 = 36 種可能的結(jié)果,丟棄最后一個(gè)結(jié)果,剩下的 35 個(gè)結(jié)果平分成 7 份,對(duì)應(yīng)的概率值便為 1/7 。使用這種思想,我們可以利用一枚骰子獲得任意 1/N 的概率。
在這個(gè)問(wèn)題中,我們可以看到拒絕采樣的一些關(guān)鍵要素:
- 采樣:從易于采樣的分布(兩個(gè)骰子的所有可能結(jié)果)中生成樣本;
- 縮放:(扔兩次骰子)獲得更大的樣本分布;
- 拒絕:丟棄(拒絕)不符合條件的樣本(第36種情況);
- 接受:對(duì)于剩下的樣本,重新調(diào)整概率(通過(guò)分組),獲得目標(biāo)概率分布。
用大白話來(lái)總結(jié)就是:我們想獲得某個(gè)分布(1/7)的樣本,但卻沒(méi)有辦法。于是我們對(duì)另外一個(gè)分布(1/6)進(jìn)行采樣,但這個(gè)分布不能涵蓋原始分布,需要我們縮放這個(gè)分布(扔兩次)來(lái)包裹起來(lái)目標(biāo)分布。然后,我們以某種規(guī)則拒絕明顯不是目標(biāo)分布的采樣點(diǎn),剩下的采樣點(diǎn)就可以看作是從目標(biāo)分布采樣出來(lái)的了。
統(tǒng)計(jì)學(xué)的拒絕采樣
LLM 的拒絕采樣
LLM 的拒絕采樣操作起來(lái)非常簡(jiǎn)單:讓自己的模型針對(duì) prompt 生成多個(gè)候選 response,然后用 reward_model 篩選出來(lái)高質(zhì)量的 response (也可以是 pair 對(duì)),拿來(lái)再次進(jìn)行訓(xùn)練。
解剖這個(gè)過(guò)程:
- 提議分布是我們自己的模型,目標(biāo)分布是最好的語(yǔ)言模型;
- prompt + response = 一個(gè)采樣結(jié)果;
- do_sample 多次 = 縮放提議分布(也可以理解為扔多次骰子);
- 采樣結(jié)果得到 reward_model 的認(rèn)可 = 符合目標(biāo)分布。
經(jīng)過(guò)這一番操作,我們能獲得很多的訓(xùn)練樣本,“這些樣本既符合最好的語(yǔ)言模型的說(shuō)話習(xí)慣,又不偏離原始語(yǔ)言模型的表達(dá)習(xí)慣”,學(xué)習(xí)它們就能讓我們的模型更接近最好的語(yǔ)言模型。
統(tǒng)計(jì)學(xué)與 LLM 的映射關(guān)系
統(tǒng)計(jì)學(xué)的拒絕采樣有幾個(gè)關(guān)鍵要素:
- 原始分布采樣困難,提議分布采樣簡(jiǎn)單;
- 提議分布縮放后能涵蓋原始分布;
- 有辦法判斷從提議分布獲取的樣本是否屬于原始分布,這需要我們知道原始分布的密度函數(shù)。
LLM 的拒絕采樣也有幾個(gè)對(duì)應(yīng)的關(guān)鍵要素:
- 我們不知道最好的語(yǔ)言模型怎么說(shuō)話,但我們知道自己的語(yǔ)言模型如何說(shuō)話;
- 讓自己的語(yǔ)言模型反復(fù)說(shuō)話,得到的語(yǔ)料大概率會(huì)包括最好的語(yǔ)言模型的說(shuō)話方式;
- reward_model 可以判斷某句話是否屬于最好的語(yǔ)言模型的說(shuō)話方式。
目前為止,是不是看上去很有道理,很好理解。但其實(shí)這里有一個(gè)致命的邏輯漏洞:為什么我們的模型反復(fù) do_sample,就一定能覆蓋最好的語(yǔ)言模型呢?這不合邏輯啊,狗嘴里采樣多少次也吐不出象牙啊。
緊接著,就需要我們引出另一個(gè)概念了:RLHF 的優(yōu)化目標(biāo)是什么?
RLHF 與拒絕采樣
