自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

自然語言生成任務(wù)中的五種采樣方法介紹和Pytorch代碼實現(xiàn)

人工智能
在自然語言生成任務(wù)(NLG)中,采樣方法是指從生成模型中獲取文本輸出的一種技術(shù)。本文將介紹常用的5中方法并用Pytorch進(jìn)行實現(xiàn)。

在自然語言生成任務(wù)(NLG)中,采樣方法是指從生成模型中獲取文本輸出的一種技術(shù)。本文將介紹常用的5中方法并用Pytorch進(jìn)行實現(xiàn)。

1、Greedy Decoding

Greedy Decoding在每個時間步選擇當(dāng)前條件概率最高的詞語作為輸出,直到生成結(jié)束。在貪婪解碼中,生成模型根據(jù)輸入序列,逐個時間步地預(yù)測輸出序列中的每個詞語。在每個時間步,模型根據(jù)當(dāng)前的隱藏狀態(tài)和已生成的部分序列計算每個詞語的條件概率分布,模型選擇具有最高條件概率的詞語作為當(dāng)前時間步的輸出。這個詞語成為下一個時間步的輸入,生成過程持續(xù)直到滿足某種終止條件,比如生成了指定長度的序列或者生成了特殊的結(jié)束標(biāo)記。

這種方法簡單高效,每個時間步只需計算當(dāng)前條件概率最高的詞語,因此計算速度較快。但是由于每個時間步只考慮當(dāng)前條件概率最高的詞語,貪婪解碼可能會陷入局部最優(yōu)解,而無法獲得全局最優(yōu)解。這可能導(dǎo)致生成的文本缺乏多樣性或不準(zhǔn)確。

盡管貪婪解碼存在一些局限性,但它仍然是許多序列生成任務(wù)中常用的一種方法,特別是在對速度要求較高或者任務(wù)較為簡單的情況下。

 def greedy_decoding(input_ids, max_tokens=300):
     with torch.inference_mode():
         for _ in range(max_tokens):
             outputs = model(input_ids)
             next_token_logits = outputs.logits[:, -1, :]
             next_token = torch.argmax(next_token_logits, dim=-1)
             if next_token == tokenizer.eos_token_id:
                 break
             input_ids = torch.cat([input_ids, rearrange(next_token, 'c -> 1 c')], dim=-1)
         generated_text = tokenizer.decode(input_ids[0])
     return generated_text

2、Beam Search

束搜索(Beam Search)是貪婪解碼的一種擴(kuò)展,通過在每個時間步保留多個候選序列來克服貪婪解碼的局部最優(yōu)問題。

在每個時間步保留概率最高的前幾個候選詞語,然后在下一個時間步基于這些候選詞語繼續(xù)擴(kuò)展,直到生成結(jié)束。束搜索通過考慮多個候選詞語路徑,可以在一定程度上增加生成文本的多樣性。

在束搜索中,模型在每個時間步會生成多個候選序列,而不是僅選擇一個最優(yōu)序列。模型會根據(jù)當(dāng)前已生成的部分序列和隱藏狀態(tài),預(yù)測下一個時間步可能的詞語,并計算每個詞語的條件概率分布。

上圖的每一步中,只保留兩條最可能的路徑(根據(jù)beam =2),而所有其他都被丟棄。此過程將繼續(xù)進(jìn)行,直到滿足停止條件,該停止條件可以是生成序列結(jié)束令牌或達(dá)到最大序列長度的模型。最終輸出將是最后一組路徑中具有最高總體概率的序列。

 from einops import rearrange
 import torch.nn.functional as F
 
 def beam_search(input_ids, max_tokens=100, beam_size=2):
     beam_scores = torch.zeros(beam_size).to(device)
     beam_sequences = input_ids.clone()
     active_beams = torch.ones(beam_size, dtype=torch.bool)
     for step in range(max_tokens):
         outputs = model(beam_sequences)
         logits = outputs.logits[:, -1, :]
         probs = F.softmax(logits, dim=-1)
         top_scores, top_indices = torch.topk(probs.flatten(), k=beam_size, sorted=False)
         beam_indices = top_indices // probs.shape[-1]
         token_indices = top_indices % probs.shape[-1]
         beam_sequences = torch.cat([
             beam_sequences[beam_indices],
             token_indices.unsqueeze(-1)
        ], dim=-1)
         beam_scores = top_scores
         active_beams = ~(token_indices == tokenizer.eos_token_id)
         if not active_beams.any():
             print("no active beams")
             break
     best_beam = beam_scores.argmax()
     best_sequence = beam_sequences[best_beam]
     generated_text = tokenizer.decode(best_sequence)
     return generated_text

