PyTorch 訓練,除了會訓練還要了解這些
讓我們討論一下在訓練過程中幫助你進行實驗的技術(shù)。我將提供一些理論、代碼片段和完整的流程示例。主要要點包括:
- 數(shù)據(jù)集分割
- 指標
- 可重復性
- 配置、日志記錄和可視化
分割數(shù)據(jù)集
我喜歡有訓練集、驗證集和測試集的分割。這里沒什么好說的;你可以使用隨機分割,或者如果你有一個不平衡的數(shù)據(jù)集(就像在實際情況中經(jīng)常發(fā)生的那樣)——分層分割。
對于測試集,嘗試手動挑選一個“黃金數(shù)據(jù)集”,包含你希望模型擅長的所有示例。測試集應該在實驗之間保持不變。它應該只在你完成模型訓練后使用。這將在部署到生產(chǎn)環(huán)境之前給你客觀的指標。別忘了,你的數(shù)據(jù)集應該盡可能接近生產(chǎn)環(huán)境,這樣才有代表性。
指標
為你的任務(wù)選擇正確的指標至關(guān)重要。我最喜歡的錯誤使用指標的例子是 Kaggle 的“深空系外行星狩獵”數(shù)據(jù)集,在那里你可以找到很多筆記本,人們在大約有 5000 個負樣本和 50 個正樣本的嚴重不平衡的數(shù)據(jù)集上使用準確率。當然,他們得到了 99% 的準確率,并且總是預測負樣本。那樣的話,他們永遠也找不到系外行星,所以讓我們明智地選擇指標。
深入討論指標超出了本文的范圍,但我將簡要提及一些可靠的選項:
- F1 分數(shù)
- 精確度和召回率
- mAP(檢測任務(wù))
- IoU(分割任務(wù))
- 準確率(對于平衡的數(shù)據(jù)集)
- ROC-AUC
真實圖像分類問題的分數(shù)示例:
+--------+----------+--------+-----------+--------+
| split | accuracy | f1 | precision | recall |
+--------+----------+--------+-----------+--------+
| val | 0.9915 | 0.9897 | 0.9895 | 0.99 |
| test | 0.9926 | 0.9921 | 0.9927 | 0.9915 |
+--------+----------+--------+-----------+--------+
為你的任務(wù)選擇幾個指標:
def get_metrics(gt_labels: List[int], preds: List[int]) -> Dict[str, float]:
num_classes = len(set(gt_labels))
if num_classes == 2:
average = "binary"
else:
average = "macro"
metrics = {}
metrics["accuracy"] = accuracy_score(gt_labels, preds)
metrics["f1"] = f1_score(gt_labels, preds, average=average)
metrics["precision"] = precision_score(gt_labels, preds, average=average)
metrics["recall"] = recall_score(gt_labels, preds, average=average)
return metrics
此外,繪制精確度-閾值和召回率-閾值曲線,以更好地選擇置信度閾值。
可重復性
沒有可靠的可重復性,我們就不能談?wù)搶嶒?。當你不改變?nèi)魏螙|西時,你應該得到相同的結(jié)果。簡單的例子,如果你使用 torch 和 Nvidia,如何凍結(jié)所有種子:
def set_seeds(seed: int, cudnn_fixed: bool = False) -> None:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if cudnn_fixed:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
注意:cudnn_fixed 可能會影響性能。我在實驗期間使用它,然后在選擇參數(shù)后的最終訓練階段關(guān)閉它。
這是當你有 cudnn 固定并且不改變參數(shù)時會發(fā)生什么——訓練完全一樣。這就是我們希望從 cudnn_fixed 得到的結(jié)果。
現(xiàn)在你可以處理參數(shù),并確保結(jié)果的變化是因為參數(shù)的變化。
配置、日志記錄和可視化
這是人們經(jīng)常忘記的部分,但我發(fā)現(xiàn)它非常有用:
- 包含變量和超參數(shù)的配置文件
- 記錄所有指標和配置的日志
- 指標的可視化
當你有一個包含很多模塊的項目時,配置文件非常方便,你可以將變量放在一個配置文件中,并在所有模塊中使用。你還應該在那里存儲訓練配置。這里有一個我使用的 Hydra 配置的例子:
project_name: project_name
exp_name: baseline
exp: ${exp_name}_${now_dir}
train:
root: /path/to/project
device: cuda
label_to_name: {0: "class_1", 1: "class_2", 2: "class_3"}
img_size: [256, 256] # (h, w)
train_split: 0.8
val_split: 0.1 # test_split = 1 - train_split - val_split
batch_size: 64
epochs: 15
use_scheduler: True
layers_to_train: -1
num_workers: 10
threads_to_use: 10
data_path: ${train.root}/dataset
path_to_save: ${train.root}/output/models/${exp}
vis_path: ${train.root}/output/visualized
seed: 42
cudnn_fixed: False
debug_img_processing: False
export: # TensorRT must be done on the inference device
half: False
max_batch_size: 1
model_path: ${train.path_to_save}
path_to_data: ${train.root}/to_test
每次訓練會議的配置文件和指標都應該被記錄。通過 wandb(或類似的東西)集成,每次訓練會議都被記錄和可視化。
我也更喜歡保存在本地:
在訓練期間:
- 每個時代后打印驗證指標
- 如果它實現(xiàn)了最佳指標,則保存模型
- 如果調(diào)試模式開啟,則保存預處理圖像
訓練結(jié)束時:
- 保存包含最佳驗證和測試指標的 metrics.csv
訓練后:
- 保存 model.onnx,model.engine 和在模型導出期間創(chuàng)建的其他格式
- 保存顯示模型注意力的可視化
結(jié)構(gòu)示例:
output
|
├── debug_img_processing
| ├── img_1
| └── img_2
|
├── models
| ├── experiment_1
| ├── model.pt
| ├── model.engine
| ├── precision_recall_curves
| └── val_precision_recall_vs_threshold.png
| └── metrics.csv
|
└── visualized
├── class_1
├── img_1
└── img_2