自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

利用 PyTorch Lightning 搭建一個(gè)文本分類模型,你學(xué)會(huì)了嗎?

開發(fā) 前端
本項(xiàng)目展示了如何采用 PyTorch Lightning 進(jìn)行構(gòu)建、訓(xùn)練和部署文本分類模型的系統(tǒng)化方法。盡情地嘗試代碼,調(diào)整參數(shù),并試用不同的數(shù)據(jù)集或模型。

引言

在這篇博文[1]中,將逐步介紹如何使用 PyTorch Lightning 來構(gòu)建和部署一個(gè)基礎(chǔ)的文本分類模型。該項(xiàng)目借助了 PyTorch 生態(tài)中的多個(gè)強(qiáng)大工具,例如 torch、pytorch_lightning 以及 Hugging Face 提供的 transformers,從而構(gòu)建了一個(gè)強(qiáng)大且可擴(kuò)展的機(jī)器學(xué)習(xí)流程。

圖片圖片

代碼庫包含四個(gè)核心的 Python 腳本:

  • data.py:負(fù)責(zé)數(shù)據(jù)的加載和預(yù)處理工作。
  • model.py:構(gòu)建模型的結(jié)構(gòu)。
  • train.py:包含了訓(xùn)練循環(huán)和訓(xùn)練的配置。
  • inference.py:支持使用訓(xùn)練好的模型進(jìn)行推斷。

下面詳細(xì)解析每個(gè)部分,以便理解它們是如何協(xié)同作用,以實(shí)現(xiàn)文本分類的高效工作流程。

1. 數(shù)據(jù)加載與預(yù)處理

在 data.py 文件中,DataModule 類被設(shè)計(jì)用來處理數(shù)據(jù)加載和預(yù)處理的所有環(huán)節(jié)。它利用了 PyTorch Lightning 的 LightningDataModule,這有助于保持?jǐn)?shù)據(jù)處理任務(wù)的模塊化和可復(fù)用性。

class DataModule(pl.LightningDataModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

這個(gè)類在初始化時(shí)需要指定模型名稱和批量大小,并從 Hugging Face 的 Transformers 庫加載一個(gè)分詞器。prepare_data() 函數(shù)會(huì)從 GLUE 基準(zhǔn)測試套件中下載 CoLA 數(shù)據(jù)集,這個(gè)數(shù)據(jù)集經(jīng)常用來評(píng)估自然語言理解(NLU)模型的性能。

setup() 函數(shù)負(fù)責(zé)對文本數(shù)據(jù)進(jìn)行分詞處理,并創(chuàng)建用于訓(xùn)練和驗(yàn)證的 PyTorch DataLoader 對象:

def setup(self, stage=None):
    if stage == "fit" or stage is None:
        self.train_data = self.train_data.map(self.tokenize_data, batched=True)
        self.train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
        self.val_data = self.val_data.map(self.tokenize_data, batched=True)
        self.val_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

2. 模型架構(gòu)

在 model.py 文件中定義的 ColaModel 類繼承自 PyTorch Lightning 的 LightningModule。該模型采用 BERT(一種雙向編碼器表示,源自 Transformers)的簡化版本作為文本表示的核心模型。

class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=1e-2):
        super(ColaModel, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.W = nn.Linear(self.bert.config.hidden_size, 2)

模型在前向傳播過程中提取 BERT 的最終隱藏狀態(tài),并通過一個(gè)線性層來生成用于二分類的對數(shù)幾率(logits):

def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    h_cls = outputs.last_hidden_state[:, 0]
    logits = self.W(h_cls)
    return logits

另外,training_step() 和 validation_step() 函數(shù)分別負(fù)責(zé)處理訓(xùn)練和驗(yàn)證的邏輯,并記錄諸如損失和準(zhǔn)確率等關(guān)鍵指標(biāo)。

3. Training Loop

train.py 腳本利用 PyTorch Lightning 的 Trainer 類來控制訓(xùn)練過程。它還包含了模型檢查點(diǎn)和提前停止的回調(diào)機(jī)制,以防止模型過擬合。

checkpoint_callback = ModelCheckpoint(dirpath="./models", mnotallow="val_loss", mode="min")
early_stopping_callback = EarlyStopping(mnotallow="val_loss", patience=3, verbose=True, mode="min")

訓(xùn)練過程設(shè)定了最大周期數(shù),并在可能的情況下利用 GPU 進(jìn)行加速:

trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=5,
    fast_dev_run=False,
    logger=pl.loggers.TensorBoardLogger("logs/", name="cola", versinotallow=1),
    callbacks=[checkpoint_callback, early_stopping_callback],
)
trainer.fit(cola_model, cola_data)

這樣的配置不僅讓訓(xùn)練變得更加簡便,還保證了模型能夠定期保存并對其性能進(jìn)行監(jiān)控。

4. 推理

訓(xùn)練結(jié)束后,將利用模型來進(jìn)行預(yù)測。inference.py 腳本中定義了一個(gè)名為 ColaPredictor 的類,該類負(fù)責(zé)加載經(jīng)過訓(xùn)練的模型檢查點(diǎn),并提供了一個(gè)用于生成預(yù)測的方法:

class ColaPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ColaModel.load_from_checkpoint(model_path)
        self.model.eval()
        self.model.freeze()

Predict() 方法接受文本輸入,使用分詞器對其進(jìn)行處理,并返回模型的預(yù)測:

def predict(self, text):
    inference_sample = {"sentence": text}
    processed = self.processor.tokenize_data(inference_sample)
    logits = self.model(
        torch.tensor([processed["input_ids"]]),
        torch.tensor([processed["attention_mask"]]),
    )
    scores = self.softmax(logits[0]).tolist()
    predictions = [{"label": label, "score": score} for score, label in zip(scores, self.labels)]
    return predictions

總結(jié)

本項(xiàng)目展示了如何采用 PyTorch Lightning 進(jìn)行構(gòu)建、訓(xùn)練和部署文本分類模型的系統(tǒng)化方法。盡情地嘗試代碼,調(diào)整參數(shù),并試用不同的數(shù)據(jù)集或模型吧。

責(zé)任編輯:武曉燕 來源: 數(shù)據(jù)科學(xué)工廠
相關(guān)推薦

2021-08-30 09:25:25

Bert模型PyTorch語言

2021-10-04 09:29:41

對象池線程池

2023-03-26 22:02:53

APMPR監(jiān)控

2024-06-21 08:15:25

2024-06-19 09:47:21

2023-06-12 07:41:16

dockerspark集群

2023-09-19 08:03:50

rebase?merge

2023-04-27 08:42:50

效果

2022-02-08 09:09:45

智能指針C++

2024-04-01 08:13:59

排行榜MySQL持久化

2023-05-24 08:14:55

2024-11-29 08:53:46

2024-09-26 09:10:08

2024-01-02 12:05:26

Java并發(fā)編程

2023-08-01 12:51:18

WebGPT機(jī)器學(xué)習(xí)模型

2023-01-10 08:43:15

定義DDD架構(gòu)

2024-02-04 00:00:00

Effect數(shù)據(jù)組件

2023-07-26 13:11:21

ChatGPT平臺(tái)工具

2024-01-19 08:25:38

死鎖Java通信

2023-08-08 08:34:47

漏洞環(huán)境獲取方法
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)