ReFT(表征微調(diào)):比PeFT效果更好的新的大語(yǔ)言模型微調(diào)技術(shù)
ReFT(Representation Finetuning)是一種突破性的方法,有望重新定義我們對(duì)大型語(yǔ)言模型進(jìn)行微調(diào)的方式。
這是由斯坦福大學(xué)的研究人員剛剛(4月)發(fā)布在arxiv上的論文,ReFT與傳統(tǒng)的基于權(quán)重的微調(diào)方法大有不同,它提供了一種更高效和有效的方法來(lái)適應(yīng)這些大規(guī)模的模型,以適應(yīng)新的任務(wù)和領(lǐng)域!
在介紹這篇論文之前,我們先看看PeFT。
參數(shù)高效微調(diào) PeFT
參數(shù)高效微調(diào)方法(Parameter-Efficient Fine-Tuning,PEFT)僅微調(diào)少量或額外的模型參數(shù),固定大部分預(yù)訓(xùn)練參數(shù),大大降低了計(jì)算和存儲(chǔ)成本,同時(shí)最先進(jìn)的 PEFT 技術(shù)也能實(shí)現(xiàn)了與全量微調(diào)相當(dāng)?shù)男阅堋?/span>
在PeFT的思想之上就產(chǎn)生了我們非常熟悉的LoRA,還有各種LoRA的變體,除了有名的LoRA之外常用的PeFT方法還有:
Prefix Tuning:通過(guò)virtual token構(gòu)造連續(xù)型隱式prompt ,這是21年斯坦福發(fā)布的方法。
P-Tuning V1/V2:這是清華大學(xué)在21年提出的將自然語(yǔ)言的離散模版轉(zhuǎn)化為可訓(xùn)練的隱式prompt (連續(xù)參數(shù)優(yōu)化問(wèn)題),V2版在輸入前面的每層加入可微調(diào)的參數(shù),增強(qiáng)了V1版的性能。
然后就是我們熟悉的也是最長(zhǎng)用的LoRA,這里就不多介紹了,我們可以狹義理解為L(zhǎng)oRA是目前最好的PeFT方法,這樣可以對(duì)我們下面介紹的ReFT更好的對(duì)比。
表征微調(diào) ReFT
ReFT (Representation Finetuning)是一組專注于在推理過(guò)程中對(duì)語(yǔ)言模型的隱藏表示學(xué)習(xí)干預(yù)的方法,而不是直接修改其權(quán)重。
與更新模型整個(gè)參數(shù)集的傳統(tǒng)微調(diào)方法不同,ReFT通過(guò)策略性地操縱模型表示的一小部分來(lái)操作,指導(dǎo)其行為以更有效地解決下游任務(wù)。
ReFT背后的核心思想受到最近語(yǔ)言模型可解釋性研究的啟發(fā):在這些模型學(xué)習(xí)的表示中編碼了豐富的語(yǔ)義信息。通過(guò)干預(yù)這些表示,ReFT旨在解鎖和利用這些編碼知識(shí),實(shí)現(xiàn)更高效和有效的模型適應(yīng)。
ReFT的一個(gè)關(guān)鍵優(yōu)點(diǎn)是它的參數(shù)效率:傳統(tǒng)的微調(diào)方法需要更新模型參數(shù)的很大一部分,這可能是計(jì)算昂貴和資源密集的,特別是對(duì)于具有數(shù)十億參數(shù)的大型語(yǔ)言模型。ReFT方法通常需要訓(xùn)練數(shù)量級(jí)更少的參數(shù),從而獲得更快的訓(xùn)練時(shí)間和更少的內(nèi)存需求。
ReFT與PeFT有何不同
ReFT與傳統(tǒng)PEFT方法在幾個(gè)關(guān)鍵方面有所不同:
1、干預(yù)目標(biāo)
PEFT方法,例如,LoRA、DoRA和prefix-tuning,側(cè)重于修改模型的權(quán)重或引入額外的權(quán)重矩陣。而ReFT方法不直接修改模型的權(quán)重;它們會(huì)干預(yù)模型在向前傳遞期間計(jì)算的隱藏表示。
2、適應(yīng)機(jī)制
像LoRA和DoRA這樣的PEFT方法學(xué)習(xí)權(quán)重更新或模型權(quán)重矩陣的低秩近似值。然后在推理期間將這些權(quán)重更新合并到基本模型的權(quán)重中,從而不會(huì)產(chǎn)生額外的計(jì)算開(kāi)銷。ReFT方法學(xué)習(xí)干預(yù),在推理過(guò)程中在特定層和位置操縱模型的表示。此干預(yù)過(guò)程會(huì)產(chǎn)生一些計(jì)算開(kāi)銷,但可以實(shí)現(xiàn)更有效的適應(yīng)。
3、動(dòng)機(jī)
PEFT方法的主要?jiǎng)訖C(jī)是對(duì)參數(shù)有效適應(yīng)的需求,減少了調(diào)優(yōu)大型語(yǔ)言模型的計(jì)算成本和內(nèi)存需求。另一方面,ReFT方法受到最近語(yǔ)言模型可解釋性研究的啟發(fā),該研究表明,在這些模型學(xué)習(xí)的表示中編碼了豐富的語(yǔ)義信息。ReFT的目標(biāo)是利用和利用這些編碼的知識(shí)來(lái)更有效地適應(yīng)模型。
4.參數(shù)效率
PEFT和ReFT方法都是為了參數(shù)效率而設(shè)計(jì)的,但ReFT方法在實(shí)踐中證明了更高的參數(shù)效率。例如LoReFT(低秩線性子空間ReFT)方法通常需要訓(xùn)練的參數(shù)比最先進(jìn)的PEFT方法(LoRA)少10-50倍,同時(shí)在各種NLP基準(zhǔn)測(cè)試中獲得具有競(jìng)爭(zhēng)力或更好的性能。
5、可解釋性
雖然PEFT方法主要側(cè)重于有效的適應(yīng),但ReFT方法在可解釋性方面提供了額外的優(yōu)勢(shì)。通過(guò)干預(yù)已知編碼特定語(yǔ)義信息的表示,ReFT方法可以深入了解語(yǔ)言模型如何處理和理解語(yǔ)言,從而可能導(dǎo)致更透明和值得信賴的人工智能系統(tǒng)。
ReFT架構(gòu)
ReFT模型體系結(jié)構(gòu)定義了干預(yù)的一般概念,這基本上意味著在模型向前傳遞期間對(duì)隱藏表示的修改。我們首先考慮一個(gè)基于transformer的語(yǔ)言模型,該模型生成標(biāo)記序列的上下文化表示。
給定一個(gè)n個(gè)輸入令牌序列x = (x?,…,xn),模型首先將其嵌入到一個(gè)表示列表中,就h?,…,hn。然后m層連續(xù)計(jì)算第j個(gè)隱藏表示,每一個(gè)隱藏的表示都是一個(gè)向量h∈λ,其中d是表示的維數(shù)。
ReFT定義了一個(gè)干預(yù)的概念,它在模型向前傳遞期間修改隱藏的表示。
干預(yù)I是一個(gè)元組?Φ, P, L?,它封裝了由基于transformer的LM計(jì)算的表示的單個(gè)推理時(shí)間的干預(yù)動(dòng)作,這個(gè)函數(shù)包含了三個(gè)參數(shù):
干預(yù)函數(shù)Φ:用學(xué)習(xí)到的參數(shù)Φ (Φ)來(lái)表示。
干預(yù)所應(yīng)用的一組輸入位置P≤{1,…,n}。
對(duì)層L∈{1,…,m}進(jìn)行干預(yù)。
然后,干預(yù)的動(dòng)作如下:
h??? ← (Φ(h_p???) if p ∈ P else h_p???)_{p∈1,…,n}
該干預(yù)在前向傳播計(jì)算完后立即進(jìn)行,所以會(huì)影響到后續(xù)層中計(jì)算的表示。
為了提高計(jì)算的效率,也可以將干預(yù)的權(quán)重進(jìn)行低秩分解,也就是得到了低秩線性子空間ReFT (LoReFT)。
在上面的公式中使用學(xué)習(xí)到的投影源Rs = Wh +b。LoReFT編輯R列的R維子空間中的表示,來(lái)或取從我們的線性投影Wh +b中獲得的值。
對(duì)于生成任務(wù),ReFT論文使用語(yǔ)言建模的訓(xùn)練目標(biāo),重點(diǎn)是在所有輸出位置上使用最小化交叉熵?fù)p失。
pyreft庫(kù)代碼示例
斯坦福大學(xué)的研究人員在發(fā)布論文的同時(shí)還發(fā)布了pyreft庫(kù),這是一個(gè)建立在pyvene之上用于在任意PyTorch模型上執(zhí)行和訓(xùn)練激活干預(yù)的庫(kù)。
pyreft可以兼容HuggingFace上可用的任何預(yù)訓(xùn)練語(yǔ)言模型,并且可以使用ReFT方法進(jìn)行微調(diào)。以下是如何將lama- 27b模型的第19層輸出進(jìn)行單一干預(yù)的代碼示例:
import torch
import transformers
from pyreft import (
get_reft_model,
ReftConfig,
LoreftIntervention,
ReftTrainerForCausalLM
)
# Loading HuggingFace model
model_name_or_path = "yahma/llama-7b-hf"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map="cuda"
)
# Wrap the model with rank-1 constant reFT
reft_config = ReftConfig(
representations={
"layer": 19,
"component": "block_output",
"intervention": LoreftIntervention(
embed_dim=model.config.hidden_size, low_rank_dimension=1
),
}
)
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()
剩下的代碼就和HuggingFace訓(xùn)練模型沒(méi)有任何的區(qū)別了,我們來(lái)做一個(gè)完整的演示:
from pyreft import (
ReftTrainerForCausalLM,
make_last_position_supervised_data_module
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, model_max_length=2048, padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token
# get training data to train our intervention to remember the following sequence
memo_sequence = """
Welcome to the Natural Language Processing Group at Stanford University!
We are a passionate, inclusive group of students and faculty, postdocs
and research engineers, who work together on algorithms that allow computers
to process, generate, and understand human languages. Our interests are very
broad, including basic scientific research on computational linguistics,
machine learning, practical applications of human language technology,
and interdisciplinary work in computational social science and cognitive
science. We also develop a wide variety of educational materials
on NLP and many tools for the community to use, including the Stanza
toolkit which processes text in over 60 human languages.
"""
data_module = make_last_position_supervised_data_module(
tokenizer=tokenizer,
model=model,
inputs=["GO->"],
outputs=[memo_sequence])
# train
training_args = transformers.TrainingArguments(
num_train_epochs=1000.0,
output_dir="./tmp",
learning_rate=2e-3,
logging_steps=50)
trainer = ReftTrainerForCausalLM(
model=reft_model, tokenizer=tokenizer,
args=training_args, **data_module)
_ = trainer.train()
一旦完成訓(xùn)練,就可以檢查模型信息:
prompt = tokenizer("GO->", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1 # last position
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
intervene_on_prompt=True, max_new_tokens=512, do_sample=False,
eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
LoReFT的性能測(cè)試
最后我們來(lái)看看它在各種NLP基準(zhǔn)測(cè)試中的卓越表現(xiàn),以下是斯坦福大學(xué)的研究人員展示的數(shù)據(jù)。
LoReFT在8個(gè)具有挑戰(zhàn)性的數(shù)據(jù)集上獲得了最先進(jìn)的性能,包括BoolQ、PIQA、SIQA、HellaSwag、WinoGrande、ARC-e、ARC-c和OBQA。盡管使用的參數(shù)比現(xiàn)有的PEFT方法少得多(少10-50倍),但LoReFT的性能還是大大超過(guò)了所有其他方法,展示了它有效捕獲和利用大型語(yǔ)言模型中編碼的常識(shí)性知識(shí)的能力。
雖然LoReFT在數(shù)學(xué)推理任務(wù)上沒(méi)有超過(guò)現(xiàn)有的PEFT方法,但它在AQuA、GSM8K、MAWPS和SVAMP等數(shù)據(jù)集上展示了具有競(jìng)爭(zhēng)力的性能。研究人員指出LoReFT的性能隨著模型尺寸的增大而提高,這表明它的能力隨著語(yǔ)言模型的不斷增長(zhǎng)而擴(kuò)大。
在指令遵循領(lǐng)域,LoReFT取得了顯著的結(jié)果,在Alpaca-Eval v1.0基準(zhǔn)測(cè)試上優(yōu)于所有的微調(diào)方法,包括完全微調(diào)(這個(gè)要注重說(shuō)明)。當(dāng)在llama - 27b模型上訓(xùn)練時(shí),LoReFT的比GPT-3.5 Turbo模型的還要好1%,同時(shí)使用的參數(shù)比其他PEFT方法少得多。
LoReFT還展示了其在自然語(yǔ)言理解任務(wù)中的能力,當(dāng)應(yīng)用于RoBERTa-base和RoBERTa-large模型時(shí),在GLUE基準(zhǔn)測(cè)試中實(shí)現(xiàn)了與現(xiàn)有PEFT方法相當(dāng)?shù)男阅堋?/span>
當(dāng)在參數(shù)數(shù)量上與之前最有效的PEFT方法相匹配時(shí),LoReFT在各種任務(wù)中獲得了相似的分?jǐn)?shù),包括情感分析和自然語(yǔ)言推理。
總結(jié)
ReFT特別是LoReFT的成功,對(duì)自然語(yǔ)言處理的未來(lái)和大型語(yǔ)言模型的實(shí)際應(yīng)用具有重要意義。ReFT的參數(shù)效率使其成為一種使大型語(yǔ)言模型適應(yīng)特定的任務(wù)或領(lǐng)域,同時(shí)最大限度地減少計(jì)算資源和訓(xùn)練時(shí)間的有效的解決方案。
并且ReFT還提供了一個(gè)獨(dú)特的視角來(lái)增強(qiáng)大型語(yǔ)言模型的可解釋性。在常識(shí)推理、算術(shù)推理和指令遵循等任務(wù)中的成功表明該方法的有效性。目前來(lái)看ReFT有望開(kāi)啟新的可能性,克服傳統(tǒng)調(diào)優(yōu)方法的局限性。