MLX vs MPS vs CUDA:蘋果新機(jī)器學(xué)習(xí)框架的基準(zhǔn)測(cè)試
如果你是一個(gè)Mac用戶和一個(gè)深度學(xué)習(xí)愛好者,你可能希望在某些時(shí)候Mac可以處理一些重型模型。蘋果剛剛發(fā)布了MLX,一個(gè)在蘋果芯片上高效運(yùn)行機(jī)器學(xué)習(xí)模型的框架。
最近在PyTorch 1.12中引入MPS后端已經(jīng)是一個(gè)大膽的步驟,但隨著MLX的宣布,蘋果還想在開源深度學(xué)習(xí)方面有更大的發(fā)展。
在本文中,我們將對(duì)這些新方法進(jìn)行測(cè)試,在三種不同的Apple Silicon芯片和兩個(gè)支持cuda的gpu上和傳統(tǒng)CPU后端進(jìn)行基準(zhǔn)測(cè)試。
這里把基準(zhǔn)測(cè)試集中在圖卷積網(wǎng)絡(luò)(GCN)模型上。這個(gè)模型主要由線性層組成,所以對(duì)于其他的模型也應(yīng)該得到類似的結(jié)果。
創(chuàng)造環(huán)境
要為MLX構(gòu)建環(huán)境,我們必須指定是使用i386還是arm架構(gòu)。使用conda,可以使用:
CONDA_SUBDIR=osx-arm64 conda create -n mlx pythnotallow=3.10 numpy pytorch scipy requests -c conda-forge
conda activate mlx
如果檢查你的env是否實(shí)際使用了arm,下面命令的輸出應(yīng)該是arm,而不是i386(因?yàn)槲覀冇玫腁pple Silicon):
python -c "import platform; print(platform.processor())"
然后就是使用pip安裝MLX:
pip install mlx
GCN模型
GCN模型是圖神經(jīng)網(wǎng)絡(luò)(GNN)的一種,它使用鄰接矩陣(表示圖結(jié)構(gòu))和節(jié)點(diǎn)特征。它通過收集鄰近節(jié)點(diǎn)的信息來計(jì)算節(jié)點(diǎn)嵌入。每個(gè)節(jié)點(diǎn)獲得其鄰居特征的平均值。這種平均是通過將節(jié)點(diǎn)特征與標(biāo)準(zhǔn)化鄰接矩陣相乘來完成的,并根據(jù)節(jié)點(diǎn)度進(jìn)行調(diào)整。為了學(xué)習(xí)這個(gè)過程,特征首先通過線性層投射到嵌入空間中。
我們將使用MLX實(shí)現(xiàn)一個(gè)GCN層和一個(gè)GCN模型:
import mlx.nn as nn
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def __call__(self, x, adj):
x = self.linear(x)
return adj @ x
class GCN(nn.Module):
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
super(GCN, self).__init__()
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
self.gcn_layers = [
GCNLayer(in_dim, out_dim, bias)
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
]
self.dropout = nn.Dropout(p=dropout)
def __call__(self, x, adj):
for layer in self.gcn_layers[:-1]:
x = nn.relu(layer(x, adj))
x = self.dropout(x)
x = self.gcn_layers[-1](x, adj)
return x
可以看到,mlx的模型開發(fā)方式與tf2基本一樣,都是調(diào)用__call__進(jìn)行前向傳播,其實(shí)torch也一樣,只不過它自定義了一個(gè)forward函數(shù)。
下面就是訓(xùn)練:
gcn = GCN(
x_dim=x.shape[-1],
h_dim=args.hidden_dim,
out_dim=args.nb_classes,
nb_layers=args.nb_layers,
dropout=args.dropout,
bias=args.bias,
)
mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr)
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
# Training loop
for epoch in range(args.epochs):
# Loss
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
mx.eval(gcn.parameters(), optimizer.state)
# Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
在MLX中,計(jì)算是惰性的,這意味著eval()通常用于在更新后實(shí)際計(jì)算新的模型參數(shù)。而另一個(gè)關(guān)鍵函數(shù)是nn.value_and_grad(),它生成一個(gè)計(jì)算參數(shù)損失的函數(shù)。第一個(gè)參數(shù)是保存當(dāng)前參數(shù)的模型,第二個(gè)參數(shù)是用于前向傳遞和損失計(jì)算的可調(diào)用函數(shù)。它返回的函數(shù)接受與forward函數(shù)相同的參數(shù)(在本例中為forward_fn)。我們可以這樣定義這個(gè)函數(shù):
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
y_hat = gcn(x, adj)
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
return loss, y_hat
它僅僅包括計(jì)算前向傳遞和計(jì)算損失。Loss_fn()和eval_fn()定義如下:
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
return l + weight_decay * l2_reg
return l
def eval_fn(x, y):
return mx.mean(mx.argmax(x, axis=1) == y)
損失函數(shù)是計(jì)算預(yù)測(cè)和標(biāo)簽之間的交叉熵,并包括L2正則化。由于L2正則化還不是內(nèi)置特性,需要手動(dòng)實(shí)現(xiàn)。
本文的完整代碼:https://github.com/TristanBilot/mlx-GCN
可以看到除了一些細(xì)節(jié)函數(shù)調(diào)用的差別,基本的訓(xùn)練流程與pytorch和tf都很類似,但是這里的一個(gè)很好的事情是消除了顯式地將對(duì)象分配給特定設(shè)備的需要,就像我們?cè)赑yTorch中經(jīng)常使用.cuda()和.to(device)那樣。這是因?yàn)樘O果硅芯片的統(tǒng)一內(nèi)存架構(gòu),所有變量共存于同一空間,也就是說消除了CPU和GPU之間緩慢的數(shù)據(jù)傳輸,這樣也可以保證不會(huì)再出現(xiàn)與設(shè)備不匹配相關(guān)的煩人的運(yùn)行時(shí)錯(cuò)誤。
基準(zhǔn)測(cè)試
我們將使用MLX與MPS, CPU和GPU設(shè)備進(jìn)行比較。我們的測(cè)試平臺(tái)是一個(gè)2層GCN模型,應(yīng)用于Cora數(shù)據(jù)集,其中包括2708個(gè)節(jié)點(diǎn)和5429條邊。
對(duì)于MLX, MPS和CPU測(cè)試,我們對(duì)M1 Pro, M2 Ultra和M3 Max進(jìn)行基準(zhǔn)測(cè)試。在兩款NVIDIA V100 PCIe和V100 NVLINK上進(jìn)行測(cè)試。
MPS:比M1 Pro的CPU快2倍以上,在其他兩個(gè)芯片上,與CPU相比有30-50%的改進(jìn)。
MLX:比M1 Pro上的MPS快2.34倍。與MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之間沒有真正的改進(jìn)。
CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max與MLX,這里的原因可能是因?yàn)槲覀兊哪P捅容^小,所以發(fā)揮不出V100和NVLINK的優(yōu)勢(shì)(NVLINK主要GPU之間的數(shù)據(jù)傳輸大的情況下會(huì)有提高)。這也說明了蘋果的統(tǒng)一內(nèi)存架構(gòu)的確可以消除CPU和GPU之間緩慢的數(shù)據(jù)傳輸。
總結(jié)
與CPU和MPS相比,MLX可以說是非常大的金幣,在小數(shù)據(jù)量的情況下它甚至接近特斯拉V100的性能。也就是說我們可以使用MLX跑一些不是那么大的模型,比如一些表格數(shù)據(jù)。
從上面的基準(zhǔn)測(cè)試也可以看到,現(xiàn)在可以利用蘋果芯片的全部力量在本地運(yùn)行深度學(xué)習(xí)模型(我一直認(rèn)為MPS還沒發(fā)揮蘋果的優(yōu)勢(shì),這回MPS已經(jīng)證明了這一點(diǎn))。
MLX剛剛發(fā)布就已經(jīng)取得了驚人的影響力,并展示了巨大的潛力。相信未來幾年開源社區(qū)的進(jìn)一步增強(qiáng),可以期待在不久的將來更強(qiáng)大的蘋果芯片,將MLX的性能提升到一個(gè)全新的水平。
另外也說明了MPS(雖然也發(fā)布不久)還是有巨大的發(fā)展空間的,畢竟切換框架是一件很麻煩的事情,如果MPS能達(dá)到MLX 80%或者90%的速度,我想不會(huì)有人去換框架的。
最后說到框架,現(xiàn)在已經(jīng)有了Pytorch,TF,JAX,現(xiàn)在又多了一個(gè)MLX。各種設(shè)備、各種后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,現(xiàn)在又多了一個(gè)MPS。