自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程

發(fā)布于 2025-4-14 01:31
瀏覽
0收藏

背景

在理解與學會了Naive RAG的框架流程后,就很自然地關注到embedding模型,與問題相關的文本召回,也有很多論文在做這方面的創(chuàng)新。

以前一直不知道embedding模型是如何微調出來的,一直聽說是微調BERT,但是不知道是怎么微調出來的。直到在B站上看到bge模型微調的視頻[參考資料4]才理解。

于是便想著自己也微調出一個 embedding模型。涉及到下面三個階段:

  • 數據集制作
  • 模型訓練
  • 評估

微調實戰(zhàn)

安裝包 

pip install -U FlagEmbedding[finetune]

項目基于 https://github.com/FlagOpen/FlagEmbedding,若遇到環(huán)境報錯,可參考該項目的環(huán)境,完成python環(huán)境設置

FlagEmbedding論文:C-Pack: Packed Resources For General Chinese Embeddings , 也稱 C-METB

介紹 

你可以閱讀參考資料[1]和[2],先嘗試實現一次官方的微調教程。

官方微調的模型是BAAI/bge-large-en-v1.5,我選擇直接微調BERT模型,這樣感受微調的效果更明顯。僅僅是出于學習的目的,我才選擇微調BERT,如果大家打算用于生產環(huán)境,還是要選擇微調現成的embedding模型。因為embedding模型也分為預訓練與微調兩個階段,我們不做預訓練。

embedding 模型需要通過encode方法把文本變成向量,而BERT模型沒有encode方法。故要使用FlagEmbedding導入原生的BERT模型。

from FlagEmbedding.inference.embedder.encoder_only.base import BaseEmbedder

# 省略數據集加載代碼

bert_embedding = BaseEmbedder("bert-base-uncased")
# get the embedding of the corpus
corpus_embeddings = bert_embedding.encode(corpus_dataset["text"])

print("shape of the corpus embeddings:", corpus_embeddings.shape)
print("data type of the embeddings: ", corpus_embeddings.dtype)

可瀏覽:eval_raw_bert.ipynb

項目文件介紹

數據集構建:

  • build_train_dataset.ipynb?: 構建訓練集數據,隨機采樣負樣本數據通過修改?neg_num?的值,構架了training_neg_10.json和training_neg_50.json兩個訓練的數據集,比較增加負樣本的數量是否能提高模型召回的效果(實驗結果表明:這里的效果并不好,提升不明顯)。
  • build_eval_dataset.ipynb: 構建測試集數據,評估模型召回的效果。與FlagEmbedding數據集構建結構不同,我個人用這種數據集樣式更方便,不需要像FlagEmbedding一樣從下標讀出正確的樣本的數據。

模型訓練:

  • finetune_neg10.sh
  • finetune_neg50.sh

finetune_neg10.sh的代碼如下:

torchrun --nproc_per_node=1 \
    -m FlagEmbedding.finetune.embedder.encoder_only.base \
    --model_name_or_path bert-base-uncased \
    --train_data ./ft_data/training_neg_10.json \
    --train_group_size 8 \
    --query_max_len 512 \
    --passage_max_len 512 \
    --pad_to_multiple_of 8 \
    --query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \
    --query_instruction_format '{}{}' \
    --output_dir ./output/bert-base-uncased_neg10 \
    --overwrite_output_dir \
    --learning_rate 1e-5 \
    --fp16 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --warmup_ratio 0.1 \
    --logging_steps 200 \
    --save_steps 2000 \
    --temperature 0.02 \
    --sentence_pooling_method cls \
    --normalize_embeddings True \
    --kd_loss_type kl_div

bash finetune_neg10.sh > finetune_neg10.log 2>&1 & 把訓練的日志保存到 finetune_neg10.log 日志文件中,訓練用時6分鐘。

neg10代表每條數據10個負樣本,neg50代表每條數據50個負樣本。

評估:

評估是在所有語料上完成的評估,并不是在指定的固定數量的負樣本上完成的評估。由于是在全部語料上完成召回,故使用到了faiss向量數據庫。

  • eval_raw_bert.ipynb: 評估BERT原生模型
  • eval_train_neg10.ipynb: 評估基于10條負樣本微調后的模型
  • eval_train_neg50.ipynb: 評估基于50條負樣本微調后的模型
  • eval_bge_m3.ipynb: 評估 BAAI 現在表現效果好的 BGE-M3 模型

結論:通過評估結果,可看出BERT經過微調后的提升明顯,但依然達不到BGE-M3 模型的效果。