3、Temperature Sampling

溫度參數(shù)采樣(Temperature Sampling)常用于基于概率的生成模型,如語言模型。它通過引入一個稱為“溫度”(Temperature)的參數(shù)來調(diào)整模型輸出的概率分布,從而控制生成文本的多樣性。

在溫度參數(shù)采樣中,模型在每個時間步生成詞語時,會計算出詞語的條件概率分布。然后模型將這個條件概率分布中的每個詞語的概率值除以溫度參數(shù),對結(jié)果進(jìn)行歸一化處理,獲得新的歸一化概率分布。較高的溫度值會使概率分布更平滑,從而增加生成文本的多樣性。低概率的詞語也有較高的可能性被選擇;而較低的溫度值則會使概率分布更集中,更傾向于選擇高概率的詞語,因此生成的文本更加確定性。最后模型根據(jù)這個新的歸一化概率分布進(jìn)行隨機(jī)采樣,選擇生成的詞語。

 import torch
 import torch.nn.functional as F
 
 def temperature_sampling(logits, temperature=1.0):
     logits = logits / temperature
     probabilities = F.softmax(logits, dim=-1)
     sampled_token = torch.multinomial(probabilities, 1)
     return sampled_token.item()

4、Top-K Sampling

Top-K 采樣(在每個時間步選擇條件概率排名前 K 的詞語,然后在這 K 個詞語中進(jìn)行隨機(jī)采樣。這種方法既能保持一定的生成質(zhì)量,又能增加文本的多樣性,并且可以通過限制候選詞語的數(shù)量來控制生成文本的多樣性。

這個過程使得生成的文本在保持一定的生成質(zhì)量的同時,也具有一定的多樣性,因為在候選詞語中仍然存在一定的競爭性。

參數(shù) K 控制了在每個時間步中保留的候選詞語的數(shù)量。較小的 K 值會導(dǎo)致更加貪婪的行為,因為只有少數(shù)幾個詞語參與隨機(jī)采樣,而較大的 K 值會增加生成文本的多樣性,但也會增加計算開銷。

 def top_k_sampling(input_ids, max_tokens=100, top_k=50, temperature=1.0):
    for _ in range(max_tokens):
         with torch.inference_mode():
             outputs = model(input_ids)
             next_token_logits = outputs.logits[:, -1, :]
             top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
             top_k_probs = F.softmax(top_k_logits / temperature, dim=-1)
             next_token_index = torch.multinomial(top_k_probs, num_samples=1)
             next_token = top_k_indices.gather(-1, next_token_index)
             input_ids = torch.cat([input_ids, next_token], dim=-1)
     generated_text = tokenizer.decode(input_ids[0])
     return generated_text

5、Top-P (Nucleus) Sampling:

Nucleus Sampling(核采樣),也被稱為Top-p Sampling旨在在保持生成文本質(zhì)量的同時增加多樣性。這種方法可以視作是Top-K Sampling的一種變體,它在每個時間步根據(jù)模型輸出的概率分布選擇概率累積超過給定閾值p的詞語集合,然后在這個詞語集合中進(jìn)行隨機(jī)采樣。這種方法會動態(tài)調(diào)整候選詞語的數(shù)量,以保持一定的文本多樣性。

在Nucleus Sampling中,模型在每個時間步生成詞語時,首先按照概率從高到低對詞匯表中的所有詞語進(jìn)行排序,然后模型計算累積概率,并找到累積概率超過給定閾值p的最小詞語子集,這個子集就是所謂的“核”(nucleus)。模型在這個核中進(jìn)行隨機(jī)采樣,根據(jù)詞語的概率分布來選擇最終輸出的詞語。這樣做可以保證所選詞語的總概率超過了閾值p,同時也保持了一定的多樣性。

參數(shù)p是Nucleus Sampling中的重要參數(shù),它決定了所選詞語的概率總和。p的值會被設(shè)置在(0,1]之間,表示詞語總概率的一個下界。

Nucleus Sampling 能夠保持一定的生成質(zhì)量,因為它在一定程度上考慮了概率分布。通過選擇概率總和超過給定閾值p的詞語子集進(jìn)行隨機(jī)采樣,Nucleus Sampling 能夠增加生成文本的多樣性。

 def top_p_sampling(input_ids, max_tokens=100, top_p=0.95):
     with torch.inference_mode():
         for _ in range(max_tokens):
                 outputs = model(input_ids)
                 next_token_logits = outputs.logits[:, -1, :]
                 sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                 sorted_probabilities = F.softmax(sorted_logits, dim=-1) 
                 cumulative_probs = torch.cumsum(sorted_probabilities, dim=-1)
                 sorted_indices_to_remove = cumulative_probs > top_p
                 sorted_indices_to_remove[..., 0] = False 
                 indices_to_remove = sorted_indices[sorted_indices_to_remove]
                 next_token_logits.scatter_(-1, indices_to_remove[None, :], float('-inf'))
                 probs = F.softmax(next_token_logits, dim=-1)
                 next_token = torch.multinomial(probs, num_samples=1)
                 input_ids = torch.cat([input_ids, next_token], dim=-1)
         generated_text = tokenizer.decode(input_ids[0])
     return generated_text

總結(jié)

自然語言生成任務(wù)中,采樣方法是非常重要的。選擇合適的采樣方法可以在一定程度上影響生成文本的質(zhì)量、多樣性和效率。上面介紹的幾種采樣方法各有特點,適用于不同的應(yīng)用場景和需求。

貪婪解碼是一種簡單直接的方法,適用于速度要求較高的情況,但可能導(dǎo)致生成文本缺乏多樣性。束搜索通過保留多個候選序列來克服貪婪解碼的局部最優(yōu)問題,生成的文本質(zhì)量更高,但計算開銷較大。Top-K 采樣和核采樣可以控制生成文本的多樣性,適用于需要平衡質(zhì)量和多樣性的場景。溫度參數(shù)采樣則可以根據(jù)溫度參數(shù)靈活調(diào)節(jié)生成文本的多樣性,適用于需要平衡多樣性和質(zhì)量的任務(wù)。

責(zé)任編輯:華軒 來源: DeepHub IMBA
相關(guān)推薦

2009-11-25 14:25:14

PHP自然語言排序

2020-02-25 12:00:53

自然語言開源工具

2020-02-25 23:28:50

工具代碼開發(fā)

2023-06-26 15:11:30

智能家居自然語言

2021-11-12 15:43:10

Python自然語言數(shù)據(jù)

2009-11-25 14:31:43

PHP自然語言倒序

2024-10-16 21:05:07

2020-04-24 10:53:08

自然語言處理NLP是人工智能

2024-08-07 10:39:47

ChatGPT自然語言企業(yè)數(shù)據(jù)

2017-04-10 16:15:55

人工智能深度學(xué)習(xí)應(yīng)用

2021-05-13 07:17:13

Snownlp自然語言處理庫

2022-09-23 11:16:26

自然語言人工智能

2017-10-19 17:05:58

深度學(xué)習(xí)自然語言

2024-04-24 11:38:46

語言模型NLP人工智能

2017-06-23 19:08:23

大數(shù)據(jù)PyTorch自然語言

2009-07-03 09:29:24

KeelKit

2023-04-20 15:13:11

PytorchGrad-CAM架構(gòu)

2024-04-09 14:31:37

大語言模型Pytorch

2024-02-05 14:18:07

自然語言處理
點贊
收藏

51CTO技術(shù)棧公眾號