機(jī)器學(xué)習(xí) | 從0開始大模型之模型LoRA訓(xùn)練
1、LoRA是如何實(shí)現(xiàn)的?
在深入了解 LoRA 之前,我們先回顧一下一些基本的線性代數(shù)概念。
1.1、秩
給定矩陣中線性獨(dú)立的列(或行)的數(shù)量,稱為矩陣的秩,記為 rank(A) 。
- 矩陣的秩小于或等于列(或行)的數(shù)量,rank(A) ≤ min{m, n}
- 滿秩矩陣是所有的行或者列都獨(dú)立,rank(A) = min{m, n}
- 不滿秩矩陣是滿秩矩陣的反面是不滿秩,即 rank(A) < min(m, n),矩陣的列(或行)不是彼此線性獨(dú)立的
舉個(gè)兩個(gè)秩的例子:
不滿秩
滿秩
1.2、秩相關(guān)屬性
從上面的秩的介紹中可以看出,矩陣的秩可以被理解為它所表示的特征空間的維度,在這種情況下,特定大小的低秩矩陣比相同維度的滿秩矩陣封裝更少的特征(或更低維的特征空間)。與之相關(guān)的屬性如下:
- 矩陣的秩受其行數(shù)和列數(shù)中最小值的約束,rank(A) ≤ min{m, n};
- 兩個(gè)矩陣的乘積的秩受其各自秩的最小值的約束,給定矩陣 A 和 B,其中 rank(A) = m 且 rank(A) = n,則 rank(AB) ≤ min{m, n};
1.3、LoRA
LoRA(Low rand adaption) 是微軟研究人員提出的一種高效的微調(diào)技術(shù),用于使大型模型適應(yīng)特定任務(wù)和數(shù)據(jù)集。LoRA 的背后的主要思想是模型微調(diào)期間權(quán)重的變化也具有較低的內(nèi)在維度,具體來說,如果W??代表單層的權(quán)重,ΔW??代表模型自適應(yīng)過程中權(quán)重的變化,作者提出ΔW??是一個(gè)低秩矩陣,即:rank(ΔW??) << min(n,k) 。
為什么?模型有了基座以后,如果強(qiáng)調(diào)學(xué)習(xí)少量的特征,那么就可以大大減少參數(shù)的更新量,而ΔW??就可以實(shí)現(xiàn),這樣就可以認(rèn)為ΔW??是一個(gè)低秩矩陣。
實(shí)現(xiàn)原理ΔW??是一個(gè)更新矩陣,然后ΔW??根據(jù)秩的屬性,又可以拆分兩個(gè)低秩矩陣的乘積,即:B?? 和 A?? ,其中 r << min{n,k} 。這意味著網(wǎng)絡(luò)中權(quán)重 Wx = Wx + ΔWx = Wx + B??A??x,由于 r 很小,所以 B??A?? 的參數(shù)數(shù)量非常少,所以只需要更新很少的參數(shù)。
LoRA
2、peft庫
LoRA 訓(xùn)練非常方便,只需要借助 https://huggingface.co/blog/zh/peft 庫,這是 huggingface 提供的,使用方法如下:
# 引入庫
from peft import get_peft_model, LoraConfig, TaskType
# 創(chuàng)建對(duì)應(yīng)的配置
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q", "v"],
lora_dropout=0.01,
bias="none"
task_type="SEQ_2_SEQ_LM",
)
# 包裝模型
model = AutoModelForSeq2SeqLM.from_pretrained(
"t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
LoraConfig 詳細(xì)參數(shù)如下:
- r:秩,即上面的r,默認(rèn)為8;
- target_modules:對(duì)特定的模塊進(jìn)行微調(diào),默認(rèn)為None,支持nn.Linear、nn.Embedding和nn.Conv2d;
- lora_alpha:ΔW 按 α / r 縮放,其中 α 是常數(shù),默認(rèn)為8;
- task_type:任務(wù)類型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
- lora_dropout:Dropout 概率,默認(rèn)為0,通過在訓(xùn)練過程中以 dropout 概率隨機(jī)選擇要忽略的神經(jīng)元來減少過度擬合的技術(shù);
- bias:是否添加偏差,默認(rèn)為 "none";
3、訓(xùn)練
使用 peft 庫對(duì)SFT全量訓(xùn)練修改如下:
def init_model():
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return list(lora_module_names)
model = Transformer(lm_config)
ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
state_dict = torch.load(ckp, map_locatinotallow=device_type)
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
target_modules = find_all_linear_names(model)
peft_config = LoraConfig(
r=8,
target_modules=target_modules
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(f'LLM總參數(shù)量:{count_parameters(model) / 1e6:.3f} 百萬')
model = model.to(device_type)
return model
只需要修改模型初始化部分,其他不變,訓(xùn)練過程和之前一樣,這里不再贅述。
參考
(1)https://cloud.tencent.com/developer/article/2372297
(2)http://www.bimant.com/blog/lora-deep-dive/
(3)https://blog.csdn.net/shebao3333/article/details/134523779