Pytorch的編譯新特性TorchDynamo的工作原理和使用示例
在深度學習中,優(yōu)化模型性能至關重要,特別是對于需要快速執(zhí)行和實時推斷的應用。而PyTorch在平衡動態(tài)圖執(zhí)行與高性能方面常常面臨挑戰(zhàn)。傳統(tǒng)的PyTorch優(yōu)化技術在處理動態(tài)計算圖時效果有限,導致訓練時間延長和模型性能不佳。TorchDynamo是一種為PyTorch設計的即時(JIT)編譯器,通過在運行時攔截Python代碼、優(yōu)化它,并編譯成高效的機器代碼來解決這一問題。本文通過使用合成數據集展示了TorchDynamo的實際應用,包括特征工程、超參數調整、交叉驗證和評估指標。
TorchDynamo簡介
TorchDynamo 是一個由 PyTorch 團隊開發(fā)的編譯器前端,它旨在自動優(yōu)化 PyTorch 程序以提高運行效率。TorchDynamo 的工作原理是在運行時動態(tài)分析和轉換 PyTorch 的代碼,然后將其轉發(fā)給各種后端編譯器(如 TorchScript、TVM、Triton 等),從而實現性能的提升。
特別是在需要實時執(zhí)行的應用中,如自動駕駛或金融預測等,深度學習模型要求快速執(zhí)行。傳統(tǒng)的優(yōu)化技術經常需要在處理Python的動態(tài)特性時進行修訂,這正是TorchDynamo的強項所在。它能夠即時捕獲計算圖,針對特定的工作負載和硬件應用優(yōu)化,從而減少延遲并提高吞吐量。
TorchDynamo的另外一個突出特點是其易于集成。重寫整個代碼庫以集成新工具可能是一項艱巨的任務。但是TorchDynamo僅需要對現有的PyTorch工作流進行最小的更改。它的簡單性和強大的優(yōu)化能力使它成為經驗豐富的研究人員和行業(yè)專業(yè)人士的有力選擇。
將 TorchDynamo 集成到現有的 PyTorch 程序中相對簡單,只需要在程序中導入 TorchDynamo 并使用它來包裝模型的執(zhí)行部分。
import torch
import torchdynamo
# 定義模型和優(yōu)化器
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# 使用 TorchDynamo 優(yōu)化模型的訓練過程
def training_step(input, target):
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
return loss
# 使用 torchdynamo.optimize 包裝訓練步驟
optimized_training_step = torchdynamo.optimize(training_step)
# 訓練循環(huán)
for input, target in data_loader:
loss = optimized_training_step(input, target)
TorchDynamo的工作原理
TorchDynamo通過追蹤PyTorch代碼的執(zhí)行,動態(tài)地捕獲計算圖。這個過程涉及理解代碼的依賴關系和流程,使其能夠識別優(yōu)化的機會。應用優(yōu)化
一旦捕獲了計算圖,TorchDynamo就會應用各種優(yōu)化技術。這些技術包括操作符融合,它將多個操作合并為一個單一操作以減少開銷,以及改進內存管理,最小化數據移動并有效地重用資源。
優(yōu)化計算圖口,TorchDynamo將其編譯成高效的機器碼。這種編譯可以針對不同的后端,如TorchScript或NVFuser,以確保代碼在可用的硬件上以最佳方式運行。
在最后的執(zhí)行階段。與最初的Python代碼相比,上面的優(yōu)化可以顯著提高性能。JIT編譯確保在運行時期間應用這些優(yōu)化,使執(zhí)行適應不同的工作負載和輸入數據。
使用示例
下面我們演示了使用一個合成數據集的TorchDynamo示例,包括特征工程,超參數調優(yōu),交叉驗證,預測和結果解釋。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from torch import _dynamo as torchdynamo
from typing import List
# Generate synthetic dataset
np.random.seed(42)
torch.manual_seed(42)
# Feature engineering: create synthetic data
n_samples = 1000
n_features = 10
X = np.random.rand(n_samples, n_features)
y = X @ np.random.rand(n_features) + np.random.rand(n_samples) * 0.1 # Linear relation with noise
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# Convert to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)
# Define the model
class SimpleNN(nn.Module):
def __init__(self, input_dim):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# Hyperparameters
input_dim = X_train.shape[1]
learning_rate = 0.001
n_epochs = 100
# Initialize the model, loss function, and optimizer
model = SimpleNN(input_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Define custom compiler
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def train_and_evaluate(model, criterion, optimizer, X_train, y_train, X_test, y_test, n_epochs):
# Training loop with K-Fold Cross-Validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
train_losses_per_epoch = np.zeros(n_epochs)
val_losses_per_epoch = np.zeros(n_epochs)
kf_count = 0
for train_idx, val_idx in kf.split(X_train):
X_kf_train, X_kf_val = X_train[train_idx], X_train[val_idx]
y_kf_train, y_kf_val = y_train[train_idx], y_train[val_idx]
for epoch in range(n_epochs):
model.train()
optimizer.zero_grad()
y_pred_train = model(X_kf_train)
train_loss = criterion(y_pred_train, y_kf_train)
train_loss.backward()
optimizer.step()
model.eval()
y_pred_val = model(X_kf_val)
val_loss = criterion(y_pred_val, y_kf_val)
train_losses_per_epoch[epoch] += train_loss.item()
val_losses_per_epoch[epoch] += val_loss.item()
kf_count += 1
# Average losses over K-Folds
train_losses_per_epoch /= kf_count
val_losses_per_epoch /= kf_count
# Evaluate on test data
model.eval()
y_pred_test = model(X_test)
test_loss = criterion(y_pred_test, y_test).item()
test_r2 = r2_score(y_test.detach().numpy(), y_pred_test.detach().numpy())
return train_losses_per_epoch, val_losses_per_epoch, test_loss, test_r2
# Run training and evaluation with TorchDynamo optimization
train_losses, val_losses, test_loss, test_r2 = train_and_evaluate(model, criterion, optimizer, X_train, y_train, X_test, y_test, n_epochs)
# Print metrics
print(f"Test MSE: {test_loss:.4f}")
print(f"Test R^2: {test_r2:.4f}")
# Plot results
epochs = list(range(1, n_epochs + 1))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()
我們使用PyTorch定義了一個具有兩個隱藏層的簡單神經網絡。模型使用K-Fold交叉驗證來確保穩(wěn)健的性能。TorchDynamo用于優(yōu)化訓練循環(huán)。在單獨的測試集上對模型進行評估,并計算MSE和R2等指標。
最后得到的訓練和驗證損失如下
我們在代碼中my_compiler打印了TorchDynamo相關的內容,我們來看看里面到底是什么:
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ---------------------- ----------------------------------- ------------------------------------------------- --------
call_function train_losses_per_epoch <Wrapped function <original zeros>> (100,) {}
call_function val_losses_per_epoch <Wrapped function <original zeros>> (100,) {}
output output output ((train_losses_per_epoch, val_losses_per_epoch),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------------- ------------------------------------------------------- ---------------- --------
placeholder l_x_ L_x_ () {}
call_module l__self___fc1 L__self___fc1 (l_x_,) {}
call_function x <built-in method relu of type object at 0x792eaaa81760> (l__self___fc1,) {}
call_module l__self___fc2 L__self___fc2 (x,) {}
call_function x_1 <built-in method relu of type object at 0x792eaaa81760> (l__self___fc2,) {}
call_module x_2 L__self___fc3 (x_1,) {}
output output output ((x_2,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ----------------------- -------------------------------------------------------- ------------------------------------------ --------------------------------
placeholder grad L_self_param_groups_0_params_0_grad () {}
placeholder grad_1 L_self_param_groups_0_params_1_grad () {}
placeholder grad_2 L_self_param_groups_0_params_2_grad () {}
placeholder grad_3 L_self_param_groups_0_params_3_grad () {}
placeholder grad_4 L_self_param_groups_0_params_4_grad () {}
placeholder grad_5 L_self_param_groups_0_params_5_grad () {}
get_attr param self___param_groups_0__params___0 () {}
get_attr param_1 self___param_groups_0__params___1 () {}
get_attr param_2 self___param_groups_0__params___2 () {}
get_attr param_3 self___param_groups_0__params___3 () {}
get_attr param_4 self___param_groups_0__params___4 () {}
get_attr param_5 self___param_groups_0__params___5 () {}
get_attr exp_avg self___state_list_L__self___state_keys____0___exp_avg () {}
get_attr exp_avg_1 self___state_list_L__self___state_keys____1___exp_avg () {}
get_attr exp_avg_2 self___state_list_L__self___state_keys____2___exp_avg () {}
get_attr exp_avg_3 self___state_list_L__self___state_keys____3___exp_avg () {}
get_attr exp_avg_4 self___state_list_L__self___state_keys____4___exp_avg () {}
get_attr exp_avg_5 self___state_list_L__self___state_keys____5___exp_avg () {}
get_attr exp_avg_sq self___state_list_L__self___state_keys____0___exp_avg_sq () {}
get_attr exp_avg_sq_1 self___state_list_L__self___state_keys____1___exp_avg_sq () {}
get_attr exp_avg_sq_2 self___state_list_L__self___state_keys____2___exp_avg_sq () {}
get_attr exp_avg_sq_3 self___state_list_L__self___state_keys____3___exp_avg_sq () {}
get_attr exp_avg_sq_4 self___state_list_L__self___state_keys____4___exp_avg_sq () {}
get_attr exp_avg_sq_5 self___state_list_L__self___state_keys____5___exp_avg_sq () {}
get_attr step_t self___state_list_L__self___state_keys____0___step () {}
get_attr step_t_2 self___state_list_L__self___state_keys____1___step () {}
get_attr step_t_4 self___state_list_L__self___state_keys____2___step () {}
get_attr step_t_6 self___state_list_L__self___state_keys____3___step () {}
get_attr step_t_8 self___state_list_L__self___state_keys____4___step () {}
get_attr step_t_10 self___state_list_L__self___state_keys____5___step () {}
call_function step <built-in function iadd> (step_t, 1) {}
call_method lerp_ lerp_ (exp_avg, grad, 0.09999999999999998) {}
call_method mul_ mul_ (exp_avg_sq, 0.999) {}
call_method conj conj (grad,) {}
call_method addcmul_ addcmul_ (mul_, grad, conj) {'value': 0.0010000000000000009}
call_function pow_1 <built-in function pow> (0.9, step) {}
call_function bias_correction1 <built-in function sub> (1, pow_1) {}
call_function pow_2 <built-in function pow> (0.999, step) {}
call_function bias_correction2 <built-in function sub> (1, pow_2) {}
call_function step_size <built-in function truediv> (0.001, bias_correction1) {}
call_method step_size_neg neg (step_size,) {}
call_method bias_correction2_sqrt sqrt (bias_correction2,) {}
call_method sqrt_1 sqrt (exp_avg_sq,) {}
call_function mul <built-in function mul> (bias_correction2_sqrt, step_size_neg) {}
call_function truediv_1 <built-in function truediv> (sqrt_1, mul) {}
call_function truediv_2 <built-in function truediv> (1e-08, step_size_neg) {}
call_method denom add_ (truediv_1, truediv_2) {}
call_method addcdiv_ addcdiv_ (param, exp_avg, denom) {}
call_function step_1 <built-in function iadd> (step_t_2, 1) {}
call_method lerp__1 lerp_ (exp_avg_1, grad_1, 0.09999999999999998) {}
call_method mul__1 mul_ (exp_avg_sq_1, 0.999) {}
call_method conj_1 conj (grad_1,) {}
call_method addcmul__1 addcmul_ (mul__1, grad_1, conj_1) {'value': 0.0010000000000000009}
call_function pow_3 <built-in function pow> (0.9, step_1) {}
call_function bias_correction1_1 <built-in function sub> (1, pow_3) {}
call_function pow_4 <built-in function pow> (0.999, step_1) {}
call_function bias_correction2_1 <built-in function sub> (1, pow_4) {}
call_function step_size_1 <built-in function truediv> (0.001, bias_correction1_1) {}
call_method step_size_neg_1 neg (step_size_1,) {}
call_method bias_correction2_sqrt_1 sqrt (bias_correction2_1,) {}
call_method sqrt_3 sqrt (exp_avg_sq_1,) {}
call_function mul_1 <built-in function mul> (bias_correction2_sqrt_1, step_size_neg_1) {}
call_function truediv_4 <built-in function truediv> (sqrt_3, mul_1) {}
call_function truediv_5 <built-in function truediv> (1e-08, step_size_neg_1) {}
call_method denom_1 add_ (truediv_4, truediv_5) {}
call_method addcdiv__1 addcdiv_ (param_1, exp_avg_1, denom_1) {}
call_function step_2 <built-in function iadd> (step_t_4, 1) {}
call_method lerp__2 lerp_ (exp_avg_2, grad_2, 0.09999999999999998) {}
call_method mul__2 mul_ (exp_avg_sq_2, 0.999) {}
call_method conj_2 conj (grad_2,) {}
call_method addcmul__2 addcmul_ (mul__2, grad_2, conj_2) {'value': 0.0010000000000000009}
call_function pow_5 <built-in function pow> (0.9, step_2) {}
call_function bias_correction1_2 <built-in function sub> (1, pow_5) {}
call_function pow_6 <built-in function pow> (0.999, step_2) {}
call_function bias_correction2_2 <built-in function sub> (1, pow_6) {}
call_function step_size_2 <built-in function truediv> (0.001, bias_correction1_2) {}
call_method step_size_neg_2 neg (step_size_2,) {}
call_method bias_correction2_sqrt_2 sqrt (bias_correction2_2,) {}
call_method sqrt_5 sqrt (exp_avg_sq_2,) {}
call_function mul_2 <built-in function mul> (bias_correction2_sqrt_2, step_size_neg_2) {}
call_function truediv_7 <built-in function truediv> (sqrt_5, mul_2) {}
call_function truediv_8 <built-in function truediv> (1e-08, step_size_neg_2) {}
call_method denom_2 add_ (truediv_7, truediv_8) {}
call_method addcdiv__2 addcdiv_ (param_2, exp_avg_2, denom_2) {}
call_function step_3 <built-in function iadd> (step_t_6, 1) {}
call_method lerp__3 lerp_ (exp_avg_3, grad_3, 0.09999999999999998) {}
call_method mul__3 mul_ (exp_avg_sq_3, 0.999) {}
call_method conj_3 conj (grad_3,) {}
call_method addcmul__3 addcmul_ (mul__3, grad_3, conj_3) {'value': 0.0010000000000000009}
call_function pow_7 <built-in function pow> (0.9, step_3) {}
call_function bias_correction1_3 <built-in function sub> (1, pow_7) {}
call_function pow_8 <built-in function pow> (0.999, step_3) {}
call_function bias_correction2_3 <built-in function sub> (1, pow_8) {}
call_function step_size_3 <built-in function truediv> (0.001, bias_correction1_3) {}
call_method step_size_neg_3 neg (step_size_3,) {}
call_method bias_correction2_sqrt_3 sqrt (bias_correction2_3,) {}
call_method sqrt_7 sqrt (exp_avg_sq_3,) {}
call_function mul_3 <built-in function mul> (bias_correction2_sqrt_3, step_size_neg_3) {}
call_function truediv_10 <built-in function truediv> (sqrt_7, mul_3) {}
call_function truediv_11 <built-in function truediv> (1e-08, step_size_neg_3) {}
call_method denom_3 add_ (truediv_10, truediv_11) {}
call_method addcdiv__3 addcdiv_ (param_3, exp_avg_3, denom_3) {}
call_function step_4 <built-in function iadd> (step_t_8, 1) {}
call_method lerp__4 lerp_ (exp_avg_4, grad_4, 0.09999999999999998) {}
call_method mul__4 mul_ (exp_avg_sq_4, 0.999) {}
call_method conj_4 conj (grad_4,) {}
call_method addcmul__4 addcmul_ (mul__4, grad_4, conj_4) {'value': 0.0010000000000000009}
call_function pow_9 <built-in function pow> (0.9, step_4) {}
call_function bias_correction1_4 <built-in function sub> (1, pow_9) {}
call_function pow_10 <built-in function pow> (0.999, step_4) {}
call_function bias_correction2_4 <built-in function sub> (1, pow_10) {}
call_function step_size_4 <built-in function truediv> (0.001, bias_correction1_4) {}
call_method step_size_neg_4 neg (step_size_4,) {}
call_method bias_correction2_sqrt_4 sqrt (bias_correction2_4,) {}
call_method sqrt_9 sqrt (exp_avg_sq_4,) {}
call_function mul_4 <built-in function mul> (bias_correction2_sqrt_4, step_size_neg_4) {}
call_function truediv_13 <built-in function truediv> (sqrt_9, mul_4) {}
call_function truediv_14 <built-in function truediv> (1e-08, step_size_neg_4) {}
call_method denom_4 add_ (truediv_13, truediv_14) {}
call_method addcdiv__4 addcdiv_ (param_4, exp_avg_4, denom_4) {}
call_function step_5 <built-in function iadd> (step_t_10, 1) {}
call_method lerp__5 lerp_ (exp_avg_5, grad_5, 0.09999999999999998) {}
call_method mul__5 mul_ (exp_avg_sq_5, 0.999) {}
call_method conj_5 conj (grad_5,) {}
call_method addcmul__5 addcmul_ (mul__5, grad_5, conj_5) {'value': 0.0010000000000000009}
call_function pow_11 <built-in function pow> (0.9, step_5) {}
call_function bias_correction1_5 <built-in function sub> (1, pow_11) {}
call_function pow_12 <built-in function pow> (0.999, step_5) {}
call_function bias_correction2_5 <built-in function sub> (1, pow_12) {}
call_function step_size_5 <built-in function truediv> (0.001, bias_correction1_5) {}
call_method step_size_neg_5 neg (step_size_5,) {}
call_method bias_correction2_sqrt_5 sqrt (bias_correction2_5,) {}
call_method sqrt_11 sqrt (exp_avg_sq_5,) {}
call_function mul_5 <built-in function mul> (bias_correction2_sqrt_5, step_size_neg_5) {}
call_function truediv_16 <built-in function truediv> (sqrt_11, mul_5) {}
call_function truediv_17 <built-in function truediv> (1e-08, step_size_neg_5) {}
call_method denom_5 add_ (truediv_16, truediv_17) {}
call_method addcdiv__5 addcdiv_ (param_5, exp_avg_5, denom_5) {}
output output output ((),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------------- ------------------------------------------------------- ---------------- --------
placeholder s0 s0 () {}
placeholder l_x_ L_x_ () {}
call_module l__self___fc1 L__self___fc1 (l_x_,) {}
call_function x <built-in method relu of type object at 0x792eaaa81760> (l__self___fc1,) {}
call_module l__self___fc2 L__self___fc2 (x,) {}
call_function x_1 <built-in method relu of type object at 0x792eaaa81760> (l__self___fc2,) {}
call_module x_2 L__self___fc3 (x_1,) {}
output output output ((x_2,),) {}
FX圖的輸出表明了模型的結構和操作是如何組織的:
輸入0和L_x_是表示輸入數據的占位符。
模型通過全連接層L__self___fc1 , L__self___fc2,L__self___fc3傳遞輸入,這是神經網絡的三層。
在前兩層之后應用ReLU激活函數。
在第三層完全連接后產生最終輸出。
總結
對于研究人員和工程師來說,訓練大型和復雜的模型可能很耗時。TorchDynamo通過優(yōu)化計算圖和加速執(zhí)行來減少這種訓練時間,允許在更短的時間內進行更多的迭代和實驗。在需要實時處理的應用程序中,如視頻流或交互式人工智能系統(tǒng),延遲是至關重要的。TorchDynamo在運行時優(yōu)化和編譯代碼的能力確保了這些系統(tǒng)可以平穩(wěn)運行并快速響應新數據。
TorchDynamo在支持多個后端和硬件架構方面的靈活性使其非常適合在各種環(huán)境中部署。無論是在高性能gpu或邊緣設備上運行,TorchDynamo適應提供最佳性能。