StaR | 用少量推理數(shù)據(jù)讓模型學會通用推理能力,顯著提升模型復雜推理
一、概述
?Title:STaR: Bootstrapping Reasoning With Reasoning
?URL:?? https://arxiv.org/abs/2203.14465??
?Authors:Eric Zelikman, Yuhuai Wu, Jesse Mu, Noah D. Goodman
?Code:?? https://github.com/ezelikman/STaR??
1 Motivation
?Step-by-step推理步驟生成可以提升語言模型在復雜推理任務(如數(shù)學或常識問答)上的性能,但是當前要讓LLM能生成rationale推理過程,要么需要構建龐大的推理數(shù)據(jù)集,要么在只使用少量示例(但推理時犧牲了準確性)。
?需要一種方法來利用少量的推理示例和大量未經(jīng)過推理的數(shù)據(jù)來提升模型的推理能力。
2 Methods
1 省流版總結:
- 使用少量推理示例(few-shot)引導語言模型生成多個問題的推理Rational過程。
- 對于模型生成的錯誤答案,通過提供正確答案(Hint)來生成新的推理過程(稱為“rationalization”)。
- 在所有最終生成正確答案的推理上微調(diào)模型(Finetune)。
- 重復上述過程,直到performance不再提升(注意每次都使用original的預模型進行continually training來避免overfitting)。
2 專業(yè)版總結:
本文提出了一種名為“Self-Taught Reasoner”(STaR)的方法來解決語言模型在復雜推理任務上性能提升的問題。**STaR方法的核心思想是通過迭代地利用少量推理示例(rationales)和大量無推理數(shù)據(jù)集,逐步引導模型提升進行復雜推理的能力。**具體來說,STaR方法包括以下幾個步驟:
- Rationale Generation Bootstrapping:首先,使用少量帶有推理過程的示例作為提示,引導預訓練的大型語言模型(LLM)生成多個問題的推理過程。這個過程被稱為“rationale generation”。
- Filtering and Finetuning:接著,只保留那些生成了正確答案的推理過程,并在這些數(shù)據(jù)上對模型進行微調(diào)(finetune)。這一步驟的目的是強化模型生成高質(zhì)量推理過程的能力。
- Rationalization:對于模型未能正確回答的問題,STaR采用一種稱為“rationalization”的技術。在這個階段,模型被提供正確答案作為提示,然后生成一個合理的推理過程來解釋這個答案。這樣做可以讓模型從錯誤中學習,并改進其推理策略。
- Iterative Improvement:重復上述過程,每次都使用上一輪微調(diào)后的模型來生成新的訓練數(shù)據(jù)。通過這種方式,模型逐漸學習如何更好地生成推理過程,并解決越來越復雜的問題。
- 5.Performance Evaluation:在每次迭代后,評估模型在測試集上的性能,直到性能達到飽和或不再顯著提升。
3 Rationalization指的是什么?
Q1:為什么要用Rationalization?
? 直接讓LLM生成推理思考過程,這些思考過程有些是對的,有些是錯的,直接拿正確的思考過程,來訓練llm生成rational,由于沒有增量信息,會導致模型不能從failed example中學習,這樣就不能讓模型具備對new problems進行推理的能力。
Q2: 如何生成Rational
? 如下圖所示,直接讓LLM生成推理過程,對于failed的例子,加上label作為hint,基于hint,可以生成正確的推理過程。
3 Conclusion
? STaR顯著提升了在多個數(shù)據(jù)集上的性能,相對于直接預測最終答案的模型,其效果更加突出。
? 在CommonsenseQA數(shù)據(jù)集上的表現(xiàn)與微調(diào)一個大30倍的最先進語言模型相當。
? STaR使得模型能夠通過學習自身生成的推理步驟逐步提升推理能力。
二、詳細內(nèi)容
1 實驗設計
數(shù)據(jù)集:
- 算術問題:使用隨機生成的加法問題來測試STaR在處理數(shù)字運算任務上的性能。
- 常識問答(CommonsenseQA):使用CommonsenseQA(CQA)數(shù)據(jù)集,這是一個多項選擇的常識推理任務,測試STaR在自然語言推理上的能力。
- 小學數(shù)學(Grade School Math, GSM8K):使用GSM8K數(shù)據(jù)集,包含小學水平的數(shù)學問題,這些問題以自然語言的形式表述,需要進行多步計算來得出答案。
Baseline:模型采用的是6B的開源模型(GPT-J),其checkpoint和fine-tuning code都開源了。
2 Rationalization能快速提升accuracy(從失敗中學習能快速成長!!!)
說明;rationalization指的就是對于failed的example,加上hint,生成正確的推理過程數(shù)據(jù)并用于訓練。
結論:隨著STaR算法迭代次數(shù)的增加,模型在算術任務上的準確率逐漸提高。特別是在使用rationalization的情況下,準確率提升更加塊。
3 STaR + rationalization比直接FT和few-shot效果好很多
? CQA數(shù)據(jù)集
? GSM8K數(shù)據(jù)集
說明:
? Direct Finetuned:不輸出中間推理過程
? STaR without rationalization:不從失敗樣例中學習(以label作為hint生成推理過程用于ft)
? STaR with rationalization:從失敗中學習
結論1:生成中間推理過程能顯著提升最終的精度,例如就算使用100%的數(shù)據(jù),不加推理過程,精度只能到60%,加上后用更少的數(shù)據(jù)卻能更高的精度(大于68%)。
結論2:rationalization從失敗中學習能進一步提升精度。
三、總結
STaR方法的關鍵在于,它允許模型通過自我生成的推理過程來自我改進,而不需要人工標注大量的推理數(shù)據(jù)集。此外,**通過rationalization技術,STaR能夠確保模型從其錯誤中學習,從而提高整體的推理能力。**論文的實驗結果表明,STaR在多個數(shù)據(jù)集上的性能顯著優(yōu)于直接預測答案的模型,并且與使用30倍更大模型的微調(diào)性能相當。
本文轉載自??NLP PaperWeekly??,作者: NLP PaperWeekly ????
