僅用4塊GPU、不到3天訓練出「開源版GPT-4o」,這是國內團隊最新研究
以 ChatGPT 為代表的大型語言模型(LLM)已成為強大的通用任務解決器,但大多數(shù) LLM 僅支持基于文本的交互,這限制了它們在不適合文本輸入輸出的場景中的應用。GPT-4o 的出現(xiàn)使得通過語音與 LLM 進行交互成為可能。然而,開源社區(qū)對于構建此類基于 LLM 的語音交互模型仍然缺乏探索。
實現(xiàn)與 LLM 進行語音交互最簡單的方法是采用基于自動語音識別(ASR)和語音合成(TTS)模型的級聯(lián)系統(tǒng),其中 ASR 模型將用戶的語音指令轉錄為文本, TTS 模型將 LLM 的響應合成為語音。
然而,由于級聯(lián)系統(tǒng)依次輸出轉錄文本、文本響應和語音響應,整個系統(tǒng)往往具有較高的延遲。相比之下,一些多模態(tài)語音 - 語言模型將語音離散化為 token 并擴展 LLM 的詞表以支持語音輸入和輸出。這種語音 - 語言模型理論上可以直接從語音指令生成語音響應,無需生成中間文本,從而實現(xiàn)極低的響應延遲。然而,在實踐中,由于涉及語音之間復雜的映射,直接語音到語音的生成通常極具挑戰(zhàn)性。
為了解決上述問題,來自中國科學院計算技術研究所、中國科學院大學的研究者提出了一種新型模型架構 ——LLaMA-Omni,它可以實現(xiàn)與 LLM 的低延遲、高質量交互。
- 論文:https://arxiv.org/pdf/2409.06666
- 代碼:https://github.com/ictnlp/LLaMA-Omni
- 模型:https://huggingface.co/ICTNLP/Llama-3.1-8B-Omni
- 論文標題:LLaMA-Omni: Seamless Speech Interaction with Large Language Models
LLaMA-Omni 由語音編碼器、語音適配器、LLM 和流式語音解碼器組成。用戶的語音指令由語音編碼器進行編碼,經過語音適配器后輸入到 LLM。LLM 直接從語音指令中解碼文本響應,無需首先將語音轉錄為文本。語音解碼器是一個非自回歸(NAR)流式 Transformer,它將 LLM 的輸出表示作為輸入,并使用連接時序分類(Connectionist Temporal Classification, CTC)來預測與語音響應相對應的離散單元序列。
在推理過程中,當 LLM 自回歸生成文本響應時,語音解碼器同步生成相應的離散單元。為了更好地契合語音交互場景的特點,該研究通過重寫現(xiàn)有的文本指令數(shù)據(jù)并進行語音合成,構建了名為 InstructS2S-200K 的數(shù)據(jù)集。實驗結果表明,LLaMA-Omni 可以同步生成高質量的文本和語音響應,延遲低至 226ms。
此外,與 SpeechGPT 等語音 - 語言模型相比,LLaMA-Omni 顯著減少了所需的訓練數(shù)據(jù)和計算資源,從而能夠基于最新的 LLM 高效開發(fā)強大的語音交互模型。
LLaMA-Omni 模型概覽
如圖 2 所示,LLaMA-Omni 由語音編碼器、語音適配器、LLM 和語音解碼器組成,其中將用戶的語音指令、文本響應和語音響應分別表示為 X^S、Y^T 和 Y^S。
語音編碼器
該研究使用 Whisper-large-v3 (Radford et al., 2023)的編碼器作為語音編碼器 E。Whisper 是一種在大量音頻數(shù)據(jù)上訓練的通用語音識別模型,其編碼器能夠從語音中提取有意義的表征。
具體來說,對于用戶的語音指令 X^S,編碼后的語音表征由 H = ε(X^S) 給出,其中 H = [h_1, ..., h_N ] 是長度為 N 的語音表征序列,語音編碼器的參數(shù)在整個訓練過程中都被凍結。
語音適配器
為了使 LLM 能夠理解輸入語音,LLaMA-Omni 結合了一個可訓練的語音適配器 A,它將語音表征映射到 LLM 的嵌入空間中。語音適配器首先對語音表征 H 進行下采樣以減少序列長度。具體來說,每 k 個連續(xù)幀沿特征維度拼接:
接下來,H′ 通過具有 ReLU 激活的 2 層感知器,得到最終的語音表征 S:
大型語言模型
該研究使用 Llama-3.1-8B-Instruct(Dubey et al., 2024)作為 LLM M,它是目前 SOTA 開源 LLM,具有很強的推理能力,并且與人類偏好進行了對齊。prompt 模板 P (?) 如圖 3 所示。
將語音表征序列 S 填充到對應位置,然后將整個序列 P (S) 輸入到 LLM 中。最后,LLM 直接根據(jù)語音指令自回歸生成文本響應 Y^T = [y^T_1 , ..., y^T_M],并使用交叉熵損失進行訓練:
語音解碼器
為了與文本響應同步生成語音響應,LLaMA-Omni 在 LLM 之后添加了一個流式語音解碼器 D。它由幾個標準 Transformer 層組成,其架構與 LLaMA (Dubey et al., 2024) 相同,每個層都包含一個因果自注意力模塊和一個前饋網絡。
語音解碼器以非自回歸方式運行,將 LLM 的輸出表示經過上采樣后作為輸入,并生成與語音響應相對應的離散單元序列。
訓練
如圖 2 所示,LLaMA-Omni 采用兩階段訓練策略。第一階段訓練模型直接根據(jù)語音指令生成文本響應的能力。具體來說,語音編碼器被凍結,語音適配器和 LLM 使用公式 (3) 中的目標 L_LLM 進行訓練。語音解碼器在此階段不參與訓練。第二階段訓練模型來生成語音響應。在此階段,語音編碼器、語音適配器和 LLM 都被凍結,僅使用公式 (5) 中的目標 L_CTC 來訓練語音解碼器。
推理
在推理過程中,LLM 根據(jù)語音指令自回歸生成文本響應。同時,由于語音解碼器使用因果注意力,一旦 LLM 生成文本響應前綴,相應的輸出狀態(tài)
就可以被輸入到語音解碼器中以生成部分對齊
,進而產生與生成的文本前綴相對應的離散單元。
為了進一步實現(xiàn)語音波形的流式合成,當生成的離散單元數(shù)量達到預定義的塊大小 Ω 時,即將該離散單元片段輸入到聲碼器中以合成語音片段,然后立即播放給用戶。因此,用戶無需等待生成完整的文本響應即可開始收聽語音響應,從而確保低響應延遲。算法 1 描述了上述過程。
語音指令數(shù)據(jù)的構建:INSTRUCTS2S-200K
為了訓練 LLaMA-Omni,需要構建三元組數(shù)據(jù):語音指令,文本響應,語音響應。
對于語音指令數(shù)據(jù)而言,包含三步:指令重寫、響應生成、語音合成。
對于基礎文本指令,作者從 Alpaca 數(shù)據(jù)集中收集了大約 50K 條指令,該數(shù)據(jù)集涵蓋了廣泛的主題。此外,作者從 UltraChat 數(shù)據(jù)集中收集了大約 150K 條指令,該數(shù)據(jù)集主要由有關世界的問題組成。值得注意的是,UltraChat 是一個大規(guī)模多輪對話數(shù)據(jù)集,但作者僅選擇了前 150K 條條目并僅使用第一輪指令。最終獲得 200K 語音指令數(shù)據(jù),稱為 InstructS2S-200K。
實驗結果
訓練數(shù)據(jù)。作者采用 InstructS2S-200K 數(shù)據(jù)集,其包括 200K 語音指令數(shù)據(jù)。
模型配置。作者使用 Whisper-large-v3 編碼器作為語音編碼器,使用 Llama-3.1-8B-Instruct 作為 LLM。
訓練。LLaMA-Omni 遵循兩階段訓練過程:在第一階段,作者訓練語音適配器和 LLM,批處理大小為 32,共 3 個 epoch;在第二階段,作者訓練語音解碼器,使用與第一階段相同的批處理大小、step 數(shù)等。整個訓練過程在 4 個 NVIDIA L40 GPU 上大約需要 65 小時。
在評估方面,作者從以下方面對模型進行了評估:
- ChatGPT 得分;
- 語音 - 文本對齊;
- 語音質量;
- 響應延遲。
除此以外,語音 - 語言模型的基線系統(tǒng)包含 SpeechGPT 、 SALMONN (+TTS) 、 Qwen2-Audio (+TTS) 。
主要結果
表 1 給出了 InstructS2S-Eval 基準測試主要結果。
首先,在 S2TIF 任務中,從內容(content)角度來看,LLaMA-Omni 相比之前的模型有了顯著提升,這主要是因為 LLaMA-Omni 是基于最新的 Llama-3.1-8B Instruct 模型開發(fā)的,充分利用了其強大的文本指令跟隨能力。
從風格(style)角度來看,SALMONN 和 Qwen2-Audio 得分較低,因為它們是語音 - 文本模型,輸出風格與語音交互場景不太對齊,經常產生格式化的內容,包含大量冗余解釋。相比之下,SpeechGPT 作為語音 - 語音模型,獲得了更高的風格得分。
LLaMA-Omni 獲得了最高的風格得分,這說明在 InstructS2S-200K 數(shù)據(jù)集上訓練后,模型輸出風格已經與語音交互場景很好地對齊。
對于 S2SIF 任務,LLaMA-Omni 在內容和風格得分上也都優(yōu)于之前的模型。這進一步證實了 LLaMA-Omni 能夠以簡潔、高效的方式通過語音有效地處理用戶指令。
此外,在語音和文本響應的對齊方面,LLaMA-Omni 的 ASR-WER 和 ASR-CER 得分最低。相比之下,SpeechGPT 在對齊語音和文本響應方面表現(xiàn)不佳,這可能是因為它是串行生成文本和語音的。
級聯(lián)系統(tǒng)(如 SALMONN+TTS 和 Qwen2-Audio+TTS)的語音 - 文本對齊也不是最理想的,主要是因為生成的文本響應可能包含無法合成語音的字符。這個問題在 Qwen2-Audio 中尤為明顯,它偶爾會輸出中文字符,從而導致語音響應錯誤。
相比之下,LLaMA-Omni 的 ASR-WER 和 ASR-CER 得分最低,表明生成的語音和文本響應之間的對齊程度更高,進一步驗證了 LLaMA-Omni 在同時生成文本和語音響應方面的優(yōu)勢。
語音質量和響應延遲之間的權衡
為了更好地理解 Ω 的影響,作者對系統(tǒng)延遲、語音和文本響應之間的對齊以及不同 Ω 設置下生成的語音質量進行了探索。
如表 2 所示,當 Ω 設置為 10 時,系統(tǒng)的響應延遲低至 226ms,甚至低于 GPT-4o 的平均音頻延遲 320ms。
綜合來看,可以根據(jù)不同的場景調整 Ω 的值,以實現(xiàn)響應延遲和語音質量之間的權衡。
解碼時間
表 3 列出了不同模型在 S2TIF 和 S2SIF 任務上的平均解碼時間。
LLaMA-Omni 直接提供簡潔的答案,從而顯著縮短解碼時間,平均每條指令僅為 1.49 秒。
LLaMA-Omni 同時輸出文本和語音響應,并采用非自回歸架構生成離散單元,總生成時間僅增加 1.28 倍,體現(xiàn)出 LLaMA-Omni 在解碼速度上的優(yōu)勢。
案例研究
為了直觀的了解不同模型的響應差異,作者在表 4 中提供了一個示例。
可以觀察到 Qwen2-Audio 的響應相當冗長,并且包含換行符和括號等無法合成語音的元素。
SALMONN 的響應也有點長。
SpeechGPT 的響應風格更適合語音交互場景,但其響應所包含的信息量較少。
相比之下,LLaMA-Omni 給出的響應在保持簡潔風格的同時更加詳細和有幫助,在語音交互場景中的表現(xiàn)優(yōu)于之前的模型。