微調硬件配置要求 

微調過程中GPU顯存占用達到了9G左右

設備只有一臺GPU

debug 重要代碼分析【選看】

下述代碼是舊版本的代碼,不是最新的FlagEmbedding的代碼:

  • 視頻教程,bge模型微調流程:https://www.bilibili.com/video/BV1eu4y1x7ix/

推薦使用23年10月份的代碼進行debug,關注核心代碼。新版的加了抽象類與繼承,增加了很多額外的東西,使用早期版本debug起來更聚焦一些。

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程-AI.x社區(qū)

由于需要傳遞參數再運行腳本,需要在pycharm配置一些與運行相關的參數:

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程-AI.x社區(qū)

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程-AI.x社區(qū)

下述是embedding計算損失的核心代碼,這里的query與passage都是batch_size數量的輸入,如果只是一條query與passage,大家理解起來就容易很多。由于這里是batch_size數量的輸入,代碼中涉及到矩陣運算會給大家?guī)砝斫饫щy。

比較難理解的是下述代碼,這里的target 其實就是label:

target = torch.arange(
                scores.size(0), device=scores.device, dtype=torch.long
            )
target = target * (p_reps.size(0) // q_reps.size(0))

p_reps 是相關文本矩陣, q_reps 是問題矩陣。每一個問題都對應固定數量的相關文本。p_reps.size(0) // q_reps.size(0) 是每個問題對應的相關文本的數量。下一行的target 乘以 相關文本的塊數,得到query對應的 Gold Truth(也稱 pos 文本)的下標,因為在每個相關文本中,第一個位置都是正確文本,其后是負樣本,這些 Gold Truth 下標之間的距離是固定,通過乘法就可以計算出每個 Gold Truth 的下標。

額外補充【選看】:

在微調的過程中,不要錯誤的以為每個問題只和自己的相關文本計算score。真實的情況是,在batch_size的數據中,每個問題會與所有的相關文本計算score。根據上述代碼可看出 target 最大的取值是:query的數量 x 相關文本數量,這也印證了每個問題會與所有的相關文本都計算score。故我們在隨機采樣負樣本的時候,負樣本數量設置的太小也不用太擔心,因為在計算過程中負樣本的數量會乘以 batch_size。

【注意】:query的數量 = batch_size

  • 損失函數

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程-AI.x社區(qū)

image-20250405174449800

從零開始微調Embedding模型:基于BERT的實戰(zhàn)教程-AI.x社區(qū)

image-20250407143100543

def compute_loss(self, scores, target):
    return self.cross_entropy(scores, target)

C-METB 論文中,關于損失函數的介紹,公式看起來很復雜,本質就是cross_entropy。

資源分享

上述的代碼開源在github平臺,為了不增大github倉庫的容量,數據集沒有上傳到github平臺。若希望直接獲得完整的項目文件夾,從下述提供的網盤分享鏈接進行下載:

  • github開源地址:https://github.com/JieShenAI/csdn/tree/main/25/04/embedding_finetune
  • 通過網盤分享的文件:embedding_finetune.zip鏈接: https://pan.baidu.com/s/1CDRpkkjS1-0jtmIBiTWx1A 提取碼: free

最新的代碼,請以 github 的鏈接為準,網盤分享的文件,本意只是為了存儲數據,避免增加github倉庫的容量

CSDN: https://jieshen.blog.csdn.net/article/details/147043668

參考資料

[1] BAAI官方微調教程: ??https://github.com/FlagOpen/FlagEmbedding/blob/master/Tutorials/7_Fine-tuning/7.1.2_Fine-tune.ipynb??

[2] BAAI官方評估教程:??https://github.com/FlagOpen/FlagEmbedding/blob/master/Tutorials/4_Evaluation/4.1.1_Evaluation_MSMARCO.ipynb??

[3] 多文檔知識圖譜問答:??https://jieshen.blog.csdn.net/article/details/146390208??

[4] bge模型微調流程:??https://www.bilibili.com/video/BV1eu4y1x7ix/??

[5] FlagEmbedding 舊版本可用于debug的代碼:https://github.com/FlagOpen/FlagEmbedding/blob/9b6e521bcb7583ed907f044ca092daef0ee90431/FlagEmbedding/baai_general_embedding/finetune/run.py

本文轉載自??AI悠閑區(qū)??,作者:jieshenai

已于2025-4-14 10:07:43修改
收藏
回復
舉報
回復
相關推薦