斯坦福博士圖解AlphaFold 3:超多細(xì)節(jié)+可視化還原ML工程師眼中的AF3
谷歌DeepMind的人工智能模型AlphaFold 3兩個月前橫空出世,顛覆了生物學(xué)。
這個「值得獲得諾貝爾獎的發(fā)明」不僅在學(xué)術(shù)圈引起了巨震,還轟動了制藥界——它可能帶來數(shù)千億美元的商業(yè)價值,并對藥物研發(fā)產(chǎn)生深遠(yuǎn)影響。
論文地址:https://www.nature.com/articles/s41586-024-07487-w
如此重要的AlphaFold3,其具體工作原理是什么?
因為AlphaFold3的結(jié)構(gòu)非常復(fù)雜,論文有相當(dāng)高的閱讀門檻,讓人望而卻步。兩位斯坦福大學(xué)的兩位博士生制作了一個論文的「圖解版」,比論文閱讀起來友好多了,而且還很詳盡!
每一位機器學(xué)習(xí)工程師都不應(yīng)該錯過這篇圖文并茂的文章——
博客地址:https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/
此前已經(jīng)有很多關(guān)于蛋白質(zhì)結(jié)構(gòu)預(yù)測的研究動機、CASP競賽、模型失效模式、關(guān)于評估的爭論、對生物技術(shù)的影響等主題的文章,因此以上內(nèi)容都不是這篇博文關(guān)注的重點。
這篇博文關(guān)注的重點在于AlphaFold3(以下簡稱AF3)是如何在技術(shù)上實現(xiàn)的:分子在模型中被如何表示,它們又是如何被轉(zhuǎn)換成預(yù)測結(jié)構(gòu)的?
架構(gòu)概述
AlphaFold3和前代模型最大的不同點在于——預(yù)測目標(biāo)不同。
AF3不僅預(yù)測單個蛋白質(zhì)序列(AF2)或蛋白質(zhì)復(fù)合物(AF-multimeter)的結(jié)構(gòu),還能預(yù)測蛋白質(zhì)與其他蛋白質(zhì)、核酸、小分子中的一種或多種物質(zhì)的復(fù)合結(jié)構(gòu),而且僅根據(jù)序列信息。
因此,前代的AF模型只需表示標(biāo)準(zhǔn)的氨基酸序列,但AF3需要引入更復(fù)雜的輸入類型,因此設(shè)計了更復(fù)雜的特征表示和tokenization機制。
tokenization過程會在稍后單獨描述,目前我們只需要知道,token可能代表單個氨基酸(蛋白質(zhì))、核苷酸(DNA/RNA),或者單個原子(其他物質(zhì))。
模型主要由3部分組成:
- 輸入準(zhǔn)備:給定輸入的分子序列,模型需要檢索一系列的結(jié)構(gòu)相似的分子。這一步驟會識別出這些分子,并將其編碼為數(shù)值張量。
- 表征學(xué)習(xí):給定上一步中創(chuàng)建的張量,使用注意力機制的多種變體來更新這些表征。
- 結(jié)構(gòu)預(yù)測:基于第一部分創(chuàng)建的原始輸入以及第二部分改進(jìn)后的表征,使用條件擴散進(jìn)行結(jié)構(gòu)預(yù)測。
在整個模型中,蛋白質(zhì)復(fù)合物有兩種表示形式:單一表征(single representation)和配對表征(pair representation),這兩種表示都可以應(yīng)用于token級別或原子級別。
前者僅僅表示復(fù)合物中的所有token或原子,后者則表征了物質(zhì)中所有token/原子之間的關(guān)系(如距離、潛在相互作用等)。
為了簡單起見,下述的結(jié)構(gòu)中忽略了大多數(shù)LayerNorm層,但其實它們無處不在。
輸入準(zhǔn)備
用戶向AF3提供的實際輸入是一個蛋白質(zhì)序列和可選的其他分子。
本節(jié)的目標(biāo)是將這些序列轉(zhuǎn)換成一系列6個張量,這些張量將作為模型主干的輸入.
如圖所示,這6個張量分別是:
-s(token級單一表征)
-z(token級配對表征)
-q(原子級單一表征)
-p(原子級配對表征)
-m(MSA表征)
-t(模板表征)
本節(jié)包含5個步驟,分別是tokenization、檢索、創(chuàng)建原子級表征、更新原子級表征、原子級到token級集成。
tokenization
在AF2中,由于模型只表示具有固定氨基酸集的蛋白質(zhì),因此每個氨基酸都擁有自己的token。
AF3保留了這一點,但也引入了額外的token,以便處理其他分子類型:
-標(biāo)準(zhǔn)氨基酸:1個token(同AF2)
-標(biāo)準(zhǔn)核苷酸:1個token
-非標(biāo)準(zhǔn)氨基酸或核苷酸(甲基化核苷酸、翻譯后修飾的氨基酸等):每個原子1個token
-其他分子:每個原子1個token
因此,我們可以認(rèn)為某些token(如標(biāo)準(zhǔn)氨基酸/核苷酸)對應(yīng)多個原子,而其他token (如配體中的原子)只對應(yīng)單個原子。
比如,35個標(biāo)準(zhǔn)氨基酸的序列(可能大于600個原子)將由35個token來表示;同時,一個由35個原子組成的配體也同樣由35個token表示。
檢索
AF3的早期關(guān)鍵步驟之一類似于語言模型中的檢索增強生成(Retrieval Augmented Generation,RAG)。
模型會檢索到與與輸入序列相似的序列(收集到多序列比對中,即MSA),以及與這些序列相關(guān)的任何結(jié)構(gòu)(稱為「模板」),將它們作為模型的附加輸入,分別寫作m和t。
與AF-multimer相比,這些檢索步驟中唯一的新內(nèi)容是,除了蛋白質(zhì)序列外,我們現(xiàn)在還對RNA序列進(jìn)行檢索。
請注意,這在傳統(tǒng)上并不被稱為「檢索」,因為早在RAG這個術(shù)語出現(xiàn)之前,使用結(jié)構(gòu)模板指導(dǎo)蛋白質(zhì)結(jié)構(gòu)建模就已經(jīng)是同源建模領(lǐng)域的常見做法了。
不過,盡管AlphaFold沒有明確將這一過程稱為檢索,但它確實與現(xiàn)在流行的RAG非常相似。
那么,模板(templates,t)如何表征呢?
通過模板搜索,我們獲得了每個模板的三維結(jié)構(gòu),以及有關(guān)哪些token位于哪些鏈中的信息。
首先,計算給定模板中所有token對之間的歐氏距離。如果是對應(yīng)多個原子token,則使用一個有代表性的「中心原子」來計算距離。
比如,對于氨基酸來說,中心原子是Cɑ原子,而標(biāo)準(zhǔn)核苷酸的中心原子是C1'原子。
這會為每個模板生成一個Ntoken x Ntoken大小的矩陣。不過,距離并不是用數(shù)值來表示,而是將距離離散化為「距離直方圖」。
然后,我們向每個直方圖添加元數(shù)據(jù),關(guān)于每個token屬于哪個鏈、該token是否在晶體結(jié)構(gòu)中得到解析、以及每個氨基酸內(nèi)部的局部距離信息。
然后,我們對這個矩陣進(jìn)行掩碼,只查看每條鏈內(nèi)部的距離(忽略鏈之間的距離)。根據(jù)論文的解釋,這樣做的原因是「并不嘗試選擇模板......以獲取鏈間交互的信息」。
創(chuàng)建原子級表征
為了創(chuàng)建q(原子級單一表征),我們需要提取所有原子級特征。
第一步是計算每個氨基酸、核苷酸和配體的「參考構(gòu)象」(reference conformer)。雖然我們還不知道整個復(fù)合物的結(jié)構(gòu),但我們對每個單獨組件的局部結(jié)構(gòu)有很多的先驗知識。
構(gòu)象(構(gòu)象異構(gòu)體的簡稱)是分子中原子的三維排列,通過對單鍵的旋轉(zhuǎn)角度進(jìn)行采樣而生成。
每種氨基酸都有一個「標(biāo)準(zhǔn)」構(gòu)象,但這只是該氨基酸存在的低能量構(gòu)象之一,可以通過查表找到。
不過,每個小分子都需要生成自己的構(gòu)象,利用RDKit中的ETKDGv3算法,同時結(jié)合了實驗數(shù)據(jù)和扭轉(zhuǎn)角偏好來生成三維構(gòu)象。
然后,我們將該構(gòu)象信息(相對位置)與每個原子的電荷、原子序數(shù)和其他標(biāo)識符連接起來。矩陣c存儲了序列中所有原子的這些信息。
然后,我們用c來初始化原子級別的配對表征p,以存儲原子間的相對距離。
由于我們只知道每個token內(nèi)的參考距離,因此先使用掩碼機制(v)來確保這個初始距離矩陣只代表我們在構(gòu)象生成過程中計算出的距離。
最后,我們將原子級別的單一表征復(fù)制一份,并將這個副本稱為q。這個矩陣q是我們接下來要更新的,但c確實會被保存并稍后使用。
原子Transformer
在生成了q(單個原子的表征)和 p(原子配對表征)之后,我們現(xiàn)在要根據(jù)附近的其他原子更新這些表征。
AF3使用一個名為原子Transformer的模塊,在原子級別應(yīng)用注意力機制時
原子Transformer主要遵循標(biāo)準(zhǔn)的Transformer結(jié)構(gòu)。不過,其具體步驟都經(jīng)過了調(diào)整,以處理來自c和p的額外輸入。
原子級->token級
到目前為止,所有數(shù)據(jù)都是以原子級別存儲的,而AF3的表征學(xué)習(xí)部分則從這里開始以token級別運行。
為了創(chuàng)建token級表征,我們首先將原子級表征投影到一個更大的維度(catom=128,ctoken=384)。然后對分配給同一token的所有原子取平均值。
請注意,這只適用于與標(biāo)準(zhǔn)氨基酸和核苷酸相關(guān)的原子,其余原子保持不變。
現(xiàn)在我們就從「原子空間」進(jìn)入了「token空間」。
之后將token級特征和MSA中的統(tǒng)計信息連接起來,形成矩陣sinputs并被向下投影到ctoken,作為序列的起始表征sinit。
sinit將在表征學(xué)習(xí)部分中被更新,但sinputs保持不變,用于結(jié)構(gòu)預(yù)測部分。
表征學(xué)習(xí)
經(jīng)過一系列輸入準(zhǔn)備后,我們就來到了模型的主干部分,也是完成大部分計算量的部分。
這部分模型的學(xué)習(xí)目標(biāo)是改進(jìn)上述token級別的單一(s)或成對(z)張量的初始化表示,因此被稱為「表征學(xué)習(xí)」。
這部分主要包含三類步驟:
- 模板模塊(template module):使用模板t更新張量z
- MSA模塊:更新MSA表征m,再將其引入token級別的張量z
- Pairformer:使用三角注意力更新張量s、z
以上步驟會重復(fù)運行多次,每次輸出結(jié)果后再將其反饋到自身繼續(xù)作為輸入,繼續(xù)進(jìn)行計算(如上圖中藍(lán)色需先所示),這種做法被稱為「回收」(recycling)。
模板模塊
該模塊的計算流程如下圖所示(模板個數(shù)Ntemplate=2)。
每個模板t和張量z經(jīng)過線性投影后相加得到矩陣v,再經(jīng)過一系列被稱作Pairformer Stack的操作(后文詳述)。
之后,N個模板被平均到一起,再通過另一個線性層和一次ReLU,得到最終結(jié)果。
有趣的是,這是AF3模型中唯二使用ReLU的地方之一,但論文中并沒有解釋為什么選擇ReLU而非其他非線性函數(shù)。
MSA模塊
AF3中的MSA與AF2中的Evoformer非常類似,都是在同時改進(jìn)MSA表征和配對表征,對兩者分別獨立執(zhí)行一系列操作后進(jìn)行交互。
下采樣
處理MSA表征的第一步是下采樣,而非使用之前生成的MSA的所有行(最多可達(dá) 16k)。下采樣后,還要加入經(jīng)過投影映射的單一表征s。
外積均值
之后,MSA表征m通過外積均值方法(outer product mean)被合并到配對表征中。
如下圖所示,比較MSA中的兩列揭示了有關(guān)序列中兩個位點之間的關(guān)系信息(比如進(jìn)化過程中的相關(guān)性)。
對于每對標(biāo)記索引i,j,我們迭代所有進(jìn)化序列s,獲取 ms,i和 ms,j的外積,在所有進(jìn)化序列中進(jìn)行平均。
然后,我們壓平這個外部積并將其投影回去,最后將其添加至配對表征zi,j 。
雖然每個外積僅對給定序列ms內(nèi)的值進(jìn)行操作,但取平均值時會混合序列之間的信息。這是模型中唯一能在進(jìn)化序列之間共享信息的機制。
這個方法是相對于AF2的重大改變,旨在降低Evoformer的計算復(fù)雜度。
行內(nèi)自注意力
根據(jù)MSA更新配對表征后,模型接下來根據(jù)后者更新MSA,這種特定的更新模式是原子Transformer中所述的「具有配對偏差的自注意力」的簡化版本,被稱為「僅使用配對偏差的行內(nèi)門控自注意力」(row-wise gated self-attention using only pair bias)。
這種方法受到注意力機制的啟發(fā),但并不使用查詢和鍵計算,而是直接使用存儲在配對表征z中的token間關(guān)系。
如下圖所示,在張量z中,每個cz維度的向量zi,j都表示第i個和第j個token間的關(guān)系。將z線性投影到矩陣b后,每個zi,j向量變?yōu)闃?biāo)量,就可以相當(dāng)于「注意力分?jǐn)?shù)」(attention score),用于加權(quán)平均。
最后,MSA通過一系列「三角更新」(triangle updates)和注意力機制來更新配對表征,其中「三角更新」與下面Pairformer的描述相同。
Pairformer
經(jīng)過前兩個模塊后,模板t和MSA表征m的作用就結(jié)束了,只有單一表征z和經(jīng)過更新的配對表征s進(jìn)入Pairformer并用于相互更新。
Pairformer中值得注意的是「三角更新」和「三角自注意力」方法,它們首次在AF2模型中出現(xiàn),并被保留在AF3中,而且正在被應(yīng)用到越來越多的架構(gòu)中。
為什么是「三角形」
這里的指導(dǎo)原則是三角形不等式的思想:「三角形任意兩邊之和大于或等于第三條邊?!?/span>
回想一下,張量z中的值zi,j編碼序列中位置i和j之間的關(guān)系。雖然并沒有顯式地對token間的物理距離進(jìn)行編碼,但的確包含了這層含義。
如果我們想象每個zi,j代表兩個氨基酸之間的距離,并且有zi,j=1和zj,k=1。那么根據(jù)三角形不等式,zi,k不能大于2。
「三角更新」和「三角形自注意力」的目標(biāo)就是嘗試將這些幾何約束編碼到模型中,但并不會強制執(zhí)行,而是鼓勵模型在每次更新zi,j的值時考慮所有可能的三元組(i,j,k)。
此外,z不僅代表物理距離,還編碼了token之間復(fù)雜的物理關(guān)系,因此向量zi,j是有方向的。
所以,如上圖所示,在對節(jié)點k進(jìn)行「三角更新」和「三角自注意力」操作時,需要分別查看兩種有向路徑,出邊(outgoing edge)和入邊(incoming edges)。
三角更新
從圖論角度理解「三角」操作后,我們就能明白以下的張量更新和注意力機制是如何通過張量運算實現(xiàn)的。
使用出邊進(jìn)行更新時,使用到了z的三個線性投影a、b和g。
為了更新zi,j,需要對zi,k和zj,k進(jìn)行操作,即對a中的第i行和b中的第j行進(jìn)行逐元素(element-wise)乘法,之后對所有行(不同k值)求和,再用g進(jìn)行門控。
入邊的操作與出邊類似,只是進(jìn)行了行列翻轉(zhuǎn)。
為了更新zi,j,需要對zk,i和zk,j進(jìn)行操作,即對a中的第i列和b中的第j列進(jìn)行逐元素乘法,再對所有行求和。
可以發(fā)現(xiàn),出邊和入邊的「三角更新」操作與上面標(biāo)出有向路徑的兩個紫色三角形完全對應(yīng)。
三角自注意力
接下來,分別使用出邊和入邊的「三角自注意力」更新每個zi,j值。AF3論文將這兩個過程分別稱為「圍繞起始節(jié)點」(around starting node)和「圍繞結(jié)束節(jié)點」(around ending node)。
回憶一下,典型的一維序列自注意力中,查詢、鍵和值都是原始一維序列的轉(zhuǎn)換。自注意力的二維變體——軸向注意力中,在二維矩陣的不同軸上(行,然后列)上獨立應(yīng)用一維自注意力。
以此類推,「三角自注意力」在軸向注意力的基礎(chǔ)上添加了之前討論的「三角形不等式」,通過合并所有k值的zi,k和zj,k來更新zi,j。
比如,在圍繞起始節(jié)點的情況中(下圖),為了計算注意力分?jǐn)?shù)zi,j,需要將qi,j與k矩陣中第i行每個值相乘(以確定第j列受到其他列的影響),然后加上zj,k的注意力偏置。
圍繞結(jié)束節(jié)點的情況同樣是行列對稱,為了計算zi,j,需要將qi,j與k矩陣中第i列每個值相乘,注意力偏置則來自第j列。
用四個與「三角」有關(guān)的步驟更新了配對表征z后,我們還希望用它來更新單一表征s,此處使用的方法是「帶有配對偏置的單一注意力」(single attention with pair bias,下圖)。
這個方法在輸入準(zhǔn)備部分的原子Transformer中也被用到,區(qū)別只在于一個是token級別,一個是原子級別。
由于在token級別上運行,因此這里用到的是完全注意力,而非塊內(nèi)稀疏模式。
以上是Pairformer的所有計算流程。經(jīng)過48個塊的重復(fù)計算后,我們就得到了張量strunk和ztrunk。
結(jié)構(gòu)預(yù)測
擴散模塊
每一個去噪擴散步驟都根據(jù)輸入序列的多個表征來調(diào)整預(yù)測。
AF3論文將其擴散過程分為3個步驟,包括從token到原子、回到token、再回到原子。
準(zhǔn)備token級條件張量
為了初始化token級條件表征,我們將ztrunk與相對位置編碼拼接,然后將這個更大的表征向下投影,并通過幾個帶有殘差連接的轉(zhuǎn)換模塊。
同樣地,對于token級的單一表征,我們將模型的初始表征(sinputs)和我們當(dāng)前表征(strunk)拼接起來,然后投影回原始大小。
然后,我們根據(jù)當(dāng)前的擴散時間步長創(chuàng)建一個傅立葉編碼模型,將其添加到單一表征中,并將該組合通過多個轉(zhuǎn)換模塊。
通過在這里的條件輸入中加入擴散時間步長,可以確保模型在進(jìn)行去噪預(yù)測時能夠意識到擴散過程的時間步長,從而預(yù)測出該時間步長下需要消除噪聲的合適規(guī)模。
準(zhǔn)備原子級張量,應(yīng)用原子級注意力,并聚合回token級
此時,條件向量存儲的是每個token級別的信息,但我們還希望在原子級別上運行注意力。
為了解決這個問題,我們采用編碼部分(c 和 p)中創(chuàng)建的輸入的初始原子級表征,并根據(jù)當(dāng)前的token級表征更新它們,以創(chuàng)建原子級條件張量。
接下來,使用數(shù)據(jù)方差來縮放原子的當(dāng)前坐標(biāo)(x),從而有效地創(chuàng)建出具有單位方差(稱為r)的「無量綱」坐標(biāo)。
然后我們根據(jù)r更新q,這樣q就包含了原子的當(dāng)前位置信息。最后,使用原子Transformer更新q,并將原子聚合為我們之前看到的token。
在token級應(yīng)用注意力機制
這一步的目標(biāo)是運用注意力機制更新原子坐標(biāo)和序列信息的token級表征a。這一步使用的是輸入準(zhǔn)備過程中可視化的擴散Transformer,它與原子Transformer相同,但針對token。
在原子級應(yīng)用注意力機制來預(yù)測原子級噪聲更新
現(xiàn)在,我們回到原子空間,使用原子Transformer和經(jīng)過更新的a來對q進(jìn)行更新。
與上一步一樣,我們廣播token表征,使其與開始時的原子數(shù)量相匹配,然后運行原子Transformer。
最重要的是,最后一層線性層將原子級表征q映射回R。這是關(guān)鍵的一步:我們使用所有這些條件表征為所有原子生成坐標(biāo)更新rupdate。
由于是在「無量綱」空間rl中生成這些更新的,因此需要進(jìn)行重新縮放,將 rupdate中的更新轉(zhuǎn)換為為非單位方差的形式xupdate,再引入到xl中。
至此,我們就完成了對AlphaFold 3主要架構(gòu)的介紹!
此外,作者還提供了一些有關(guān)損失函數(shù)機器訓(xùn)練細(xì)節(jié)的補充信息,好學(xué)的朋友可以去博文中一看究竟。
ML工程師眼中的AF3
在對AF3的架構(gòu)及其與AF2的比較進(jìn)行了如此詳盡的介紹之后,作者沒有止步于此,而是將之與更廣泛的機器學(xué)習(xí)趨勢相關(guān)聯(lián),這一點很有意思。
作者提到了AF與檢索增強生成、Pair-Bias注意力機制、自監(jiān)督訓(xùn)練、分類與回歸、LSTM循環(huán)架構(gòu)、交叉蒸餾之間的關(guān)系。
作者簡介
本文作者是來自斯坦福大學(xué)的兩位博士生Elana Simon和Jake Silberg。
Elana Simon本科從哈佛大學(xué)畢業(yè)并獲得計算機科學(xué)領(lǐng)域?qū)W士學(xué)位,曾在谷歌、Facebook等機構(gòu)實習(xí),并擔(dān)任過Reveries Labs的機器學(xué)習(xí)工程師。
她目前是博士二年級,由James Zou指導(dǎo),從事機器學(xué)習(xí)和生物學(xué)(免疫學(xué)和基因?qū)W)交叉領(lǐng)域的研究。
Jake Silberg本科畢業(yè)于哈佛大學(xué)社會研究專業(yè),輔修全球健康和衛(wèi)生政策,碩士畢業(yè)于斯坦福大學(xué)統(tǒng)計系數(shù)據(jù)科學(xué)專業(yè)。
他曾在麥肯錫擔(dān)任高級業(yè)務(wù)分析師,目前在斯坦福大學(xué)攻讀生物醫(yī)學(xué)數(shù)據(jù)科學(xué)方面的博士學(xué)位,研究興趣包括主動學(xué)習(xí)、人機交互,以及Transformer在衛(wèi)星和醫(yī)學(xué)圖像中的應(yīng)用。