自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

MLX vs MPS vs CUDA:蘋果新機(jī)器學(xué)習(xí)框架的基準(zhǔn)測(cè)試

人工智能
如果你是一個(gè)Mac用戶和一個(gè)深度學(xué)習(xí)愛好者,你可能希望在某些時(shí)候Mac可以處理一些重型模型。蘋果剛剛發(fā)布了MLX,一個(gè)在蘋果芯片上高效運(yùn)行機(jī)器學(xué)習(xí)模型的框架。

如果你是一個(gè)Mac用戶和一個(gè)深度學(xué)習(xí)愛好者,你可能希望在某些時(shí)候Mac可以處理一些重型模型。蘋果剛剛發(fā)布了MLX,一個(gè)在蘋果芯片上高效運(yùn)行機(jī)器學(xué)習(xí)模型的框架。

最近在PyTorch 1.12中引入MPS后端已經(jīng)是一個(gè)大膽的步驟,但隨著MLX的宣布,蘋果還想在開源深度學(xué)習(xí)方面有更大的發(fā)展。

在本文中,我們將對(duì)這些新方法進(jìn)行測(cè)試,在三種不同的Apple Silicon芯片和兩個(gè)支持cuda的gpu上和傳統(tǒng)CPU后端進(jìn)行基準(zhǔn)測(cè)試。

這里把基準(zhǔn)測(cè)試集中在圖卷積網(wǎng)絡(luò)(GCN)模型上。這個(gè)模型主要由線性層組成,所以對(duì)于其他的模型也應(yīng)該得到類似的結(jié)果。

創(chuàng)造環(huán)境

要為MLX構(gòu)建環(huán)境,我們必須指定是使用i386還是arm架構(gòu)。使用conda,可以使用:

CONDA_SUBDIR=osx-arm64 conda create -n mlx pythnotallow=3.10 numpy pytorch scipy requests -c conda-forge
 conda activate mlx

如果檢查你的env是否實(shí)際使用了arm,下面命令的輸出應(yīng)該是arm,而不是i386(因?yàn)槲覀冇玫腁pple Silicon):

python -c "import platform; print(platform.processor())"

然后就是使用pip安裝MLX:

pip install mlx

GCN模型

GCN模型是圖神經(jīng)網(wǎng)絡(luò)(GNN)的一種,它使用鄰接矩陣(表示圖結(jié)構(gòu))和節(jié)點(diǎn)特征。它通過收集鄰近節(jié)點(diǎn)的信息來計(jì)算節(jié)點(diǎn)嵌入。每個(gè)節(jié)點(diǎn)獲得其鄰居特征的平均值。這種平均是通過將節(jié)點(diǎn)特征與標(biāo)準(zhǔn)化鄰接矩陣相乘來完成的,并根據(jù)節(jié)點(diǎn)度進(jìn)行調(diào)整。為了學(xué)習(xí)這個(gè)過程,特征首先通過線性層投射到嵌入空間中。

我們將使用MLX實(shí)現(xiàn)一個(gè)GCN層和一個(gè)GCN模型:

import mlx.nn as nn
 
 class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
 
    def __call__(self, x, adj):
        x = self.linear(x)
        return adj @ x
 
 class GCN(nn.Module):
    def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
        super(GCN, self).__init__()
 
        layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
        self.gcn_layers = [
            GCNLayer(in_dim, out_dim, bias)
            for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]
        self.dropout = nn.Dropout(p=dropout)
 
    def __call__(self, x, adj):
        for layer in self.gcn_layers[:-1]:
            x = nn.relu(layer(x, adj))
            x = self.dropout(x)
 
        x = self.gcn_layers[-1](x, adj)
        return x

可以看到,mlx的模型開發(fā)方式與tf2基本一樣,都是調(diào)用__call__進(jìn)行前向傳播,其實(shí)torch也一樣,只不過它自定義了一個(gè)forward函數(shù)。

下面就是訓(xùn)練:

gcn = GCN(
    x_dim=x.shape[-1],
    h_dim=args.hidden_dim,
    out_dim=args.nb_classes,
    nb_layers=args.nb_layers,
    dropout=args.dropout,
    bias=args.bias,
 )
 mx.eval(gcn.parameters())
 
 optimizer = optim.Adam(learning_rate=args.lr)
 loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
 
 # Training loop
 for epoch in range(args.epochs):
 
    # Loss
    (loss, y_hat), grads = loss_and_grad_fn(
        gcn, x, adj, y, train_mask, args.weight_decay
    )
    optimizer.update(gcn, grads)
    mx.eval(gcn.parameters(), optimizer.state)
 
    # Validation
    val_loss = loss_fn(y_hat[val_mask], y[val_mask])
    val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,計(jì)算是惰性的,這意味著eval()通常用于在更新后實(shí)際計(jì)算新的模型參數(shù)。而另一個(gè)關(guān)鍵函數(shù)是nn.value_and_grad(),它生成一個(gè)計(jì)算參數(shù)損失的函數(shù)。第一個(gè)參數(shù)是保存當(dāng)前參數(shù)的模型,第二個(gè)參數(shù)是用于前向傳遞和損失計(jì)算的可調(diào)用函數(shù)。它返回的函數(shù)接受與forward函數(shù)相同的參數(shù)(在本例中為forward_fn)。我們可以這樣定義這個(gè)函數(shù):

def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
    y_hat = gcn(x, adj)
    loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
    return loss, y_hat

它僅僅包括計(jì)算前向傳遞和計(jì)算損失。Loss_fn()和eval_fn()定義如下:

