自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

機(jī)器學(xué)習(xí) | 從0開始大模型之模型LoRA訓(xùn)練

人工智能 機(jī)器學(xué)習(xí)
LoRA 的背后的主要思想是模型微調(diào)期間權(quán)重的變化也具有較低的內(nèi)在維度,具體來說,如果W??代表單層的權(quán)重,ΔW??代表模型自適應(yīng)過程中權(quán)重的變化,作者提出ΔW??是一個(gè)低秩矩陣,即:rank(ΔW??) << min(n,k) 。

1、LoRA是如何實(shí)現(xiàn)的?

在深入了解 LoRA 之前,我們先回顧一下一些基本的線性代數(shù)概念。

1.1、秩

給定矩陣中線性獨(dú)立的列(或行)的數(shù)量,稱為矩陣的秩,記為 rank(A) 。

  • 矩陣的秩小于或等于列(或行)的數(shù)量,rank(A) ≤ min{m, n}
  • 滿秩矩陣是所有的行或者列都獨(dú)立,rank(A) = min{m, n}
  • 不滿秩矩陣是滿秩矩陣的反面是不滿秩,即 rank(A) < min(m, n),矩陣的列(或行)不是彼此線性獨(dú)立的

舉個(gè)兩個(gè)秩的例子:

不滿秩不滿秩

滿秩滿秩

1.2、秩相關(guān)屬性

從上面的秩的介紹中可以看出,矩陣的秩可以被理解為它所表示的特征空間的維度,在這種情況下,特定大小的低秩矩陣比相同維度的滿秩矩陣封裝更少的特征(或更低維的特征空間)。與之相關(guān)的屬性如下:

  • 矩陣的秩受其行數(shù)和列數(shù)中最小值的約束,rank(A) ≤ min{m, n};
  • 兩個(gè)矩陣的乘積的秩受其各自秩的最小值的約束,給定矩陣 A 和 B,其中 rank(A) = m 且 rank(A) = n,則 rank(AB) ≤ min{m, n};

1.3、LoRA

LoRA(Low rand adaption) 是微軟研究人員提出的一種高效的微調(diào)技術(shù),用于使大型模型適應(yīng)特定任務(wù)和數(shù)據(jù)集。LoRA 的背后的主要思想是模型微調(diào)期間權(quán)重的變化也具有較低的內(nèi)在維度,具體來說,如果W??代表單層的權(quán)重,ΔW??代表模型自適應(yīng)過程中權(quán)重的變化,作者提出ΔW??是一個(gè)低秩矩陣,即:rank(ΔW??) << min(n,k) 。

為什么?模型有了基座以后,如果強(qiáng)調(diào)學(xué)習(xí)少量的特征,那么就可以大大減少參數(shù)的更新量,而ΔW??就可以實(shí)現(xiàn),這樣就可以認(rèn)為ΔW??是一個(gè)低秩矩陣。

實(shí)現(xiàn)原理ΔW??是一個(gè)更新矩陣,然后ΔW??根據(jù)秩的屬性,又可以拆分兩個(gè)低秩矩陣的乘積,即:B?? 和 A?? ,其中 r << min{n,k} 。這意味著網(wǎng)絡(luò)中權(quán)重 Wx = Wx + ΔWx = Wx + B??A??x,由于 r 很小,所以 B??A?? 的參數(shù)數(shù)量非常少,所以只需要更新很少的參數(shù)。

LoRALoRA

2、peft庫

LoRA 訓(xùn)練非常方便,只需要借助 https://huggingface.co/blog/zh/peft 庫,這是 huggingface 提供的,使用方法如下:

# 引入庫
from peft import get_peft_model, LoraConfig, TaskType

# 創(chuàng)建對(duì)應(yīng)的配置
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.01,
    bias="none"
    task_type="SEQ_2_SEQ_LM",
)

# 包裝模型
model = AutoModelForSeq2SeqLM.from_pretrained(
    "t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

LoraConfig 詳細(xì)參數(shù)如下:

  • r:秩,即上面的r,默認(rèn)為8;
  • target_modules:對(duì)特定的模塊進(jìn)行微調(diào),默認(rèn)為None,支持nn.Linear、nn.Embedding和nn.Conv2d;
  • lora_alpha:ΔW 按 α / r 縮放,其中 α 是常數(shù),默認(rèn)為8;
  • task_type:任務(wù)類型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
  • lora_dropout:Dropout 概率,默認(rèn)為0,通過在訓(xùn)練過程中以 dropout 概率隨機(jī)選擇要忽略的神經(jīng)元來減少過度擬合的技術(shù);
  • bias:是否添加偏差,默認(rèn)為 "none";

3、訓(xùn)練

使用 peft 庫對(duì)SFT全量訓(xùn)練修改如下:

def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                names = name.split('.')
                lora_module_names.add(names[0] if len(names) == 1 else names[-1])

        return list(lora_module_names)

    model = Transformer(lm_config)
    ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
    state_dict = torch.load(ckp, map_locatinotallow=device_type)
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)

    target_modules = find_all_linear_names(model)
    peft_config = LoraConfig(
        r=8,
        target_modules=target_modules
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    print(f'LLM總參數(shù)量:{count_parameters(model) / 1e6:.3f} 百萬')
    model = model.to(device_type)
    return model

只需要修改模型初始化部分,其他不變,訓(xùn)練過程和之前一樣,這里不再贅述。

參考

(1)https://cloud.tencent.com/developer/article/2372297

(2)http://www.bimant.com/blog/lora-deep-dive/

(3)https://blog.csdn.net/shebao3333/article/details/134523779

責(zé)任編輯:武曉燕 來源: 周末程序猿
相關(guān)推薦

2024-11-26 09:33:44

2024-11-04 00:24:56

2025-04-03 15:40:41

機(jī)器學(xué)習(xí)大模型DeepSeek

2024-12-09 00:00:10

2025-01-10 08:38:10

2025-04-03 15:46:53

2020-08-10 15:05:02

機(jī)器學(xué)習(xí)人工智能計(jì)算機(jī)

2017-03-24 15:58:46

互聯(lián)網(wǎng)

2022-03-28 09:00:00

SQL數(shù)據(jù)庫機(jī)器學(xué)習(xí)

2023-06-24 19:59:40

2017-07-11 10:19:24

淺層模型機(jī)器學(xué)習(xí)優(yōu)化算法

2022-09-06 08:00:00

機(jī)器學(xué)習(xí)金融數(shù)據(jù)科學(xué)

2018-11-07 09:00:00

機(jī)器學(xué)習(xí)模型Amazon Sage

2024-08-20 07:55:03

2017-10-09 12:55:29

機(jī)器學(xué)習(xí)KaggleStacking

2024-06-21 11:44:17

2018-05-16 09:26:41

基線模型機(jī)器學(xué)習(xí)AI

2020-10-13 07:00:00

機(jī)器學(xué)習(xí)人工智能

2022-09-19 15:37:51

人工智能機(jī)器學(xué)習(xí)大數(shù)據(jù)

2018-03-09 09:00:00

前端JavaScript機(jī)器學(xué)習(xí)
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)