自然語言生成任務(wù)中的五種采樣方法介紹和Pytorch代碼實現(xiàn)
在自然語言生成任務(wù)(NLG)中,采樣方法是指從生成模型中獲取文本輸出的一種技術(shù)。本文將介紹常用的5中方法并用Pytorch進(jìn)行實現(xiàn)。
1、Greedy Decoding
Greedy Decoding在每個時間步選擇當(dāng)前條件概率最高的詞語作為輸出,直到生成結(jié)束。在貪婪解碼中,生成模型根據(jù)輸入序列,逐個時間步地預(yù)測輸出序列中的每個詞語。在每個時間步,模型根據(jù)當(dāng)前的隱藏狀態(tài)和已生成的部分序列計算每個詞語的條件概率分布,模型選擇具有最高條件概率的詞語作為當(dāng)前時間步的輸出。這個詞語成為下一個時間步的輸入,生成過程持續(xù)直到滿足某種終止條件,比如生成了指定長度的序列或者生成了特殊的結(jié)束標(biāo)記。
這種方法簡單高效,每個時間步只需計算當(dāng)前條件概率最高的詞語,因此計算速度較快。但是由于每個時間步只考慮當(dāng)前條件概率最高的詞語,貪婪解碼可能會陷入局部最優(yōu)解,而無法獲得全局最優(yōu)解。這可能導(dǎo)致生成的文本缺乏多樣性或不準(zhǔn)確。
盡管貪婪解碼存在一些局限性,但它仍然是許多序列生成任務(wù)中常用的一種方法,特別是在對速度要求較高或者任務(wù)較為簡單的情況下。
def greedy_decoding(input_ids, max_tokens=300):
with torch.inference_mode():
for _ in range(max_tokens):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
if next_token == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, rearrange(next_token, 'c -> 1 c')], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text
2、Beam Search
束搜索(Beam Search)是貪婪解碼的一種擴(kuò)展,通過在每個時間步保留多個候選序列來克服貪婪解碼的局部最優(yōu)問題。
在每個時間步保留概率最高的前幾個候選詞語,然后在下一個時間步基于這些候選詞語繼續(xù)擴(kuò)展,直到生成結(jié)束。束搜索通過考慮多個候選詞語路徑,可以在一定程度上增加生成文本的多樣性。
在束搜索中,模型在每個時間步會生成多個候選序列,而不是僅選擇一個最優(yōu)序列。模型會根據(jù)當(dāng)前已生成的部分序列和隱藏狀態(tài),預(yù)測下一個時間步可能的詞語,并計算每個詞語的條件概率分布。
上圖的每一步中,只保留兩條最可能的路徑(根據(jù)beam =2),而所有其他都被丟棄。此過程將繼續(xù)進(jìn)行,直到滿足停止條件,該停止條件可以是生成序列結(jié)束令牌或達(dá)到最大序列長度的模型。最終輸出將是最后一組路徑中具有最高總體概率的序列。
from einops import rearrange
import torch.nn.functional as F
def beam_search(input_ids, max_tokens=100, beam_size=2):
beam_scores = torch.zeros(beam_size).to(device)
beam_sequences = input_ids.clone()
active_beams = torch.ones(beam_size, dtype=torch.bool)
for step in range(max_tokens):
outputs = model(beam_sequences)
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
top_scores, top_indices = torch.topk(probs.flatten(), k=beam_size, sorted=False)
beam_indices = top_indices // probs.shape[-1]
token_indices = top_indices % probs.shape[-1]
beam_sequences = torch.cat([
beam_sequences[beam_indices],
token_indices.unsqueeze(-1)
], dim=-1)
beam_scores = top_scores
active_beams = ~(token_indices == tokenizer.eos_token_id)
if not active_beams.any():
print("no active beams")
break
best_beam = beam_scores.argmax()
best_sequence = beam_sequences[best_beam]
generated_text = tokenizer.decode(best_sequence)
return generated_text
3、Temperature Sampling
溫度參數(shù)采樣(Temperature Sampling)常用于基于概率的生成模型,如語言模型。它通過引入一個稱為“溫度”(Temperature)的參數(shù)來調(diào)整模型輸出的概率分布,從而控制生成文本的多樣性。
在溫度參數(shù)采樣中,模型在每個時間步生成詞語時,會計算出詞語的條件概率分布。然后模型將這個條件概率分布中的每個詞語的概率值除以溫度參數(shù),對結(jié)果進(jìn)行歸一化處理,獲得新的歸一化概率分布。較高的溫度值會使概率分布更平滑,從而增加生成文本的多樣性。低概率的詞語也有較高的可能性被選擇;而較低的溫度值則會使概率分布更集中,更傾向于選擇高概率的詞語,因此生成的文本更加確定性。最后模型根據(jù)這個新的歸一化概率分布進(jìn)行隨機(jī)采樣,選擇生成的詞語。
import torch
import torch.nn.functional as F
def temperature_sampling(logits, temperature=1.0):
logits = logits / temperature
probabilities = F.softmax(logits, dim=-1)
sampled_token = torch.multinomial(probabilities, 1)
return sampled_token.item()
4、Top-K Sampling
Top-K 采樣(在每個時間步選擇條件概率排名前 K 的詞語,然后在這 K 個詞語中進(jìn)行隨機(jī)采樣。這種方法既能保持一定的生成質(zhì)量,又能增加文本的多樣性,并且可以通過限制候選詞語的數(shù)量來控制生成文本的多樣性。
這個過程使得生成的文本在保持一定的生成質(zhì)量的同時,也具有一定的多樣性,因為在候選詞語中仍然存在一定的競爭性。
參數(shù) K 控制了在每個時間步中保留的候選詞語的數(shù)量。較小的 K 值會導(dǎo)致更加貪婪的行為,因為只有少數(shù)幾個詞語參與隨機(jī)采樣,而較大的 K 值會增加生成文本的多樣性,但也會增加計算開銷。
def top_k_sampling(input_ids, max_tokens=100, top_k=50, temperature=1.0):
for _ in range(max_tokens):
with torch.inference_mode():
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
top_k_probs = F.softmax(top_k_logits / temperature, dim=-1)
next_token_index = torch.multinomial(top_k_probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token_index)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text
5、Top-P (Nucleus) Sampling:
Nucleus Sampling(核采樣),也被稱為Top-p Sampling旨在在保持生成文本質(zhì)量的同時增加多樣性。這種方法可以視作是Top-K Sampling的一種變體,它在每個時間步根據(jù)模型輸出的概率分布選擇概率累積超過給定閾值p的詞語集合,然后在這個詞語集合中進(jìn)行隨機(jī)采樣。這種方法會動態(tài)調(diào)整候選詞語的數(shù)量,以保持一定的文本多樣性。
在Nucleus Sampling中,模型在每個時間步生成詞語時,首先按照概率從高到低對詞匯表中的所有詞語進(jìn)行排序,然后模型計算累積概率,并找到累積概率超過給定閾值p的最小詞語子集,這個子集就是所謂的“核”(nucleus)。模型在這個核中進(jìn)行隨機(jī)采樣,根據(jù)詞語的概率分布來選擇最終輸出的詞語。這樣做可以保證所選詞語的總概率超過了閾值p,同時也保持了一定的多樣性。
參數(shù)p是Nucleus Sampling中的重要參數(shù),它決定了所選詞語的概率總和。p的值會被設(shè)置在(0,1]之間,表示詞語總概率的一個下界。
Nucleus Sampling 能夠保持一定的生成質(zhì)量,因為它在一定程度上考慮了概率分布。通過選擇概率總和超過給定閾值p的詞語子集進(jìn)行隨機(jī)采樣,Nucleus Sampling 能夠增加生成文本的多樣性。
def top_p_sampling(input_ids, max_tokens=100, top_p=0.95):
with torch.inference_mode():
for _ in range(max_tokens):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probabilities, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits.scatter_(-1, indices_to_remove[None, :], float('-inf'))
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text
總結(jié)
自然語言生成任務(wù)中,采樣方法是非常重要的。選擇合適的采樣方法可以在一定程度上影響生成文本的質(zhì)量、多樣性和效率。上面介紹的幾種采樣方法各有特點,適用于不同的應(yīng)用場景和需求。
貪婪解碼是一種簡單直接的方法,適用于速度要求較高的情況,但可能導(dǎo)致生成文本缺乏多樣性。束搜索通過保留多個候選序列來克服貪婪解碼的局部最優(yōu)問題,生成的文本質(zhì)量更高,但計算開銷較大。Top-K 采樣和核采樣可以控制生成文本的多樣性,適用于需要平衡質(zhì)量和多樣性的場景。溫度參數(shù)采樣則可以根據(jù)溫度參數(shù)靈活調(diào)節(jié)生成文本的多樣性,適用于需要平衡多樣性和質(zhì)量的任務(wù)。