def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
    l = mx.mean(nn.losses.cross_entropy(y_hat, y))
 
    if weight_decay != 0.0:
        assert parameters != None, "Model parameters missing for L2 reg."
 
        l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
        return l + weight_decay * l2_reg
 
    return l
 
 def eval_fn(x, y):
    return mx.mean(mx.argmax(x, axis=1) == y)

損失函數(shù)是計(jì)算預(yù)測(cè)和標(biāo)簽之間的交叉熵,并包括L2正則化。由于L2正則化還不是內(nèi)置特性,需要手動(dòng)實(shí)現(xiàn)。

本文的完整代碼:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些細(xì)節(jié)函數(shù)調(diào)用的差別,基本的訓(xùn)練流程與pytorch和tf都很類似,但是這里的一個(gè)很好的事情是消除了顯式地將對(duì)象分配給特定設(shè)備的需要,就像我們?cè)赑yTorch中經(jīng)常使用.cuda()和.to(device)那樣。這是因?yàn)樘O果硅芯片的統(tǒng)一內(nèi)存架構(gòu),所有變量共存于同一空間,也就是說消除了CPU和GPU之間緩慢的數(shù)據(jù)傳輸,這樣也可以保證不會(huì)再出現(xiàn)與設(shè)備不匹配相關(guān)的煩人的運(yùn)行時(shí)錯(cuò)誤。

基準(zhǔn)測(cè)試

我們將使用MLX與MPS, CPU和GPU設(shè)備進(jìn)行比較。我們的測(cè)試平臺(tái)是一個(gè)2層GCN模型,應(yīng)用于Cora數(shù)據(jù)集,其中包括2708個(gè)節(jié)點(diǎn)和5429條邊。

對(duì)于MLX, MPS和CPU測(cè)試,我們對(duì)M1 Pro, M2 Ultra和M3 Max進(jìn)行基準(zhǔn)測(cè)試。在兩款NVIDIA V100 PCIe和V100 NVLINK上進(jìn)行測(cè)試。

MPS:比M1 Pro的CPU快2倍以上,在其他兩個(gè)芯片上,與CPU相比有30-50%的改進(jìn)。

MLX:比M1 Pro上的MPS快2.34倍。與MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之間沒有真正的改進(jìn)。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max與MLX,這里的原因可能是因?yàn)槲覀兊哪P捅容^小,所以發(fā)揮不出V100和NVLINK的優(yōu)勢(shì)(NVLINK主要GPU之間的數(shù)據(jù)傳輸大的情況下會(huì)有提高)。這也說明了蘋果的統(tǒng)一內(nèi)存架構(gòu)的確可以消除CPU和GPU之間緩慢的數(shù)據(jù)傳輸。

總結(jié)

與CPU和MPS相比,MLX可以說是非常大的金幣,在小數(shù)據(jù)量的情況下它甚至接近特斯拉V100的性能。也就是說我們可以使用MLX跑一些不是那么大的模型,比如一些表格數(shù)據(jù)。

從上面的基準(zhǔn)測(cè)試也可以看到,現(xiàn)在可以利用蘋果芯片的全部力量在本地運(yùn)行深度學(xué)習(xí)模型(我一直認(rèn)為MPS還沒發(fā)揮蘋果的優(yōu)勢(shì),這回MPS已經(jīng)證明了這一點(diǎn))。

MLX剛剛發(fā)布就已經(jīng)取得了驚人的影響力,并展示了巨大的潛力。相信未來幾年開源社區(qū)的進(jìn)一步增強(qiáng),可以期待在不久的將來更強(qiáng)大的蘋果芯片,將MLX的性能提升到一個(gè)全新的水平。

另外也說明了MPS(雖然也發(fā)布不久)還是有巨大的發(fā)展空間的,畢竟切換框架是一件很麻煩的事情,如果MPS能達(dá)到MLX 80%或者90%的速度,我想不會(huì)有人去換框架的。

最后說到框架,現(xiàn)在已經(jīng)有了Pytorch,TF,JAX,現(xiàn)在又多了一個(gè)MLX。各種設(shè)備、各種后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,現(xiàn)在又多了一個(gè)MPS。

責(zé)任編輯:華軒 來源: DeepHub IMBA
相關(guān)推薦

2020-05-18 07:00:00

性能測(cè)試壓力測(cè)試負(fù)載測(cè)試

2011-06-08 16:59:04

性能測(cè)試載測(cè)試壓力測(cè)試

2017-11-23 22:32:18

框架ScrumXP

2017-09-26 11:25:00

AI

2017-09-11 10:55:22

PythonWeb框架

2018-06-05 11:30:22

數(shù)據(jù)科學(xué)機(jī)器學(xué)習(xí)統(tǒng)計(jì)學(xué)

2017-05-10 09:26:41

機(jī)器學(xué)習(xí)深度學(xué)習(xí)

2021-11-08 10:25:39

機(jī)器人疫苗人工智能

2023-10-09 12:36:25

AI模型

2013-05-06 10:40:40

工程師IBM機(jī)器人

2023-03-01 11:18:59

人工智能機(jī)器學(xué)習(xí)

2021-09-01 09:19:03

人工智能機(jī)器學(xué)習(xí)數(shù)據(jù)分析

2018-07-03 15:59:14

KerasPyTorch深度學(xué)習(xí)

2009-12-02 15:11:04

Vs.Net 2010

2018-07-06 08:58:53

機(jī)器人人工智能系統(tǒng)

2024-03-04 13:21:00

模型訓(xùn)練

2020-05-06 14:19:53

大數(shù)據(jù)數(shù)據(jù)科學(xué)機(jī)器學(xué)習(xí)

2022-08-18 09:42:02

人工智能機(jī)器學(xué)習(xí)

2021-08-15 21:36:00

框架開發(fā)JavaScript
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)