LLM模型貪婪、溫度、Top-k、核采樣方式的區(qū)別(附代碼與示例)
在自然語言生成任務(wù)中,不同的采樣技術(shù)用于從語言模型的輸出中選擇下一個生成的單詞或詞語。這些技術(shù)包括貪婪采樣、溫度采樣、Top-k采樣和核(Nucleus)采樣。它們在選擇生成單詞的過程中有不同的策略,本文將介紹這四種采樣方式的區(qū)別。
1. 貪婪采樣 (Greedy Sampling)
貪婪采樣是一種直接選擇最可能的下一個詞的策略。
具體步驟為:從模型輸出的logits中,找到概率最大的那個詞,直接選擇它作為輸出。
實現(xiàn)代碼:
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
return logits.argmax(dim=-1)
優(yōu)點:
- 簡單且計算效率高。
- 保證每一步選擇最有可能的結(jié)果。
缺點:
- 可能會導(dǎo)致生成的文本非常重復(fù)和缺乏多樣性。
- 貪婪采樣只關(guān)注當(dāng)前概率最大的詞,忽略了其他潛在的好選擇,容易陷入局部最優(yōu)解。
2. 帶溫度的采樣 (Temperature Sampling)
溫度采樣通過引入一個溫度參數(shù)來調(diào)整輸出概率的分布,以控制生成文本的多樣性。溫度 T 的作用是平滑或銳化概率分布:
- 當(dāng) T = 1 時,采樣為標(biāo)準(zhǔn)隨機采樣。
- 當(dāng) T < 1 時,概率分布變得更尖銳,模型更傾向于選擇最可能的詞。
- 當(dāng) T > 1 時,概率分布變得更加平滑,模型會更多地探索低概率的詞。
實現(xiàn)代碼
class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
dist = Categorical(logits=logits / self.temperature)
return dist.sample()
優(yōu)點:
- 提供了生成文本的多樣性,尤其是在溫度高時。
- 通過調(diào)整溫度參數(shù),可以控制探索(隨機性)與利用(選擇高概率詞)之間的平衡。
缺點:
- 溫度的選擇需要仔細調(diào)節(jié),不同任務(wù)或場景下對溫度的需求可能不同。
- 溫度過低時,生成的文本趨向于貪婪采樣;溫度過高時,生成的文本可能過于隨機。
3. Top-k采樣
Top-k采樣限制了每次生成時候的候選詞數(shù)量,模型只會從概率前k個最高的詞中進行采樣,而忽略其他可能性較小的詞。
實現(xiàn)代碼:
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
優(yōu)點:
- 提供了對生成詞匯的嚴(yán)格控制,避免生成概率非常低的詞。
- 通過限制候選詞的數(shù)量,避免了一些罕見或不合邏輯的詞被選中。
缺點:
- 需要設(shè)定一個合適的 k 值,如果 k 值太小,生成的文本可能會缺乏多樣性;如果 k 值太大,則效果與標(biāo)準(zhǔn)采樣相似。
4. 核采樣 (Nucleus Sampling)
核采樣是一種自適應(yīng)的采樣方法,它選擇的候選詞集合 V(p) 是滿足累計概率和大于或等于給定閾值 p 的最小詞匯子集。與Top-k采樣不同,核采樣的候選詞數(shù)量不是固定的,而是基于累計概率動態(tài)確定的。
示例
假設(shè)同樣的語境:“今天的天氣很”,但這次我們將會有不同的詞匯及其概率分布,我們也會使用不同的閾值 ( p ) 來展示如何動態(tài)確定選詞數(shù)量。
(1) 模型預(yù)測的詞匯概率
- 好:0.4
- 冷:0.3
- 熱:0.2
- 潮濕:0.05
- 多變:0.03
- 干燥:0.02
(2) 排序與累積概率
按概率從高到低排序并計算累積概率:
- 好:0.4
- 冷:0.7 (0.4 + 0.3)
- 熱:0.9 (0.7 + 0.2)
- 潮濕:0.95 (0.9 + 0.05)
- 多變:0.98 (0.95 + 0.03)
- 干燥:1.00 (0.98 + 0.02)
(3) 確定核集合
這次,我們將選擇不同的閾值 ( p ) 來觀察核集合如何變化:
- **當(dāng) ( p = 0.7 )**:核集合包括:“好”和“冷”,因為它們的累積概率首次超過 0.7。
- **當(dāng) ( p = 0.9 )**:核集合擴展到:“好”,“冷”,和“熱”,因為它們的累積概率首次超過 0.9。
- **當(dāng) ( p = 0.95 )**:核集合進一步擴展到:“好”,“冷”,“熱”和“潮濕”,因為這是累積概率首次超過 0.95。
(4) 抽樣
在每種情況下,我們從對應(yīng)的核集合中隨機選取一個詞作為下一個詞。選擇的范圍和多樣性取決于 ( p ) 值的大小,而詞的數(shù)量是根據(jù)這個閾值動態(tài)確定的,不是固定的。
實現(xiàn)代碼
class NucleusSampler(Sampler):
"""
## Nucleus 采樣器
Nucleus 采樣器根據(jù)給定的概率 p 選擇詞匯的一個子集,并從中進行采樣。
"""
def __init__(self, p: float, sampler: Sampler):
"""
### 初始化
:param p: 要選擇的令牌概率之和,即 p 值。
:param sampler: 用于從選定令牌中進行采樣的采樣器。
"""
# 保存 p 值
self.p = p
# 保存采樣器
self.sampler = sampler
# 初始化 softmax 層,用于將 logits 轉(zhuǎn)換為概率
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
### 從 logits 中進行 Nucleus 采樣
:param logits: 輸入的 logits 張量,形狀為 (batch_size, num_tokens)。
:return: 采樣得到的令牌索引,形狀為 (batch_size,)。
"""
# 獲取概率 P(x_i | x_1:i-1)
probs = self.softmax(logits)
# 按降序?qū)Ω怕蔬M行排序,并獲取排序后的索引
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# 按排序順序獲取概率的累積總和
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找出累積總和小于 p 的令牌
nucleus = cum_sum_probs < self.p
# 在前面加一個 True,這樣我們可以在累積概率小于 p 的最小令牌數(shù)量之后添加一個令牌
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# 獲取對數(shù)概率并掩蓋非核部分
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# 使用采樣器從排序后的對數(shù)概率中進行采樣
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# 獲取實際的索引
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
# 返回采樣得到的令牌索引
return res.squeeze(-1)
優(yōu)點:
- 靈活性強,自動調(diào)整候選詞集合,避免了固定的詞數(shù)限制。
- 在生成文本時能夠更好地平衡多樣性與高概率詞的利用,表現(xiàn)優(yōu)于Top-k采樣。
缺點:
- 參數(shù) p 的選擇需要調(diào)節(jié),不同任務(wù)可能需要不同的 p 值。
- 計算復(fù)雜度較高,尤其是當(dāng)處理較大的詞匯表時。
總結(jié)
采樣方法 | 優(yōu)點 | 缺點 |
貪婪采樣 | 簡單、高效,始終選擇最有可能的詞 | 文本生成可能單一,缺乏多樣性 |
溫度采樣 | 通過調(diào)整溫度控制多樣性,適應(yīng)性強 | 溫度的調(diào)節(jié)需要謹(jǐn)慎,過高或過低的溫度可能產(chǎn)生不理想的結(jié)果 |
Top-k采樣 | 控制候選詞數(shù)量,避免選擇低概率詞 |
|
核采樣 | 動態(tài)選擇候選詞集合,更靈活,生成文本質(zhì)量較高 | 參數(shù) |
每種采樣方式都有其適用的場景,根據(jù)具體的應(yīng)用和對生成文本的要求,可以選擇不同的采樣策略。