蘋(píng)果為自家芯片打造開(kāi)源框架MLX,實(shí)現(xiàn)Llama 7B并在M2 Ultra上運(yùn)行
2020 年 11 月,蘋(píng)果推出 M1 芯片,其速度之快、功能之強(qiáng)大一時(shí)令人驚艷。2022 年蘋(píng)果又推出了 M2,今年 10 月,M3 芯片正式登場(chǎng)。
蘋(píng)果在發(fā)布芯片的同時(shí),也非常注重在其芯片上訓(xùn)練和部署 AI 模型的能力。
蘋(píng)果推出的 ML Compute 可用于在 Mac 上進(jìn)行 TensorFlow 模型的訓(xùn)練。PyTorch 則支持在 M1 版本的 Mac 上進(jìn)行 GPU 加速的 PyTorch 機(jī)器學(xué)習(xí)模型訓(xùn)練,使用蘋(píng)果 Metal Performance Shaders (MPS) 作為后端來(lái)實(shí)現(xiàn)。這些使得 Mac 用戶(hù)能夠在本地訓(xùn)練神經(jīng)網(wǎng)絡(luò)。
現(xiàn)在,蘋(píng)果宣布推出專(zhuān)門(mén)在 Apple 芯片上用于機(jī)器學(xué)習(xí)的開(kāi)源陣列框架 ——MLX。
MLX 是專(zhuān)門(mén)為機(jī)器學(xué)習(xí)研究人員設(shè)計(jì)的,旨在有效地訓(xùn)練和部署 AI 模型。框架本身的設(shè)計(jì)在概念上也很簡(jiǎn)單。研究人員能夠輕松地?cái)U(kuò)展和改進(jìn) MLX,以快速探索、測(cè)試新的想法。MLX 的設(shè)計(jì)靈感來(lái)自 NumPy、PyTorch、Jax 和 ArrayFire 等框架。
項(xiàng)目地址:https://github.com/ml-explore/mlx
MLX 項(xiàng)目貢獻(xiàn)者之一、Apple 機(jī)器學(xué)習(xí)研究團(tuán)隊(duì)(MLR)研究科學(xué)家 Awni Hannun 展示了一段使用 MLX 框架實(shí)現(xiàn) Llama 7B 并在 M2 Ultra 上運(yùn)行的視頻。
MLX 迅速引起機(jī)器學(xué)習(xí)研究人員的關(guān)注。TVM、MXNET、XGBoost 作者,CMU 助理教授,OctoML CTO 陳天奇轉(zhuǎn)推表示:「蘋(píng)果芯片又有新的深度學(xué)習(xí)框架了。」
有網(wǎng)友評(píng)價(jià) MLX 稱(chēng),蘋(píng)果再次「重造了輪子」。
圖源:https://twitter.com/ofervic/status/1732305883814596953
MLX 特性、示例
在該項(xiàng)目中,我們可以看到,MLX 有以下一些主要特性。
熟悉的 API。MLX 擁有非常像 NumPy 的 Python API,以及功能齊備的 C++ API(與 Python API 非常相似)。MLX 還有更高級(jí)的包(比如 mlx.nn 和 mlx.optimizers),它們的 API 很像 PyTorch,可以簡(jiǎn)化構(gòu)建更復(fù)雜的模型。
可組合函數(shù)變換。MLX 擁有自動(dòng)微分、自動(dòng)矢量化和計(jì)算圖優(yōu)化的可組合函數(shù)變換。
惰性計(jì)算。MLX 中的計(jì)算是惰性的,陣列只有在需要時(shí)才被實(shí)例化。
動(dòng)態(tài)圖構(gòu)建。MLX 中的計(jì)算圖構(gòu)建是動(dòng)態(tài)的,改變函數(shù)參數(shù)的形狀不會(huì)導(dǎo)致編譯變慢,并且 debug 很簡(jiǎn)單、容易上手。
多設(shè)備。任何支持的設(shè)備上(如 CPU 和 GPU)都可以運(yùn)行操作。
統(tǒng)一內(nèi)存。MLX 與其他框架的顯著差異在于統(tǒng)一內(nèi)存,陣列共享內(nèi)存。MLX 上的操作可以在任何支持的設(shè)備類(lèi)型上運(yùn)行,無(wú)需移動(dòng)數(shù)據(jù)。
此外,項(xiàng)目中提供了多種使用 MLX 框架的示例,比如 MNIST 示例可以很好地讓你學(xué)習(xí)如何使用 MLX。
圖源:https://github.com/ml-explore/mlx-examples/tree/main/mnist
MLX 還有其他更多有用的示例,包括如下:
- Transformer 語(yǔ)言模型訓(xùn)練;
- LLaMA 大規(guī)模文本生成和 LoRA 微調(diào);
- Stable Diffusion 生成圖片;
- OpenAI 的 Whisper 語(yǔ)音識(shí)別。
更詳細(xì)的文檔可參閱:https://ml-explore.github.io/mlx/build/html/install.html#