蘋果讓大模型學會偷懶:更快吐出第一個token,準確度還保住了
Llama 3.1 剛剛發(fā)布,你是否已經(jīng)嘗試了呢?就算你的個人計算機是最近的頂尖配置,運行其中最小的 8B 版本可能也依然會有明顯延遲。為了提升模型的推理效率,研究者想出了多種多樣的方法,但其中很多都會讓模型犧牲一些準確度。
近日,蘋果和 Meta AI 的一個研究團隊提出了一種新方法,可在保證準確度不明顯下降的同時,將 Llama 2 預填充階段的推理速度提升到原來的 2 倍以上,這或許能為 Llama 3.1 的加速提供一些啟發(fā)。他們把這種方法稱為 LazyLLM,即懶惰大型語言模型。
- 論文標題:LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference
- 論文地址:https://arxiv.org/abs/2407.14057
那么他們是怎么讓 LLM 偷懶的呢?要理解他們的方法,我們首先需要知道標準的基于 prompt 的 LLM 推理過程是怎樣的。簡單來說,該過程分為兩個階段:預填充和解碼,如圖 1 所示。
在預填充階段,模型計算和保存 prompt 中每個 token 的 KV 緩存,并預測首個 token。我們將預填充階段所耗費的時間稱為「首個 token 時間(TTFT)」。
預填充階段之后是解碼階段。在這個階段,模型再次使用緩存的 KV 來迭代式地解碼下一個 token,直到滿足停止標準。
在預填充階段,所有 Transformer 層都會使用 prompt 中的所有 token。當 prompt 較長時,TTFT 可能很慢,因為當前最佳的基于 Transformer 的 LLM 既深又寬,并且計算注意力的成本會隨 prompt 中 token 數(shù)量而呈二次增長。舉個例子,Llama 2(7B 版本)堆疊了 32 層 Transformer,模型維度為 4096。在這種情況下,TTFT 需要的 walltime 是每個后續(xù)解碼步驟的 21 倍,在 LongBench 基準上這些時間大約占用了總生成時間的 23%。
因此,要讓 LLM 推理高效進行,優(yōu)化 TTFT 是非常關鍵的步驟。
盡管 LLM 推理優(yōu)化方面是一個活躍的研究領域,但很多方法關注的重心都是提升解碼階段的推理速度。研究者很少關注 TTFT 的改進。一些基于壓縮的研究成果可通過減少 LLM 的大小隱式地提升 TTFT。
另一個研究方向是在靜態(tài)的 Transformer 架構下實現(xiàn)對 TTFT 的改進。對于這個研究方向,很自然會引出一個問題:在生成首個 token 時,所有 prompt token 都必不可少嗎?
圖 2 給出了在 LongBench 基準上的 LLM 分析結果。
可以看到,對于首個生成的 token,輸入 token 的注意力分數(shù)非常稀疏,這說明輸入 prompt 中的許多 token 是多余的,就算移除也不會影響到下一 token 預測。這一觀察正是該團隊提出 LazyLLM 的基礎。
LazyLLM 的優(yōu)勢包括適用范圍廣、無需訓練、效果好。圖 3 對比了標準 LLM 與 LazyLLM。
LazyLLM
圖 4 展示了 LazyLLM 的整體框架。
從完整上下文開始,LazyLLM 會逐漸對 token 進行剪枝,從而逐漸減少得到最終模型所使用的計算數(shù)量。請注意,LazyLLM 允許模型在不同的生成步驟選取不同的 token 子集,即便它們中的一些可能在之前的步驟中被剪枝了。相比于靜態(tài)剪枝(一次性對所有 token 進行剪枝),動態(tài)剪枝會在每個生成步驟對下一 token 預測進行優(yōu)化,這有助于維持模型的性能表現(xiàn)。
漸進式 token 剪枝
之前也有一些研究成功使用過 token 剪枝來優(yōu)化 LLM 推理。但是,這些方法需要積累預測前幾個 token 的完整注意力圖,以便在剪枝開始之前分析 prompt token 的重要性。也因此,它們不適合用于降低 TTFT,因為它們在預填充階段仍需要計算所有 KV 緩存。
相較之下,LazyLLM 「很懶」,會從推理的第一輪迭代(預填充步驟)開始,只計算對預測下一 token 重要的 token。
在第一輪迭代中,一大關鍵難題是確定各個 token 的重要性。受之前已有研究(其中表明 token 隱藏狀態(tài)會在穿過 Transformer 層時發(fā)生演進)的啟發(fā),該團隊的解決方案是在每個生成步驟使用逐層 token 剪枝。具體來說,他們是使用各層的注意力圖來確定輸入 token 對將要預測的 token 的重要性。
在計算了 token 的置信度分數(shù)之后,另一個難題是確定剪枝 token 的閾值。
具體來說,對于不同的層和不同的任務,該閾值可能會隨注意力分數(shù)的變化而改變。該團隊的解決思路是使用 top-k 百分位數(shù)選取策略。具體來說,如果一個 token 的置信度分數(shù)小于輸入 token 中的第 k 個百分位數(shù),便將其剪枝掉。一旦 token 被剪枝去掉了,它就不再參與所有后續(xù)層的計算。
也就是說,后續(xù)層使用的 token 是之前層所使用 token 的子集。
后面的實驗表明,剪枝層的位置和剪枝的 token 數(shù)量不同時,也會導致性能發(fā)生變化。具體來說,對于同一 Transformer 層,隨著被剪枝去掉的 token 越來越多,模型的性能也會逐漸下降。
他們還發(fā)現(xiàn),相比于早期層的剪枝,在后期層執(zhí)行剪枝時會得到更好的性能,這說明后期層對 token 剪枝的敏感度更低。為了更好地平衡速度與準確度,該團隊使用了如圖 4 所示的漸進式剪枝法,從而在早期層保留更多 token,然后在 token 流向后期層的過程中逐漸減少 token 的數(shù)量。
Aux Cache(輔助緩存)
預填充階段沒有 KV 緩存,每個 token 都表示成隱藏狀態(tài)。因此,可通過移除已被剪枝 token 的隱藏狀態(tài)來實現(xiàn)漸進式 token 剪枝。但是,要將漸進式 token 剪枝擴展到后續(xù)的解碼步驟,卻并不簡單。原因是每個解碼步驟都會使用預填充階段計算的 KV 緩存來計算注意力。由于 LazyLLM 是在預填充階段執(zhí)行漸進式 token 剪枝,因此在某一層被剪枝的 token 的 KV 不會出現(xiàn)在下一層的 KV 緩存中。
這里提醒一下,LazyLLM 框架允許在每一步讓每個生成步驟從完整的輸入 token 序列中挑選一個不同的 token 子集,無論它們是否已在之前的步驟中被剪枝。舉個例子,在接下來的解碼步驟中,那些在 KV 緩存中不存在的已被剪枝的 token 可能會被重新選取出來用于計算注意力。在這種情況下,模型無法檢索到這些 token 的 KV 緩存。
對此,一個基于直覺的解決方案是再讓這些 token 通過該 Transformer 的起點。但是,這會導致對同一 token 的重復計算,并最終減慢整體的生成速度。
為解決這個難題,該團隊在原有的 KV 緩存之外引入了另一種緩存:Aux Cache(輔助緩存)。
如果已被剪枝 token(如圖 4 中 T4 和 T7)的 KV 并未出現(xiàn)在后續(xù)層的 KV 緩存中,則會由 Aux Cache 保存它們的隱藏狀態(tài)以供后續(xù)迭代檢索。
如圖 4 所示,在每個解碼步驟,每個 Transformer 層首先會檢索過去 token 的 KV 緩存(如果存在的話)。對于那些不在 KV 緩存中的 token,則直接從其前一層的 Aux Cache 中檢索它們的隱藏狀態(tài),而不必再次經(jīng)過之前的層。Aux Cache 可確保每個 token 在每個 Transformer 層中最多被計算一次,還能確保 LazyLLM 最慢時也比標準 LLM 快。
實驗
該團隊在兩個大型語言模型上檢驗了這種「懶惰」新方法:Llama 2 7B 和 XGen 7B。作為對比的標準 LLM 是同樣的公開發(fā)布的預訓練檢查點模型,同時不進行任何附加訓練。
實驗基準是 LongBench,這是一個針對長內容理解的多任務基準。LongBench 基準包含 16 個數(shù)據(jù)集,涉及 6 個任務,包括單文檔問答、多文檔問答、總結、少樣本學習、合成任務和代碼補全。
評估指標是每種方法在 TTFT 加速與準確度權衡方面的效果和效率。
結果
表 1 給出了 LazyLLM、標準 LLM 和其它基線方法的 TTFT 加速和準確度結果。
在此表中,baseline 是指標準 LLM 推理。random token drop 是指對 token 執(zhí)行隨機剪枝。static token pruning 是指在預填充階段基于前面幾個 Transformer 層的注意力方法來對輸入 token 執(zhí)行一次性剪枝。Prompt Compression 就是 prompt 壓縮方法,也就是使用 LLM 去除輸入上下文中的冗余。
從表 1 可以看到,LazyLLM 在 TTFT 加速方面全面優(yōu)勝,同時準確度方面的下降基本可以忽略不計。需要指出,使用 LLM 來壓縮 prompt 需要大量計算。因此,即使 Prompt Compression 能讓推理速度更快,但其實際的 TTFT 卻比標準 LLM 還長。
對總體生成速度的影響
為了評估新方法對總體生成速度的影響,該團隊分析了計算使用的 prompt token 百分比和生成加速情況,見表 2。
可以看到,LazyLLM 計算使用的 token 的占比總是低于 100%,這說明 LazyLLM 在生成結束時也沒有用完 prompt 中的所有 token,但理論上講該模型可以使用所有 token。這能為不同任務的整體生成過程提供額外的加速。
不同層的丟棄率
該團隊也分析了剪枝層的位置和被剪枝 token 的數(shù)量的影響。結果見圖 6。
可以看到,當在同一 Transformer 層進行剪枝時,留下的 token 越少,模型的性能越差。這也符合我們的直觀認知。此外,相比于在更前期 Transformer 層執(zhí)行剪枝,在后期層進行剪枝會得到更好的性能,這說明后期層對 token 剪枝的敏感度更低。
基于這些觀察,可以說漸進式 token 剪枝的效果得到了證明。
漸進式 KV 增長
最后,該團隊也嘗試了理解使用 token 剪枝邏輯的模型的內部情況。具體來說,他們想要了解 prompt token 中的累積使用比例以及相應的不被使用的比例。這種「累積 token 使用量」可以等價地定義成每一步的 KV 緩存 大小。圖 7 給出了 LazyLLM 的每個階段這些累積的 prompt token 使用量。
該結果支持這一假設:許多 token 永遠不會被模型選擇(即便理論上講模型可以使用 prompt 中的所有 token。
考慮到模型依然能維持執(zhí)行任務的準確度,因此可以得出結論:模型可以有效地丟棄不影響輸出質量的 token。