一文概述聯(lián)邦持續(xù)學(xué)習(xí)最新研究進(jìn)展
由于數(shù)據(jù)隱私限制,多個(gè)中心之間的數(shù)據(jù)共享受到限制,這就影響了聯(lián)邦學(xué)習(xí)架構(gòu)下多中心合作開(kāi)發(fā)高性能深度學(xué)習(xí)模型的效果。持續(xù)學(xué)習(xí)(Continual Learning)作為點(diǎn)對(duì)點(diǎn)聯(lián)合學(xué)習(xí)的一種方法,可以通過(guò)共享中間模型而不是訓(xùn)練數(shù)據(jù)來(lái)繞過(guò)數(shù)據(jù)隱私的限制,從而促進(jìn)多中心協(xié)作開(kāi)發(fā)深度學(xué)習(xí)算法。近期不斷有研究人員探索聯(lián)邦持續(xù)學(xué)習(xí)方法(Federated Continual Learning,F(xiàn)CL),即,研究持續(xù)學(xué)習(xí)在聯(lián)邦學(xué)習(xí)架構(gòu)下多中心協(xié)作的可行性。
1、背景回顧
1.1 持續(xù)學(xué)習(xí) (Continual Learning)
首先,我們來(lái)回顧一下什么是持續(xù)學(xué)習(xí)。當(dāng)前,一般認(rèn)為持續(xù)學(xué)習(xí) (Continual Learning) 和增量學(xué)習(xí)(Incremental Learning)、終身學(xué)習(xí) (Lifelong Learning) 是等價(jià)表述,它們都是在連續(xù)的數(shù)據(jù)流中訓(xùn)練模型,隨著時(shí)間的推移,更多的數(shù)據(jù)逐漸可用,同時(shí)舊數(shù)據(jù)可能由于存儲(chǔ)限制或隱私保護(hù)等原因而逐漸不可用,并且學(xué)習(xí)任務(wù)的類(lèi)型和數(shù)量沒(méi)有預(yù)定義 (例如分類(lèi)任務(wù)中的類(lèi)別數(shù))。
當(dāng)一個(gè)模型在新的數(shù)據(jù)集或任務(wù)上被重新訓(xùn)練時(shí),深度學(xué)習(xí)會(huì)遇到災(zāi)難性遺忘的問(wèn)題,即深度學(xué)習(xí)模型會(huì)災(zāi)難性地遺忘已經(jīng)學(xué)到的舊知識(shí)。持續(xù)學(xué)習(xí)技術(shù)的目的是讓機(jī)器學(xué)習(xí)模型通過(guò)新的數(shù)據(jù)進(jìn)行更新,同時(shí)保留以前學(xué)到的知識(shí)。持續(xù)學(xué)習(xí)具有兩大優(yōu)點(diǎn):1) 不需要保存之前任務(wù)上學(xué)習(xí)過(guò)的訓(xùn)練數(shù)據(jù),從而在節(jié)約內(nèi)存的同時(shí)解決了由于物理設(shè)備 (例如機(jī)器內(nèi)存 ) 或?qū)W習(xí)策略 (例如隱私保護(hù) ) 的限制所導(dǎo)致的數(shù)據(jù)不能被長(zhǎng)期存儲(chǔ)問(wèn)題;2) 模型能夠保存之前任務(wù)所學(xué)到的知識(shí),并且能夠極大程度地將之前任務(wù)學(xué)習(xí)到的知識(shí)運(yùn)用到未來(lái)任務(wù)的學(xué)習(xí)中,從而提高學(xué)習(xí)效率。
目前,持續(xù)學(xué)習(xí)的方法仍在不斷發(fā)展中,還沒(méi)有嚴(yán)格的數(shù)學(xué)定義。韓等在文章 [1] 中給出了一幅持續(xù)學(xué)習(xí)的示意圖,如圖 1 所示,“在連續(xù)學(xué)習(xí)過(guò)程中,智能體逐個(gè)對(duì)每個(gè)連續(xù)的非獨(dú)立均勻分布流數(shù)據(jù)示例進(jìn)行學(xué)習(xí),并且該智能體對(duì)每個(gè)示例只進(jìn)行一次訪(fǎng)問(wèn)。這種學(xué) 習(xí)方式與動(dòng)物學(xué)習(xí)過(guò)程更為接近。如果我們忽略各個(gè)任務(wù)的先后次序問(wèn)題,單獨(dú)訓(xùn)練每個(gè)任務(wù),這將導(dǎo)致災(zāi)難性遺忘,這也是連續(xù)學(xué)習(xí)一直以來(lái)所面臨的最大問(wèn)題。因此,連續(xù)學(xué)習(xí)的本質(zhì)是通過(guò)各種手段高效地轉(zhuǎn)化和利用已經(jīng)學(xué)過(guò)的知識(shí)來(lái)完成新任務(wù)的學(xué)習(xí),并且能夠極大程度地降低遺忘帶來(lái)的問(wèn)題。[1]”。
圖 1. 連續(xù)學(xué)習(xí)示意圖 [1]
到目前為止,已經(jīng)有許多持續(xù)學(xué)習(xí)算法,主要分為三種類(lèi)型:回放方法(Memory Reply)、動(dòng)態(tài)結(jié)構(gòu)模型(Dynamic Structural Models)和正則化方法(Regularization Model)。1)回放方法從以前的數(shù)據(jù)集中選擇有代表性的樣本,以保留所學(xué)的知識(shí)。該方法的研究重點(diǎn)是:「要保留舊任務(wù)的哪部分?jǐn)?shù)據(jù),以及如何利用舊數(shù)據(jù)與新數(shù)據(jù)一起訓(xùn)練模型」,這對(duì)于克服數(shù)據(jù)存儲(chǔ)的限制是可行的,但對(duì)于多中心合作來(lái)說(shuō)是不可行的,因?yàn)橛捎跀?shù)據(jù)隱私問(wèn)題,其他中心的樣本是不可用的 [6-8]。2)動(dòng)態(tài)結(jié)構(gòu)模型為多任務(wù)場(chǎng)景設(shè)計(jì)動(dòng)態(tài)網(wǎng)絡(luò)架構(gòu)或動(dòng)態(tài)參數(shù),網(wǎng)絡(luò)的各個(gè)部分(如某些權(quán)重或某些神經(jīng)元連接)負(fù)責(zé)對(duì)應(yīng)的各個(gè)任務(wù) [9][10]。3)正則化方法使用相同的傳統(tǒng)神經(jīng)網(wǎng)絡(luò),但在損失函數(shù)中加入了新的正則化項(xiàng),以保留學(xué)習(xí)知識(shí)的重要參數(shù)。該方法的主要思想是「通過(guò)給新任務(wù)的損失函數(shù)施加約束的方法來(lái)保護(hù)舊知識(shí)不被新知識(shí)覆蓋」[11][12]。
1.2 聯(lián)邦持續(xù)學(xué)習(xí) (Federated Continual Learning)
聯(lián)邦學(xué)習(xí)的主要思想是去中心化,將模型下放到各個(gè)參與聯(lián)合訓(xùn)練的客戶(hù)端本地,基于本地客戶(hù)端的數(shù)據(jù)進(jìn)行模型訓(xùn)練,而不需要將用戶(hù)數(shù)據(jù)上傳到中央服務(wù)器,從而保護(hù)各個(gè)客戶(hù)端中的隱私。然而,大多數(shù)現(xiàn)有方法假設(shè)整個(gè)聯(lián)邦學(xué)習(xí)框架的數(shù)據(jù)類(lèi)別隨著時(shí)間的推移是固定不變的。實(shí)際情況中,已經(jīng)參與聯(lián)邦學(xué)習(xí)的客戶(hù)端經(jīng)??赡苁占叫骂?lèi)別的數(shù)據(jù),但考慮到每個(gè)客戶(hù)端本地設(shè)備的存儲(chǔ)空間非常有限,很難對(duì)其收集到的所有類(lèi)別都保存下足夠數(shù)量的數(shù)據(jù),這種情況下,現(xiàn)實(shí)世界中的聯(lián)邦學(xué)習(xí)模型很可能對(duì)于舊類(lèi)數(shù)據(jù)的性能遇到嚴(yán)重的災(zāi)難性遺忘。此外,聯(lián)邦學(xué)習(xí)框架往往還會(huì)有持續(xù)不斷的新用戶(hù)的參與,這些新用戶(hù)往往有著大量的新數(shù)據(jù)類(lèi)別,這樣會(huì)進(jìn)一步加劇全局模型的災(zāi)難性遺忘。
近年來(lái),有研究人員陸續(xù)提出將聯(lián)邦學(xué)習(xí)和持續(xù)學(xué)習(xí)的思想結(jié)合起來(lái)構(gòu)建聯(lián)邦持續(xù)學(xué)習(xí)框架。然而,直接簡(jiǎn)單的將聯(lián)邦學(xué)習(xí)和持續(xù)學(xué)習(xí)相結(jié)合會(huì)帶來(lái)新的問(wèn)題。首先是,聯(lián)邦持續(xù)學(xué)習(xí)仍然面臨災(zāi)難性遺忘,此外還會(huì)帶來(lái)來(lái)自其他客戶(hù)端潛在的干擾。因此我們需要有選擇地利用來(lái)自其他客戶(hù)端的知識(shí),以最小化客戶(hù)端間的干擾,最大化進(jìn)行客戶(hù)端間的知識(shí)轉(zhuǎn)移。第二個(gè)問(wèn)題是聯(lián)邦學(xué)習(xí)之間進(jìn)行通信交換知識(shí)時(shí),可能會(huì)造成通信成本過(guò)大,“通信代價(jià)” 成為了一個(gè)不可忽視的問(wèn)題。
我們通過(guò)四篇近期發(fā)表的文章概覽聯(lián)邦持續(xù)學(xué)習(xí)的最新研究進(jìn)展。
- 第一篇文章提出了一個(gè)新的聯(lián)邦持續(xù)學(xué)習(xí)框架 —— 聯(lián)邦加權(quán)客戶(hù)端間的傳輸(FedWeIT)。FedWeIT 將各個(gè)客戶(hù)端的本地模型參數(shù)分解為稠密基參數(shù)(a dense base parameter)和稀疏的任務(wù)自適應(yīng)參數(shù)( sparse task-adaptive parameters),以便進(jìn)行更高效地通信 [2]。
- 第二篇文章提出了一種全新的全局 - 本地遺忘補(bǔ)償 (GLFC) 模型,即同時(shí)從全局和本地兩個(gè)角度出發(fā),盡可能地減弱災(zāi)難性遺忘,使得聯(lián)邦學(xué)習(xí)最終可訓(xùn)練一個(gè)全局增量模型 [3]。
- 第三篇文章提出了一種聯(lián)邦互相關(guān)和持續(xù)學(xué)習(xí)方法。對(duì)于異構(gòu)性問(wèn)題,該方法利用未標(biāo)記的公共數(shù)據(jù)進(jìn)行通信,并構(gòu)造互相關(guān)矩陣來(lái)學(xué)習(xí)域偏移(domain shift)下的可概括性的表示。同時(shí),對(duì)于災(zāi)難性遺忘,在本地更新中利用跨域和本域信息進(jìn)行知識(shí)蒸餾,有效地提供域間和域內(nèi)知識(shí)而不泄露參與者的隱私 [4]。
- 第四篇文章提出了一個(gè)聯(lián)邦學(xué)習(xí)架構(gòu),稱(chēng)為聯(lián)邦多語(yǔ)者 TTS 系統(tǒng) Fed-Speech。該架構(gòu)使用漸進(jìn)式修剪掩碼來(lái)分離參數(shù),以保留說(shuō)話(huà)人的語(yǔ)調(diào)。此外,應(yīng)用選擇性掩碼來(lái)有效地重用任務(wù)中的知識(shí)。最后,引入 private speaker embedding 以保持用戶(hù)的隱私 [5]。
2、Federated Continual Learning with Weighted Inter-client Transfer
持續(xù)學(xué)習(xí)和聯(lián)邦學(xué)習(xí)在現(xiàn)實(shí)世界的深度神經(jīng)網(wǎng)絡(luò)中都很重要。然而,對(duì)于每個(gè)客戶(hù)端從私有的循環(huán)數(shù)據(jù)流中學(xué)習(xí)一連串任務(wù)的情況,卻很少有人進(jìn)行研究。這種聯(lián)邦學(xué)習(xí)的問(wèn)題給持續(xù)學(xué)習(xí)帶來(lái)了新的挑戰(zhàn),比如,如何有效利用其他客戶(hù)端的知識(shí)同時(shí)防止不相關(guān)知識(shí)的干擾?為了解決這些問(wèn)題,本文提出了一個(gè)新穎的聯(lián)邦持續(xù)學(xué)習(xí)框架,即聯(lián)邦加權(quán)客戶(hù)端間傳輸(Federated Weighted Inter-client Transfer,F(xiàn)edWeIT),該框架將網(wǎng)絡(luò)工作權(quán)重分解為全局的聯(lián)邦參數(shù)( global federated parameters)和稀疏的特定任務(wù)參數(shù)( sparse task-specific parameters),每個(gè)客戶(hù)端可以通過(guò)對(duì)其特定任務(wù)參數(shù)進(jìn)行加權(quán)組合,從其他客戶(hù)端那里獲得選擇性知識(shí)。具體是通過(guò)中央服務(wù)端獲得其他客戶(hù)端的 task-specific parameters,再對(duì)這些參數(shù)進(jìn)行加權(quán)聚合得到 selective knowledge,從而最大化相似任務(wù)之間共識(shí)知識(shí)的傳遞。FedWeIT 最大限度地減少了不兼容任務(wù)之間的干擾,并且在學(xué)習(xí)過(guò)程中允許客戶(hù)端之間的積極知識(shí)轉(zhuǎn)移。
作者將 Fed-WeIT 與現(xiàn)有的聯(lián)邦學(xué)習(xí)和連續(xù)學(xué)習(xí)方法在不同程度的客戶(hù)端之間的任務(wù)相似性進(jìn)行了驗(yàn)證,本文模型明顯較優(yōu),通信成本大大降低。代碼已公布在 https://github.com/wyjeong/FedWeIT。
2.1 方法介紹
受人類(lèi)從間接經(jīng)驗(yàn)中學(xué)習(xí)過(guò)程的啟發(fā),作者引入了一種新的聯(lián)邦學(xué)習(xí)環(huán)境下的持續(xù)學(xué)習(xí) --- 聯(lián)邦持續(xù)學(xué)習(xí)(Federated Continual Learning,F(xiàn)CL)。FCL 假設(shè)多個(gè)客戶(hù)端在私有數(shù)據(jù)流中的任務(wù)序列上進(jìn)行訓(xùn)練,同時(shí)與中央服務(wù)器交互所學(xué)到的參數(shù)。在標(biāo)準(zhǔn)的持續(xù)學(xué)習(xí)中(在一臺(tái)機(jī)器上),模型從一連串的任務(wù) {T (1),T (2),...,T (T)} 中反復(fù)學(xué)習(xí),其中 T (t) 是第 t 個(gè)任務(wù)的標(biāo)記數(shù)據(jù)集。假設(shè)現(xiàn)實(shí)情況如下,任務(wù)序列是一個(gè)具有未知到達(dá)順序的任務(wù)流,這樣,模型只能在任務(wù) t 的訓(xùn)練期訪(fǎng)問(wèn) T (t),之后就無(wú)法訪(fǎng)問(wèn)了。給定 T (t) 和到目前為止學(xué)到的模型,任務(wù) t 的學(xué)習(xí)目標(biāo)如下:
然后,將傳統(tǒng)的持續(xù)學(xué)習(xí)擴(kuò)展到有多個(gè)客戶(hù)端和一個(gè)中央服務(wù)器的聯(lián)邦學(xué)習(xí)環(huán)境。假設(shè)有 C 個(gè)客戶(hù)端,每個(gè)客戶(hù)端 c_c∈{c_1, . . . , c_C } 在一個(gè)私人可訪(fǎng)問(wèn)的任務(wù)序列 {T^(1)_c , T^(2)_c , ..., T^(t)_c }? T 上訓(xùn)練一個(gè)模型。需要注意的是,在步驟 t 收到的跨客戶(hù)端任務(wù)之間沒(méi)有關(guān)系。
現(xiàn)在的目標(biāo)是通過(guò)與中央服務(wù)器溝通模型參數(shù),有效地在他們自己的私有任務(wù)流上訓(xùn)練 C 類(lèi)持續(xù)學(xué)習(xí)模型,中央服務(wù)器匯總每個(gè)客戶(hù)端發(fā)送的參數(shù),并將它們重新分配給客戶(hù)端。在聯(lián)邦持續(xù)學(xué)習(xí)框架中,將參數(shù)匯總為一個(gè)全局參數(shù) θ_G,允許客戶(hù)端間的知識(shí)轉(zhuǎn)移,因?yàn)榭蛻?hù)端 c_i 在第 q 輪學(xué)到的任務(wù)可能與客戶(hù)端 c_j 在第 r 輪學(xué)到的任務(wù)相似或相關(guān)。然而,作者分析使用單一的綜合參數(shù) θ_G 可能是實(shí)現(xiàn)這一目標(biāo)的次優(yōu)選擇,因?yàn)閬?lái)自不相關(guān)任務(wù)的知識(shí)可能沒(méi)有用處,甚至可能會(huì)通過(guò)將其參數(shù)改變到不正確的方向來(lái)阻礙每個(gè)客戶(hù)端的訓(xùn)練,作者將此描述為客戶(hù)端間的干擾。
另一個(gè)實(shí)際上也很重要的問(wèn)題是通信效率。從客戶(hù)端到中央服務(wù)器,以及從中央服務(wù)器到客戶(hù)端的參數(shù)傳輸都會(huì)產(chǎn)生很大的通信成本,這對(duì)于持續(xù)學(xué)習(xí)環(huán)境來(lái)說(shuō)是有問(wèn)題的,因?yàn)榭蛻?hù)端可能會(huì)在無(wú)限的任務(wù)流上進(jìn)行訓(xùn)練。如前所述,造成這些問(wèn)題的主要原因是,在多個(gè)客戶(hù)端學(xué)到的所有任務(wù)的知識(shí)被存儲(chǔ)在一組參數(shù) θ_G 中。然而,為了使知識(shí)轉(zhuǎn)移有效,每個(gè)客戶(hù)端應(yīng)該有選擇地只利用在其他客戶(hù)端訓(xùn)練的相關(guān)任務(wù)的知識(shí)。這種選擇性轉(zhuǎn)移也是最小化客戶(hù)端間干擾的關(guān)鍵,因?yàn)樗豢紤]可能干擾學(xué)習(xí)的不相關(guān)任務(wù)的知識(shí)。
作者通過(guò)分解參數(shù)來(lái)解決這個(gè)問(wèn)題,這些參數(shù)分為三種不同的類(lèi)型,具有不同的作用:全局參數(shù)(θ_G),捕獲所有客戶(hù)端的全局和通用知識(shí);本地基礎(chǔ)參數(shù)(B),捕獲每個(gè)客戶(hù)端的通用知識(shí);任務(wù)適應(yīng)性參數(shù)(A),用于每個(gè)客戶(hù)端的每個(gè)具體任務(wù)。將一組在持續(xù)學(xué)習(xí)客戶(hù)端 c_c 的任務(wù) t 的模型參數(shù) θ^(t)_c 定義如下:
其中,B^(t)_c 是第 c 個(gè)客戶(hù)端的基本參數(shù)集,在客戶(hù)端的所有任務(wù)中共享。m^(t)_c 是稀疏向量掩碼的集合,它允許對(duì)任務(wù) t 的 B^(t)_c 進(jìn)行適應(yīng)性轉(zhuǎn)換,A^(t)_c 是客戶(hù)端 c_c 的稀疏任務(wù)適應(yīng)性參數(shù)集合。L 是神經(jīng)網(wǎng)絡(luò)中的層數(shù),I_l、O_l 分別是第 l 層權(quán)重的輸入和輸出維度。
上式中的第一項(xiàng)允許有選擇地利用全局知識(shí)。作者希望每個(gè)客戶(hù)端的基礎(chǔ)參數(shù) B^(t)_c 能夠捕獲所有客戶(hù)端的所有任務(wù)中的通用知識(shí)。如圖 2(a),在每一輪 t 中用前一次迭代的全局參數(shù) θ^(t-1)_G 來(lái)初始化,匯總從客戶(hù)端發(fā)送的參數(shù)。這使得 B^(t)_c 也能從關(guān)于所有任務(wù)的全局知識(shí)中受益。然而,由于 θ^(t-1)_G 也包含與當(dāng)前任務(wù)無(wú)關(guān)的知識(shí),我們不是原封不動(dòng)地使用它,而是學(xué)習(xí)稀疏掩碼 m^(t)_c,只為給定的任務(wù)選擇相關(guān)參數(shù)。這種稀疏的參數(shù)選擇有助于最大限度地減少客戶(hù)端之間的干擾,從而實(shí)現(xiàn)高效的通信。上式中的第二項(xiàng)是任務(wù)適應(yīng)性參數(shù) A^(t)_c。對(duì)參數(shù)進(jìn)行加法分解處理后,能夠?qū)W會(huì)捕捉第一項(xiàng)沒(méi)有捕捉到的關(guān)于任務(wù)的知識(shí),因此將捕捉到關(guān)于任務(wù) T^(t)_c 的具體知識(shí)。上式中的最后一項(xiàng)描述了加權(quán)的客戶(hù)端間知識(shí)轉(zhuǎn)移。我們擁有一組從中央服務(wù)器傳輸?shù)膮?shù),其中包含了所有客戶(hù)端的所有任務(wù)適應(yīng)性參數(shù)。為了有選擇地利用這些來(lái)自其他客戶(hù)端的間接經(jīng)驗(yàn),進(jìn)一步在這些參數(shù)上分配注意力 α^(t)_c,并采取加權(quán)組合的方式。通過(guò)學(xué)習(xí)這種注意力,每個(gè)客戶(hù)端可以只選擇有助于學(xué)習(xí)給定任務(wù)的相關(guān)任務(wù)適應(yīng)性參數(shù)。盡管作者將 A^(j)_i 設(shè)計(jì)成高度稀疏的,在實(shí)踐中使用大約 2-3% 的全參數(shù)內(nèi)存,但發(fā)送所有的任務(wù)知識(shí)仍然是不可取的。因此,作者選擇從知識(shí)庫(kù)中傳輸所有時(shí)間步驟的隨機(jī)抽樣的任務(wù)適應(yīng)性參數(shù),根據(jù)經(jīng)驗(yàn),作者發(fā)現(xiàn)這種處理方式在實(shí)踐中取得了良好的效果。
圖 2. FedWeIT 更新。(a) 客戶(hù)端發(fā)送稀疏化的聯(lián)邦參數(shù) B_c ⊙m^(t)_c 。之后,中央服務(wù)器將聚合的參數(shù)重新分配給客戶(hù)端。(b) 知識(shí)庫(kù)存儲(chǔ)了客戶(hù)端先前的任務(wù)適應(yīng)性參數(shù),每個(gè)客戶(hù)端有選擇地利用這些參數(shù),并有一個(gè)注意力掩碼訓(xùn)練。我們通過(guò)優(yōu)化以下目標(biāo)函數(shù)來(lái)學(xué)習(xí)可分解參數(shù) θ^(t)_c:
其中,L 是損失函數(shù),?(?) 是所有任務(wù)自適應(yīng)參數(shù)和掩碼變量的稀疏性誘導(dǎo)正則化項(xiàng),以使它們變得稀疏。第二個(gè)正則化項(xiàng)用于追溯更新過(guò)去的任務(wù)適應(yīng)性參數(shù),通過(guò)反映基礎(chǔ)參數(shù)的變化,幫助任務(wù)適應(yīng)性參數(shù)保持目標(biāo)任務(wù)的原始解決方案。?B^(t)_c 是指當(dāng)前時(shí)間段和前一個(gè)時(shí)間段的基礎(chǔ)參數(shù)之差。?A^(i)_c 是任務(wù) i 在當(dāng)前和前一時(shí)間段的任務(wù)適應(yīng)性參數(shù)之間的差異。這種正則化處理對(duì)于防止災(zāi)難性的遺忘至關(guān)重要。λ1 和 λ2 是控制兩個(gè)正則化作用的超參數(shù)。
客戶(hù)端。在每個(gè)輪次 r,每個(gè)客戶(hù)端 c_c 用中央服務(wù)器發(fā)送的全局參數(shù)的非零分量部分更新其基礎(chǔ)參數(shù);也就是說(shuō),B_c (n) = θ_G (n),其中,n 是全局參數(shù)的非零元素。它為新任務(wù)獲得一個(gè)稀疏的基礎(chǔ)參數(shù) ^Bb^(t)_c 和任務(wù)適應(yīng)性參數(shù) A^(t)_c,將這兩個(gè)參數(shù)發(fā)送到中央服務(wù)器,與 FCL 基線(xiàn)方法相比,成本更低。FCL 基線(xiàn)方法需要 | C|×R×|θ| 的資源用于客戶(hù)端到中央服務(wù)器的通信,而 FedWeIT 需要 | C|×(R×|Bb|+|A|),其中 R 是每個(gè)任務(wù)的通信輪數(shù),|?| 是參數(shù)數(shù)量。
中央服務(wù)器。中央服務(wù)器首先對(duì)所有客戶(hù)端發(fā)送的基礎(chǔ)參數(shù)進(jìn)行匯總,取其加權(quán)平均值 θ_G。然后,將 θ_G 廣播給所有客戶(hù)端。t-1 的任務(wù)適應(yīng)性參數(shù)在訓(xùn)練任務(wù) t 期間在每個(gè)客戶(hù)端廣播一次。FCL 基線(xiàn)需要 | C|×R×|θ| 的中央服務(wù)器 - 客戶(hù)端通信成本,而 FedWeIT 需要 | C|×(R×|θG|+(|C|-1)×|A|),其中 θ_G、A 是高度稀疏的。算法 1 中描述了 FedWeIT 的算法。
2.2 實(shí)驗(yàn)情況介紹
作者驗(yàn)證了 FedWeIT 在不同的任務(wù)序列配置下與基線(xiàn)方法(Overlapped-CIFAR-100 和 NonIID-50)的對(duì)比。1) Overlapped-CIFAR-100:將 100 個(gè) CIFAR-100 數(shù)據(jù)集類(lèi)分組為 20 個(gè) NonIID 超類(lèi)任務(wù)。然后,從 20 個(gè)任務(wù)中隨機(jī)抽取 10 個(gè)任務(wù)并拆分實(shí)例,為每個(gè)任務(wù)重疊的客戶(hù)端創(chuàng)建一個(gè)任務(wù)序列。2) NonIID-50:使用以下八個(gè)基準(zhǔn)數(shù)據(jù)集:MNIST、CIFAR-10/-100、SVHN、Fashion MNIST,Not MNIST 和 TrafficSigns。將 8 個(gè)數(shù)據(jù)集中的類(lèi)劃分為 50 個(gè) NonIID 任務(wù),每個(gè)任務(wù)由 5 個(gè)類(lèi)組成,這些類(lèi)與用于其他任務(wù)的類(lèi)不相交。
實(shí)驗(yàn)中用到的對(duì)比模型如下:1)STL:?jiǎn)稳蝿?wù)學(xué)習(xí)每個(gè)到達(dá)的任務(wù)。2) EWC:每個(gè)客戶(hù)端進(jìn)行個(gè)人持續(xù)學(xué)習(xí)。3) Stable-SGD:每個(gè)客戶(hù)端持續(xù)學(xué)習(xí) Stable-SGD。4) APD:每個(gè)客戶(hù)端使用 APD 進(jìn)行個(gè)人持續(xù)學(xué)習(xí)。5) FedProx:使用 FedProx 算法的 FCL。6) Scaffold :使用 Scaffold 算法的 FCL。7) FedCurv:使用 FedCurv 算法的 FCL。8) FedProx-[model]:使用帶有 [model] 的 FedProx 算法進(jìn)行訓(xùn)練的 FCL。9) FedWeIT:FedWeIT 算法。
表 1 給出了在兩個(gè)數(shù)據(jù)集上完成(聯(lián)邦)連續(xù)學(xué)習(xí)后,每項(xiàng)任務(wù)的最終平均性能。我們觀(guān)察到,基于 FedProx 的 FCL 方法與沒(méi)有聯(lián)邦學(xué)習(xí)的相同方法相比,會(huì)降低連續(xù)學(xué)習(xí)(CL)方法的性能。這是因?yàn)樵诓幌嚓P(guān)的任務(wù)中學(xué)習(xí)的所有客戶(hù)端參數(shù)的匯總導(dǎo)致了對(duì)每個(gè)任務(wù)學(xué)習(xí)的嚴(yán)重干擾,這導(dǎo)致了災(zāi)難性的遺忘和次優(yōu)的任務(wù)適應(yīng)性。Scaffold 在 FCL 上的表現(xiàn)很差,因?yàn)樗鼘?duì)本地梯度的正則化處理對(duì) FCL 是有害的,因?yàn)樗械目蛻?hù)端都是從不同的任務(wù)序列中學(xué)習(xí)的。雖然 FedCurv 減少了任務(wù)間的參數(shù)差異,但它不能最大限度地減少任務(wù)間的干擾,這導(dǎo)致它的表現(xiàn)不如單機(jī) CL 方法。另一方面,F(xiàn)edWeIT 在兩個(gè)數(shù)據(jù)集上的表現(xiàn)都明顯優(yōu)于單機(jī) CL 基線(xiàn)和 FCL 基線(xiàn)。即使有更多的客戶(hù)端(C = 100),F(xiàn)edWeIT 也一直優(yōu)于所有基線(xiàn)(圖 3)。這種改進(jìn)主要?dú)w功于 FedWeIT 有選擇地利用其他客戶(hù)端的知識(shí)來(lái)迅速適應(yīng)目標(biāo)任務(wù)的能力,并獲得更好的最終性能。
表 1. 5 個(gè)客戶(hù)端在 FCL 期間對(duì)兩個(gè)數(shù)據(jù)集的平均每任務(wù)表現(xiàn)(分?jǐn)?shù) = 1.0)。在完成所有學(xué)習(xí)階段的 3 次單獨(dú)試驗(yàn)后,作者測(cè)量了任務(wù)準(zhǔn)確性和模型大小。作者還測(cè)量了訓(xùn)練每個(gè)任務(wù)的 C2S/S2C 通信成本
圖 3. 訓(xùn)練最后兩個(gè)(第 9 和第 10 個(gè))任務(wù)時(shí)的平均任務(wù)適應(yīng)性,有 5 個(gè)和 100 個(gè)客戶(hù)端
對(duì)新任務(wù)的快速適應(yīng)是客戶(hù)端間知識(shí)轉(zhuǎn)移的另一個(gè)明顯優(yōu)勢(shì)。為了進(jìn)一步證明本文方法在更大的網(wǎng)絡(luò)中的實(shí)用性,作者在 ResNet-18 的 NonIID 數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn)(表 2),F(xiàn)edWeIT 在使用較少參數(shù)的情況下仍然明顯優(yōu)于最強(qiáng)基線(xiàn)(FedProx-APD)。
表 2. 使用 ResNet-18 在 NonIID-50 數(shù)據(jù)集上的 FCL 結(jié)果
此外,作者研究了在持續(xù)學(xué)習(xí)過(guò)程中過(guò)去任務(wù)的表現(xiàn)如何變化,以了解每種方法的災(zāi)難性遺忘的嚴(yán)重程度。圖 4 給出了 FedWeIT 和 FCL 基線(xiàn)在第 3、第 6 和第 8 個(gè)任務(wù)上的表現(xiàn)。我們觀(guān)察到,F(xiàn)CL 基線(xiàn)比帶有 EWC 的本地持續(xù)學(xué)習(xí)遭受了更嚴(yán)重的災(zāi)難性遺忘,這是因?yàn)榭蛻?hù)端間的干擾,來(lái)自其他客戶(hù)端的不相關(guān)任務(wù)的知識(shí)覆蓋了過(guò)去的任務(wù)知識(shí)。與此相反,本文模型沒(méi)有顯示出災(zāi)難性遺忘的跡象。這主要是由于選擇性地利用了通過(guò)全局 / 任務(wù)自適應(yīng)參數(shù)從其他客戶(hù)端那里學(xué)到的先驗(yàn)知識(shí),這使得它能夠有效地緩解客戶(hù)端間的干擾。FedProx-APD 也不存在災(zāi)難性遺忘的問(wèn)題,但由于知識(shí)轉(zhuǎn)移的無(wú)效性,它們的性能較差。
圖 4. 災(zāi)難性遺忘。在 NonIID-50 的聯(lián)邦持續(xù)學(xué)習(xí)過(guò)程中,在第 3、6 和 8 個(gè)任務(wù)中關(guān)于當(dāng)前任務(wù)適應(yīng)性的性能比較
3、Federated Class Incremental Learning
3.1 本地災(zāi)難性遺忘補(bǔ)償
通過(guò)在分散的客戶(hù)端上進(jìn)行數(shù)據(jù)私有的協(xié)作訓(xùn)練,聯(lián)邦學(xué)習(xí)吸引了越來(lái)越多的關(guān)注。然而,大多數(shù)現(xiàn)有的方法假設(shè)整體框架的對(duì)象類(lèi)別是固定的。這使得全局模型在現(xiàn)實(shí)世界的場(chǎng)景中遭受了嚴(yán)重的災(zāi)難性遺忘,因?yàn)楸镜乜蛻?hù)端經(jīng)常不斷地收集新的類(lèi)別,而用于存儲(chǔ)舊的類(lèi)別的存儲(chǔ)空間非常有限。此外,存有之前未見(jiàn)過(guò)的新類(lèi)別數(shù)據(jù)的新客戶(hù)端可能參與 FL 訓(xùn)練,這就進(jìn)一步加劇了全局模型的災(zāi)難性遺忘。為了應(yīng)對(duì)這些挑戰(zhàn),本文提出了一個(gè)新的全局 - 本地遺忘補(bǔ)償(Global-Local Forgetting Compensation,GLFC)模型,從本地和全局的角度學(xué)習(xí)一個(gè)全局類(lèi)別增量模型來(lái)緩解災(zāi)難性遺忘。作者表示,這是第一次嘗試在 FL 設(shè)置中學(xué)習(xí)全局類(lèi)增量模型(a global class-incremental model)。具體來(lái)說(shuō),為了解決本地客戶(hù)端的類(lèi)別不平衡引起的本地遺忘,作者設(shè)計(jì)了一個(gè)類(lèi)別意識(shí)的梯度補(bǔ)償損失和一個(gè)類(lèi)別語(yǔ)義關(guān)系蒸餾損失,以平衡舊類(lèi)別的遺忘,并在不同任務(wù)中蒸餾出一致的類(lèi)間關(guān)系。為了解決 non-i.i.d 類(lèi)不平衡帶來(lái)的全局遺忘問(wèn)題,作者提出了一個(gè)代理服務(wù)器,選擇最佳的舊全局模型來(lái)協(xié)助本地關(guān)系蒸餾??紤]到隱私保護(hù),代理服務(wù)器通過(guò)基于原型梯度的通信機(jī)制從本地客戶(hù)端收集新類(lèi)的擾動(dòng)原型樣本,然后利用它們來(lái)監(jiān)測(cè)全局模型的性能以選擇最佳模型。本文模型在代表性的基準(zhǔn)數(shù)據(jù)集上的平均準(zhǔn)確率比 SOTA 方法高出 4.4%~15.1%。代碼已公布在 https://github.com/conditionWang/FCIL。
圖 5 描述了本文模型的概況。為了滿(mǎn)足 FCIL 的要求,本文模型通過(guò)類(lèi)別意識(shí)梯度補(bǔ)償損失和類(lèi)別語(yǔ)義關(guān)系蒸餾損失來(lái)解決本地遺忘問(wèn)題,同時(shí)通過(guò)代理服務(wù)器為本地客戶(hù)端選擇最佳舊模型來(lái)解決全局遺忘問(wèn)題。
圖 5. GLFC 模型概述。它主要由類(lèi)別意識(shí)梯度補(bǔ)償損失 L_GC 和類(lèi)別語(yǔ)義關(guān)系蒸餾損失 L_RD 組成,以克服本地的類(lèi)別不平衡造成的本地災(zāi)難性遺忘。使用代理服務(wù)器 S_P 來(lái)解決 non-ii.d. 類(lèi)別不平衡帶來(lái)的跨客戶(hù)端的全局災(zāi)難性遺忘,其中 S_P 和客戶(hù)端之間開(kāi)發(fā)了一個(gè)原型梯度通信機(jī)制用于私人通信,同時(shí)為 L_RD 選擇最佳舊的全局模型
在第 t 個(gè)增量任務(wù)中,給定第 l 個(gè)本地客戶(hù)端 S_l∈S_b 的新類(lèi)別訓(xùn)練數(shù)據(jù)和樣本存儲(chǔ)器 M_l,minibatch 的分類(lèi)損失 L_CE 為:
(1)
其中,b 是批次大小,Θ_r,t 是第 r 輪全局任務(wù)的分類(lèi)模型,由中央服務(wù)器傳送給本地客戶(hù)端。P^t_l (x_t^(l_i, Θ_r,t) ∈R^(C^p+C^t)表示通過(guò) Θ_r,t 預(yù)測(cè)的 sigmoid 概率,DCE (?,?) 是二元交叉熵?fù)p失。
如前所述,新舊類(lèi)別(T^t_l 和 M_l)本地不平衡,使得本地訓(xùn)練在舊的類(lèi)別上出現(xiàn)了明顯的性能下降(即本地災(zāi)難性遺忘)。為了防止本地遺忘,如圖 5 所示,本文為本地客戶(hù)端開(kāi)發(fā)了一個(gè)類(lèi)別意識(shí)的梯度補(bǔ)償損失和一個(gè)類(lèi)別語(yǔ)義關(guān)系蒸餾損失,它可以糾正不平衡的梯度傳播并確保跨增量任務(wù)的類(lèi)別間語(yǔ)義一致性。
- 類(lèi)別感知梯度補(bǔ)償損失:在 S_G 將 Θ^r,t 分布到本地客戶(hù)端后,本地客戶(hù)端的類(lèi)別不平衡分布導(dǎo)致 Θ^r,t 中最后輸出層的梯度反向傳播不平衡。它使得本地模型 Θ^r,t_l 的更新在本地訓(xùn)練后,在新的類(lèi)別中執(zhí)行不同的 learning paces,在舊的類(lèi)別中執(zhí)行不同的 forgetting paces。當(dāng)新的流媒體數(shù)據(jù)不斷成為舊的類(lèi)別的一部分時(shí),這種現(xiàn)象嚴(yán)重惡化了對(duì)舊類(lèi)別的本地遺忘。
對(duì)應(yīng)于此問(wèn)題,本文設(shè)計(jì)了一個(gè)類(lèi)別意識(shí)梯度補(bǔ)償損失 L_GC,通過(guò)重新加權(quán)梯度傳播,分別對(duì)新類(lèi)別的學(xué)習(xí)速度和舊類(lèi)別的遺忘速度進(jìn)行規(guī)范。具體來(lái)說(shuō),對(duì)于單個(gè)樣本(x^t_li, y^t_li),我們得到一個(gè)相對(duì)于 Θ^r,t_l 中最后輸出層的第 y^t_li 個(gè)神經(jīng)元的梯度測(cè)量 G^t_li:
(2)
為了使新類(lèi)別的學(xué)習(xí)速度和舊類(lèi)別的遺忘速度正?;?,我們對(duì)新舊類(lèi)別分別進(jìn)行梯度規(guī)范化處理,并利用它來(lái)重估 L_CE。給定一個(gè)小批次 {x^t_li, y^t_li},定義如下:
(3)
作為新舊兩類(lèi)別的梯度手段,其中 I (?) 是指標(biāo)函數(shù),如果下標(biāo)條件為真,I (True)=1;否則,I (False)=0。因此,重新加權(quán)的 L_CE 損失表述如下:
(4)
- 類(lèi)別 - 語(yǔ)義關(guān)系蒸餾損失。在初始化為當(dāng)前全局模型 Θ^r,t 的本地模型 Θ^r,t_l 的訓(xùn)練過(guò)程中,Θ^r,t_l 預(yù)測(cè)的概率表示類(lèi)間語(yǔ)義相似關(guān)系。為了確保不同增量任務(wù)之間的類(lèi)間語(yǔ)義一致性,作者通過(guò)考慮新舊類(lèi)之間的基本關(guān)系,設(shè)計(jì)了一個(gè)類(lèi)別 - 語(yǔ)義關(guān)系蒸餾損失 L_RD。如圖 5 所示,分別將一個(gè)小批次的數(shù)據(jù)集 {X^t_lb, Y^t_lb} 轉(zhuǎn)發(fā)到存儲(chǔ)的舊模型 Θ^t-1_l 和當(dāng)前的本地模型 Θ^r,t_l。這些概率反映了新舊兩類(lèi)之間的類(lèi)間關(guān)系。與現(xiàn)有的知識(shí)蒸餾策略不同的是,作者通過(guò)優(yōu)化 L_RD 同時(shí)考慮了新舊類(lèi)之間的類(lèi)間關(guān)系,只保證舊類(lèi)在 Θ^t-1_l 和 Θ^r,t_l 之間的語(yǔ)義一致性。也就是說(shuō),利用獨(dú)熱編碼標(biāo)簽 Y^t_lb 的一個(gè)變體,用 P^t-1_l (X^t_lb, Θ^t-1_l) 代替 Y^t_lb 的第一個(gè) C^p 維度,并將這個(gè)變體表示為:
由此得到 L_RD 如下:
(5)
總的來(lái)說(shuō),第 l 個(gè)本地客戶(hù)端的優(yōu)化目標(biāo)為:
(6)
任務(wù)轉(zhuǎn)移檢測(cè)。在 FCIL 中,我們沒(méi)有關(guān)于本地客戶(hù)端何時(shí)收到新類(lèi)數(shù)據(jù)的先驗(yàn)知識(shí)。為了解決這個(gè)問(wèn)題,作者考慮一個(gè)解決方案:識(shí)別訓(xùn)練數(shù)據(jù)的標(biāo)簽以前是否被觀(guān)察到過(guò)。然而,由于類(lèi)別分布的 non-i.i.d. 設(shè)置,這種方法不能確定新收到的標(biāo)簽是來(lái)自新的類(lèi)別還是其他本地客戶(hù)端觀(guān)察到的舊的類(lèi)別。另一個(gè)直觀(guān)的解決方案是使用性能下降作為收集新類(lèi)的信號(hào)。這種解決方案在 FCIL 中是不可行的,因?yàn)殡S機(jī)選擇 {So, Sb, Sn} 和它們的 non-i.i.d. 類(lèi)分布會(huì)導(dǎo)致性能急劇下降,即使沒(méi)有收到新的類(lèi)。為此,作者提出了一種任務(wù)轉(zhuǎn)移檢測(cè)機(jī)制,以準(zhǔn)確識(shí)別本地客戶(hù)端何時(shí)收到新的類(lèi)別。具體來(lái)說(shuō),在第 r 個(gè)全局輪次,每個(gè)客戶(hù)端通過(guò)收到的全局模型 Θ^r,t 對(duì)其當(dāng)前訓(xùn)練數(shù)據(jù) T^t_l 計(jì)算平均熵 H^r,t_l:
(7)
3.2 全局災(zāi)難性遺忘補(bǔ)償
雖然公式(6)可以解決本地的類(lèi)別不平衡帶來(lái)的本地災(zāi)難性遺忘,但它不能解決來(lái)自其他客戶(hù)端的異質(zhì)性遺忘(即全局災(zāi)難性遺忘)。換句話(huà)說(shuō),non-i.i.d. 類(lèi)的不平衡分布在本地客戶(hù)端上導(dǎo)致了某些舊類(lèi)的全局災(zāi)難性遺忘,進(jìn)一步惡化了本地災(zāi)難性遺忘。因此,有必要從全局角度解決不同客戶(hù)端的異質(zhì)性遺忘問(wèn)題。如前所述,公式(5)中提出的類(lèi)別 - 語(yǔ)義關(guān)系蒸餾損失 L_RD 需要存儲(chǔ)以前任務(wù)的舊分類(lèi)模型 Θ^t-1_l 來(lái)提煉類(lèi)間關(guān)系。一個(gè)較好的 Θ^t-1_l 可以在全局上增加以前任務(wù)的蒸餾收益,從全局上加強(qiáng)對(duì)舊類(lèi)的記憶。因此,Θ^t-1_l 的選擇在全局災(zāi)難性遺忘補(bǔ)償中起著重要作用,應(yīng)從全局角度考慮。
然而在 FCIL 中,由于隱私保護(hù)很難選擇最佳 Θ^t-1_l。直觀(guān)的解決方案是,每個(gè)客戶(hù)端在第(t-1)個(gè)任務(wù)期間用訓(xùn)練數(shù)據(jù) T^t-1_l 存儲(chǔ)其最佳舊模型 {Θ^t-1_l}。不幸的是,這個(gè)解決方案是從本地角度考慮選擇 Θ^t-1_l 的,不能保證所選擇的 Θ^t-1_l 對(duì)所有的舊類(lèi)都有最好的記憶,因?yàn)槊總€(gè)本地客戶(hù)端只有一個(gè)舊類(lèi)子集(non-i.i.d)。為此,作者引入一個(gè)代理服務(wù)器 S_P,從全局角度為所有客戶(hù)端選擇最佳的 Θ^t-1,如圖 5 所描述。具體來(lái)說(shuō),當(dāng)本地客戶(hù)端通過(guò)任務(wù)轉(zhuǎn)換檢測(cè)在第 t 個(gè)任務(wù)開(kāi)始時(shí)識(shí)別了新的類(lèi)(即 T^t_l),他們將通過(guò)基于原型梯度的通信機(jī)制將新類(lèi)的擾動(dòng)原型樣本傳送給 S_P。在收到這些梯度后,S_P 重建被擾動(dòng)的原型樣本,并利用它們來(lái)監(jiān)測(cè)全局模型 Θ^r,t(從 S_G 收到)的性能,直到找到最佳模型。當(dāng)步入下一個(gè)任務(wù)(t+1)時(shí),S_P 將最優(yōu) Θ^r,t 分發(fā)給本地客戶(hù)端,本地客戶(hù)端將其視為最優(yōu)舊模型來(lái)執(zhí)行 L_RD。
- 基于梯度的原型通信。給定第 l 個(gè)本地客戶(hù)端 Sl∈Sb∪Sn,該客戶(hù)端收到訓(xùn)練數(shù)據(jù) T^t_ l 的新類(lèi),Sl 通過(guò)任務(wù)轉(zhuǎn)換 detention 來(lái)識(shí)別新的類(lèi)。然后 Sl 從 T^t_l 中為每個(gè)新的類(lèi)(c = C^p_l + 1, - -, C^p_l +C^t_l)只選擇一個(gè)有代表性的原型樣本 x^t_lc?,其中 x^t_lc?的特征最接近屬于 c 類(lèi)的所有樣本在潛在特征空間的平均嵌入。然后將這些原型樣本和它們的標(biāo)簽送入 L 層梯度編碼網(wǎng)絡(luò) Γ = {Wi},計(jì)算梯度?Γ_lc。S_P 隨機(jī)處理所有從本輪全局的選定客戶(hù)端處收到的梯度,以構(gòu)建一個(gè)梯度池,假設(shè)這個(gè)池子里有 N^t_g 個(gè)梯度。這種操作可以防止 S_P 通過(guò)注釋特殊的梯度分布來(lái)追蹤某些選定的客戶(hù)端。對(duì)于?Γ^t 的第 n 個(gè)元素?Γ^t_n,我們可以通過(guò)觀(guān)察?Γ 中最后一層的梯度符號(hào)來(lái)獲得其對(duì)應(yīng)的 ground-truth 標(biāo)簽 y^t_n(有一個(gè)獨(dú)熱編碼標(biāo)簽 y^t_n)。給定一個(gè)由標(biāo)準(zhǔn)高斯 (N0,1) 初始化的假樣本 xˉ^t_n,將所有的 {xˉ^t_n, ?Γ^t_n, y^t_n} 對(duì)轉(zhuǎn)發(fā)到 Γ = {Wi},該網(wǎng)絡(luò)與本地客戶(hù)端使用的梯度編碼網(wǎng)絡(luò)相同,以恢復(fù)每個(gè)新類(lèi)的原型樣本。重建損失 L_RT 如下:
(8、9)
- 最優(yōu)舊模型的選擇。當(dāng)檢測(cè)到新的類(lèi)時(shí),S_P 只能在第 t 個(gè)任務(wù)的第一輪接收本地客戶(hù)端的梯度。然后,S_P 通過(guò)優(yōu)化公式(9)重建 N^t_g 個(gè)新類(lèi)別的原型樣本及其標(biāo)簽(即 {xˉ^t_n, y^t_n})。在第 t 個(gè)任務(wù)中,S_P 將這些重建的樣本轉(zhuǎn)發(fā)到全局模型 Θ^r,t(從 S_G 處收到),通過(guò)評(píng)估哪個(gè)模型具有最佳精度來(lái)選擇最佳 Θ^t,直到收到下一個(gè)任務(wù)的新類(lèi)的梯度。在從第二個(gè)任務(wù)開(kāi)始的每一輪全局處理過(guò)程中,S_P 將上一個(gè)任務(wù)和當(dāng)前任務(wù)的最優(yōu)模型(即 Θ^t-1 和 Θ^t),分配給所有被選中的客戶(hù)端。如果這些被選中的客戶(hù)端在第 t 個(gè)任務(wù)中檢測(cè)到 T^t+1_l 的新類(lèi),他們將把 Θ^t 設(shè)置為舊模型 Θ^t-1_l,否則,將 Θ^t-1 設(shè)置為 Θ^t-1_l 來(lái)執(zhí)行 L_RD。
擾動(dòng)的原型樣本構(gòu)建。盡管網(wǎng)絡(luò) Γ 只有 S_P 和本地客戶(hù)端可以私下訪(fǎng)問(wèn),但可以竊取 Γ 和這些梯度來(lái)重建第 l 個(gè)本地客戶(hù)端的原始原型樣本 {x^t_lc? , y^t_lc? }。為了實(shí)現(xiàn)隱私保護(hù),作者建議在這些原型樣本中加入擾動(dòng)。即使能重建原型樣本,也只能從擾動(dòng)的原型樣本中獲得很少的有用信息。給定一個(gè)原型樣本 {x^t_lc? , y^t_lc? },將其轉(zhuǎn)入通過(guò)公式(6)訓(xùn)練的本地模型 Θ^r,t_l,并應(yīng)用反向傳播來(lái)更新這個(gè)樣本。為了產(chǎn)生擾動(dòng)的原型樣本,作者在原型樣本的潛在特征中引入一個(gè)高斯噪聲,然后通過(guò)公式(11)更新 x^t_lc?:
(10)
(11)
其中,Φ(x^t_lc?) 表示 x^t_lc?的潛在特征,P^t_l (Φ(x^t_lc?)+γN (0, σ2), Θ^r,t_l) 是在 Φ(x^t_lc?) 中加入高斯噪聲 N (0, σ2) 時(shí)通過(guò) Θ^r,t_l 預(yù)測(cè)的概率。σ2 代表屬于 y^t_lc?的所有樣本的特征方差,作者根據(jù)經(jīng)驗(yàn)設(shè)定 γ=0.1,以控制本文中高斯噪聲的影響。圖 6 展示了一些重建的原型樣本。
圖 6. CIFAR-100 中原始原型樣本(上行)、擾動(dòng)原型樣本(中行)和通過(guò)代理服務(wù)器重建的原型樣本(下行)的可視化情況
3.3 實(shí)驗(yàn)情況介紹
本文在 CIFAR-100、ImageNetSubset 和 TinyImageNet 上進(jìn)行實(shí)驗(yàn),對(duì)比實(shí)驗(yàn)結(jié)果如表 3-5。其中,△表示本文模型與其他比較方法相比的改進(jìn)。我們觀(guān)察到,在 FCIL 設(shè)置中,本文模型以 4.4%~15.1% 的幅度超過(guò)了現(xiàn)有的 class-incremental 方法的平均精度。這驗(yàn)證了本文模型可以使本地客戶(hù)端協(xié)同訓(xùn)練一個(gè)全局的 class-incremental 模型。此外,與其他方法相比,本文模型在所有增量任務(wù)中都有穩(wěn)定的性能提升,這驗(yàn)證了本文模型在解決 FCIL 中遺忘問(wèn)題的有效性。
表 3. 本文模型和其他基線(xiàn)方法在 CIFAR-100 上的性能比較
表 4. 本文模型和其他基線(xiàn)方法在 ImageNet-Subset 上的性能比較
表 5. TinyImageNet 上前 10 個(gè)任務(wù)與 20 個(gè)任務(wù)的比較
此外,作者對(duì)基準(zhǔn)數(shù)據(jù)集上的各種增量任務(wù)(T=5、10、20)進(jìn)行定性分析,以驗(yàn)證 GLFC 的性能。根據(jù)這些曲線(xiàn),我們可以很容易地觀(guān)察到,在不同的任務(wù)數(shù)量(T=5,10,20)的設(shè)置下,我們的模型在所有的增量任務(wù)中都比其他基線(xiàn)方法表現(xiàn)更好。這表明 GLFC 模型能夠使多個(gè)本地客戶(hù)端以流式方式學(xué)習(xí)新的類(lèi),同時(shí)解決本地和全局遺忘問(wèn)題。如圖 7、8 所示。
圖 7. T=5(左)、T=10(中)和 T=20(右)時(shí),對(duì) CIFAR-100 上不同增量任務(wù)的定性分析
圖 8. T=5(左)、T=10(中)和 T=20(右)時(shí),ImageNet-Subset 上增量任務(wù)的定性分析
4、Learn from Others and Be Yourself in Heterogeneous Federated Learning
聯(lián)邦學(xué)習(xí)已經(jīng)成為一種重要的分布式學(xué)習(xí)范式,它通常涉及與他人的協(xié)作更新和私有數(shù)據(jù)的本地更新。然而,異質(zhì)性問(wèn)題和災(zāi)難性遺忘給聯(lián)邦學(xué)習(xí)帶來(lái)了挑戰(zhàn)。首先,由于 non-i.i.d 數(shù)據(jù)和異質(zhì)結(jié)構(gòu),模型在其他域的性能下降,并與參與者的模型存在溝通障礙。其次,在本地更新中模型是在私有數(shù)據(jù)上單獨(dú)優(yōu)化的,這很容易過(guò)度擬合當(dāng)前的數(shù)據(jù)分布,忘記以前獲得的知識(shí),導(dǎo)致災(zāi)難性的遺忘。本文提出了聯(lián)邦交叉相關(guān)和持續(xù)學(xué)習(xí)(Federated Cross-Correlation and Continual Learning,F(xiàn)CCL)。對(duì)于異質(zhì)性問(wèn)題,F(xiàn)CCL 利用未標(biāo)記的公共數(shù)據(jù)進(jìn)行通信,并構(gòu)建交叉相關(guān)矩陣來(lái)學(xué)習(xí)域偏移的可泛化表示。同時(shí),對(duì)于災(zāi)難性遺忘,F(xiàn)CCL 利用本地更新中的知識(shí)提煉,在不泄露隱私的情況下提供域間和域內(nèi)信息。作者通過(guò)各種圖像分類(lèi)任務(wù)的實(shí)證結(jié)果證明了本文方法的有效性和模塊的效率。
按照標(biāo)準(zhǔn)的聯(lián)邦學(xué)習(xí)設(shè)置,有 K 個(gè)參與者(以 i 為索引),每個(gè)參與者都有一個(gè)本地模型 θ_i 和私有數(shù)據(jù) D_i = {(X_i,Y_i)|X_i∈R^(Ni×D), Y_i∈R^(Ni×C) },其中,N_i 表示私有數(shù)據(jù)的數(shù)量,D 表示輸入大小,C 定義為分類(lèi)的類(lèi)別數(shù)量。同時(shí),私有數(shù)據(jù)分布表示為 P_i (X, Y),并改寫(xiě)為 P_i (X|Y ) P_i (Y)。此外,在異質(zhì)聯(lián)邦學(xué)習(xí)中,數(shù)據(jù)的異質(zhì)性和模型的異質(zhì)性定義如下:
- 數(shù)據(jù)的異質(zhì)性。P_i (X|Y )≠P_j (X|Y )。私有數(shù)據(jù)之間存在域偏移,即私有數(shù)據(jù)的條件分布 P (X|Y) 在不同的參與者中是不同的,即使 P (Y) 是共享的。具體來(lái)說(shuō),同一個(gè)標(biāo)簽 Y 在不同域有不同的特征 X。
- 模型的異質(zhì)性:Shape (θ_i) ≠ Shape (θ_j )。參與者獨(dú)立定制模型,即對(duì)于分類(lèi)任務(wù),所選擇的骨干網(wǎng)(如 ResNet、EfficientNet 和 MobileNet)是不同的,具有不同的分類(lèi)器模型。
作者利用無(wú)標(biāo)簽的公共數(shù)據(jù) D_0={X_0|X_0∈R^(N0×D)} 來(lái)實(shí)現(xiàn)通信。公共數(shù)據(jù)在實(shí)際場(chǎng)景中相對(duì)容易獲取。第一名參與者的目標(biāo)是達(dá)成溝通,并學(xué)習(xí)具有可概括性的模型 θ_i。此外,考慮到災(zāi)難性的問(wèn)題,θ_k 需要呈現(xiàn)更高更穩(wěn)定的域間和域內(nèi)性能。本文方法的框架如圖 9 所示。具體來(lái)說(shuō),在協(xié)作更新中,作者測(cè)量未標(biāo)記的公共數(shù)據(jù)上輸出的 Logits 之間的交叉相關(guān)矩陣,以實(shí)現(xiàn)相似性和減少冗余。同時(shí),在本地更新中,通過(guò)知識(shí)蒸餾不斷平衡多域的信息。
圖 9. FCCL 示例。(a) 本文方法的簡(jiǎn)化示意圖,該方法通過(guò)聯(lián)邦交叉相關(guān)學(xué)習(xí)和聯(lián)邦持續(xù)學(xué)習(xí)解決了異質(zhì)性問(wèn)題和災(zāi)難性遺忘;(b) 聯(lián)邦交叉相關(guān)學(xué)習(xí);(c) 聯(lián)邦持續(xù)學(xué)習(xí)。梯度顏色比例反映了其他參與者的影響程度
4.1 聯(lián)邦交叉相關(guān)學(xué)習(xí)
維度級(jí)(Dimension-Level)操作的啟發(fā)。受信息瓶頸進(jìn)行自監(jiān)督學(xué)習(xí)的成功經(jīng)驗(yàn)啟發(fā),作者提出:一個(gè)可概括的表征應(yīng)該盡可能地提供關(guān)于圖像的信息,同時(shí)盡可能地對(duì)應(yīng)用于該樣本的特定域的變換不產(chǎn)生影響。在本文工作中,域偏移導(dǎo)致同一標(biāo)簽 Y 在不同域有不同的特征 X。因此,不同域的 logits 輸出沿批次維度的分布是不一樣的。此外,不同維度的 logits 輸出對(duì)應(yīng)著不同的類(lèi)別。因此,我們需要鼓勵(lì)相同維度的不變性和不同維度的多樣性。私有數(shù)據(jù)帶有特定的域信息,并且受到隱私保護(hù),這對(duì)于進(jìn)行自監(jiān)督學(xué)習(xí)是不合適的,也是不可行的。因此,我們利用未標(biāo)記的公共數(shù)據(jù),這些數(shù)據(jù)通常是從多個(gè)域產(chǎn)生和收集的,而且很容易獲得。我們通過(guò)要求 logits 輸出不受域偏移的影響,以及在無(wú)標(biāo)簽的公共數(shù)據(jù)上對(duì) logits 輸出的不同維度進(jìn)行修飾來(lái)優(yōu)化私有模型。
交叉相關(guān)矩陣的構(gòu)建。具體來(lái)說(shuō),我們得到第 i 個(gè)參與者的 logits 輸出。Z_i =f (θ_i, X_0) ∈R^(N_0×C) 。對(duì)于第 i 個(gè)和第 j 個(gè)參與者,在未標(biāo)記的公共數(shù)據(jù)上的 logits 輸出為 Z_i 和 Z_j。值得注意的是,考慮到中央服務(wù)器端的計(jì)算負(fù)擔(dān),我們計(jì)算 average Logits 輸出:
然后,計(jì)算交叉相關(guān)矩陣,第 i 個(gè)參與者的 average logits 輸出為 M_i:
(12)
其中,b 指的是批次樣本,u、v 指的是 logits 輸出的維度,||?|| 是沿批次維度進(jìn)行的歸一化操作。M_i 是一個(gè)正方形矩陣,大小為輸出維度 C,數(shù)值在 - 1(即不相似)和 1(即相似)之間。那么,第 i 個(gè)參與者的協(xié)作損失定義為:
(13)
其中,λ_Col 是一個(gè)正常數(shù),用來(lái)權(quán)衡損失的第一項(xiàng)和第二項(xiàng)的重要性。當(dāng)交叉相關(guān)矩陣的對(duì)角線(xiàn)項(xiàng)取值為 + 1 時(shí),它鼓勵(lì)不同參與者的 logits 輸出相似。當(dāng)交叉相關(guān)矩陣的對(duì)角線(xiàn)項(xiàng)取值為 - 1 時(shí),它鼓勵(lì) logits 輸出的多樣性,因?yàn)檫@些 logits 輸出的不同維度是彼此不相關(guān)的。
4.2 聯(lián)邦持續(xù)相關(guān)學(xué)習(xí)
典型的監(jiān)督損失。對(duì)于聯(lián)邦學(xué)習(xí)中的本地更新,目前的方法通常將這個(gè)過(guò)程作為一個(gè)監(jiān)督分類(lèi)問(wèn)題。具體來(lái)說(shuō),在第 t 輪通信中,協(xié)作更新后,將第 i 個(gè)私有模型定義為 (θ^t,im)_i。然后,在固定 epoch 的私有數(shù)據(jù) D_i (X_i, Y_i) 上優(yōu)化 (θ^t,im)_i。給定如下 logits 輸出:
用 softmax 對(duì)交叉熵?fù)p失進(jìn)行優(yōu)化:
(14)
這樣的訓(xùn)練目標(biāo)設(shè)計(jì)可能會(huì)面臨災(zāi)難性遺忘的問(wèn)題,主要是由于以下兩個(gè)限制:1)在本地更新中,如果沒(méi)有其他參與者的監(jiān)督,模型很容易過(guò)度擬合當(dāng)前的數(shù)據(jù)分布,呈現(xiàn)出糟糕的域間性能。2)它只對(duì)預(yù)測(cè)進(jìn)行獨(dú)立的先驗(yàn)概率懲罰,這提供了有限的和 hard 的域內(nèi)信息。
雙域知識(shí)蒸餾損失。作者開(kāi)發(fā)了一種聯(lián)邦持續(xù)學(xué)習(xí)方法,通過(guò)從模型方面對(duì)目標(biāo)進(jìn)行正則化來(lái)解決 1)和 2)的問(wèn)題。具體來(lái)說(shuō),在第 t-1 輪訓(xùn)練結(jié)束時(shí),更新的模型 (θ^t-1)_i 包含了從其他參與者學(xué)到的知識(shí)。在私有數(shù)據(jù)上計(jì)算 logits 輸出如下:
域內(nèi)知識(shí)蒸餾的損失定義為:
(15)
其中,σ 表示 softmax 函數(shù)。如公式(15),其目的是在保護(hù)隱私的同時(shí)不斷向他人學(xué)習(xí),從而保證域間性能,并處理聯(lián)邦學(xué)習(xí)中的災(zāi)難性遺忘問(wèn)題。此外,對(duì)于第 i 個(gè)參與者,在私有數(shù)據(jù)上預(yù)訓(xùn)練一個(gè)模型 (θ^?)_i 是可行的。給定如下私有數(shù)據(jù):
域內(nèi)知識(shí)蒸餾損失定義為:
(16)
帶有預(yù)訓(xùn)練模型的知識(shí)蒸餾提供了 soft 而豐富的域內(nèi)信息。此外,它與公式(14)中的典型監(jiān)督損失(即交叉熵?fù)p失)合作,提供 soft 和 hard 的域內(nèi)信息以確保域內(nèi)性能。在某種程度上,上述兩個(gè)模型(即更新的模型 (θ^t-1)_i 和預(yù)訓(xùn)練的模型 (θ^?)_i)分別代表了 "教師" 之間和內(nèi)部的模型。通過(guò)知識(shí)蒸餾,平衡來(lái)自他人和自身的知識(shí),同時(shí)提升了域間和域內(nèi)的性能。雙域知識(shí)蒸餾的計(jì)算方法為:
(17)
公式(14)中的典型監(jiān)督損失和公式(17)中的雙域知識(shí)蒸餾損失是相互補(bǔ)充的。前者要求模型學(xué)習(xí)對(duì)分類(lèi)任務(wù)有意義的鑒別性表征,而后者則有助于用域內(nèi)和域間的 soft 豐富信息使模型規(guī)范化。因此,總的訓(xùn)練目標(biāo)是:
完整的 FCCL 流程如下:
4.3 實(shí)驗(yàn)情況介紹
作者在兩個(gè)分類(lèi)任務(wù)(如 Digits 和 Office-Home)和三個(gè)公共數(shù)據(jù)集(如 Cifar-100、ImageNet 和 Fashion-MNIST)上廣泛地評(píng)估了本文方法。具體來(lái)說(shuō),Digits 任務(wù)包括四個(gè)域(MNIST(M)、USPS(U)、SVHN(SV)和 SYN(SY)),共有 10 個(gè)類(lèi)別。Office-Home 任務(wù)也有四個(gè)域(藝術(shù)(A)、剪貼畫(huà)(C)、產(chǎn)品(P)和現(xiàn)實(shí)世界(R))。請(qǐng)注意,對(duì)于這兩項(xiàng)任務(wù)來(lái)說(shuō),從不同域獲得的數(shù)據(jù)呈現(xiàn)出域偏移(數(shù)據(jù)異質(zhì)性)特性。對(duì)于這兩個(gè)分類(lèi)任務(wù),參與者定制的模型可以從差異化的骨干網(wǎng)和分類(lèi)器中獲得差異(模型異質(zhì)性)。在實(shí)驗(yàn)中,作者將這四個(gè)域的模型設(shè)置為 ResNet、EfficientNet、MobileNet 和 GoogLeNet。作者將 FCCL 與最先進(jìn)的方法,包括 FedDF、FML、FedMD、RCFL 和 FedMatch,進(jìn)行比較。此外,還比較了 SOLO,即參與者在沒(méi)有聯(lián)邦學(xué)習(xí)的情況下在私有數(shù)據(jù)上訓(xùn)練一個(gè)模型。由于具體的實(shí)驗(yàn)設(shè)置并不完全一致,作者保留了一些方法的關(guān)鍵特征進(jìn)行比較。
評(píng)價(jià)指標(biāo)。作者報(bào)告了衡量方法質(zhì)量的標(biāo)準(zhǔn)指標(biāo):準(zhǔn)確性,將其定義為成對(duì)的樣本數(shù)除以樣本數(shù)。具體來(lái)說(shuō),為了評(píng)價(jià)域內(nèi)和域間的性能,定義如下指標(biāo):
域間分析。表 6 報(bào)告了不同方法的域間性能。在域偏移的情況下,SOLO 在這兩個(gè)任務(wù)中表現(xiàn)得最差。我們觀(guān)察到,F(xiàn)CCL 的表現(xiàn)明顯優(yōu)于其他同類(lèi)方法。圖 10 顯示,F(xiàn)CCL 在參與者之間實(shí)現(xiàn)了相似的 logits 輸出,并在 logits 輸出中實(shí)現(xiàn)了冗余,這證實(shí)了 FCCL 在公共和私有數(shù)據(jù)上成功地執(zhí)行了相同維度的相關(guān)性和不同維度的去相關(guān)性。
表 6. 域間性能與最先進(jìn)方法的比較。M→表示私有數(shù)據(jù)是 MNIST,各自的模型在其他域測(cè)試。AVG 表示從每個(gè)域計(jì)算出的平均精度
圖 10. 用 Cifar-100 對(duì)不同域的數(shù)字任務(wù)進(jìn)行交叉相關(guān)矩陣的可視化分析
域內(nèi)分析。為了比較緩解災(zāi)難性遺忘的效果,表 7 報(bào)告了不同方法的域內(nèi)性能。以 Cifar-100 的 Digits 任務(wù)為例,本文方法比 RCFL 要好 2.30%。此外,圖 11a 中通過(guò)增加通信輪次的域內(nèi)精度和圖 11b 中的優(yōu)化目標(biāo)值表明,F(xiàn)CCL 受到的周期性性能沖擊較小,而且不容易對(duì)當(dāng)前的數(shù)據(jù)分布進(jìn)行過(guò)擬合(L^Loc=0.0225),說(shuō)明 FCCL 能夠平衡多種知識(shí),緩解災(zāi)難性遺忘。
表 7. 用 Cifar-100 在這兩項(xiàng)任務(wù)上與最先進(jìn)的方法進(jìn)行域內(nèi)性能比較
圖 11. 用 Cifar-100 通過(guò)增加通信輪次對(duì)數(shù)字任務(wù)進(jìn)行本地更新時(shí)的域內(nèi)性能和優(yōu)化目標(biāo)值的比較
模型同質(zhì)性分析。作者進(jìn)一步將 FCCL 與其他方法進(jìn)行模型同質(zhì)性比較。將共享模型設(shè)定為 ResNet-18,并在協(xié)作更新和本地更新之間添加平均參數(shù)操作。表 8 列出了域間和域內(nèi)的數(shù)據(jù),展示了 Cifar-100 在 Office-Home 任務(wù)中的域間和域內(nèi)性能。
表 8. 用 Cifar-100 在 Office-Home 任務(wù)中與最先進(jìn)的方法進(jìn)行比較
5、FedSpeech: Federated Text-to-Speech with Continual Learning
聯(lián)邦學(xué)習(xí)可以在嚴(yán)格的隱私限制下對(duì)機(jī)器學(xué)習(xí)模型進(jìn)行協(xié)作訓(xùn)練,而聯(lián)邦文本到語(yǔ)音的應(yīng)用目的是利用存儲(chǔ)在本地設(shè)備中的少量音頻訓(xùn)練樣本合成多個(gè)用戶(hù)的自然語(yǔ)音。然而,聯(lián)邦文本到語(yǔ)音面臨著幾個(gè)挑戰(zhàn):每個(gè)說(shuō)話(huà)人的訓(xùn)練樣本很少,訓(xùn)練樣本都存儲(chǔ)在每個(gè)用戶(hù)的本地設(shè)備中,而且全局模型容易受到各種打擊。本文提出了一個(gè)新穎的聯(lián)邦學(xué)習(xí)架構(gòu),稱(chēng)為聯(lián)邦多語(yǔ)者文本到語(yǔ)音 TTS 系統(tǒng) Fed-Speech,基于持續(xù)學(xué)習(xí)方法來(lái)克服上述困難。具體如下:1)通過(guò)選擇性掩碼,F(xiàn)edSpeech 可以有效地從協(xié)作訓(xùn)練中獲益,以減少有限訓(xùn)練數(shù)據(jù)的影響。2)使用漸進(jìn)式修剪掩碼來(lái)分離不同說(shuō)話(huà)人的參數(shù),以克服災(zāi)難性遺忘問(wèn)題。因此,F(xiàn)edSpeech 避免了所有說(shuō)話(huà)人的語(yǔ)調(diào)變化問(wèn)題。3) 引入私有說(shuō)話(huà)人嵌入,加上上述兩類(lèi)掩碼,以保護(hù)隱私并避免對(duì)說(shuō)話(huà)人的各種打擊。在縮小的 VCTK 數(shù)據(jù)集上的實(shí)驗(yàn)(每個(gè)說(shuō)話(huà)人的訓(xùn)練集減少到四分之一,以模擬低資源的語(yǔ)言場(chǎng)景)表明,F(xiàn)edSpeech 在語(yǔ)音質(zhì)量方面幾乎可以與上限、多任務(wù)訓(xùn)練相匹配,甚至在說(shuō)話(huà)人相似性實(shí)驗(yàn)中明顯優(yōu)于所有系統(tǒng)。
5.1 模型結(jié)構(gòu)
FedSpeech 的整體模型結(jié)構(gòu)如圖 12 所示。編碼器將音素嵌入序列轉(zhuǎn)換為音素隱序列,然后在隱序列中加入不同的差異信息,如持續(xù)時(shí)間和語(yǔ)調(diào),最后由旋律譜解碼器將適應(yīng)的隱序列轉(zhuǎn)換為旋律譜序列。采用前饋轉(zhuǎn)化器(Feed-forward Transformer)模塊,它是 FastSpeech 中自注意力層和 1D-convolution 前饋網(wǎng)絡(luò)的疊加,作為編碼器和旋律譜圖解碼器的基本結(jié)構(gòu)。此外,還采用了一個(gè)音高預(yù)測(cè)器和一個(gè)持續(xù)時(shí)間預(yù)測(cè)器來(lái)引入更多的信息。每個(gè)網(wǎng)絡(luò)都包括一個(gè)具有 ReLU 激活的 2 層一維卷積網(wǎng)絡(luò),然后是層歸一化和 DropOut 層,以及一個(gè)額外的線(xiàn)性層,將隱狀態(tài)投射到輸出序列中。在訓(xùn)練階段,將從錄音中提取的持續(xù)時(shí)間和語(yǔ)調(diào)的真實(shí)值作為輸入到隱序列中以預(yù)測(cè)目標(biāo)語(yǔ)音。同時(shí),用真實(shí)的時(shí)長(zhǎng)和音高值作為目標(biāo)來(lái)訓(xùn)練預(yù)測(cè)器。使用這些輸出進(jìn)行推理,以合成目標(biāo)語(yǔ)音。
圖 12. FedSpeech 的整體架構(gòu)。+ 表示元素相加的操作
為了通過(guò)從潛在空間估計(jì)說(shuō)話(huà)人的特征來(lái)控制語(yǔ)音并保護(hù)隱私,作者引入了一個(gè)私有說(shuō)話(huà)人模塊,它是一個(gè)可訓(xùn)練的查找表,將說(shuō)話(huà)人的身份號(hào)碼 S_id 作為輸入,并生成說(shuō)話(huà)人表示 R={r_1, r_2, ..., r_n},其中 n 是模型的隱空間大小。然后,將說(shuō)話(huà)人表示 R 傳遞到編碼器的輸出端,作為額外的關(guān)鍵信息來(lái)控制訓(xùn)練和推理中的語(yǔ)調(diào)特征??紤]到隱私問(wèn)題,每個(gè)說(shuō)話(huà)人都會(huì)訓(xùn)練并保持自己的模塊參數(shù)集,這樣其他人即使用他的 S_id 也無(wú)法合成他的聲音。
圖 13. 用 FedSpeech 進(jìn)行的兩輪訓(xùn)練過(guò)程。在第一輪中,逐步修剪掩碼以隔離每個(gè)說(shuō)話(huà)人的權(quán)重。如果為某個(gè)說(shuō)話(huà)人保留的權(quán)重小于閾值,模型就會(huì)擴(kuò)大。在第二輪,以 speaker 2 為例。選擇性掩碼的訓(xùn)練是為了重新利用為其他說(shuō)話(huà)人保留的權(quán)重中的知識(shí)
本文未采用聯(lián)邦聚合訓(xùn)練的方法,因?yàn)檫@種方法存在災(zāi)難性遺忘問(wèn)題。如圖 13 所示,作者采用了持續(xù)學(xué)習(xí)中常用的連續(xù)訓(xùn)練設(shè)置。在經(jīng)典設(shè)置的基礎(chǔ)上,本文提出了兩輪的順序訓(xùn)練。在第一輪訓(xùn)練中,模型分別學(xué)習(xí)并固定每個(gè)說(shuō)話(huà)人的一部分權(quán)重,這樣在第二輪訓(xùn)練中,可以有選擇地重用前一個(gè)和后一個(gè)說(shuō)話(huà)人的知識(shí)。
具體來(lái)說(shuō),在第一輪訓(xùn)練中,計(jì)算得到圖 13 中的漸進(jìn)修剪掩碼以隔離每個(gè)說(shuō)話(huà)人的參數(shù)。將從 1 到 N 的 speaker 表示為 S_1:N。S_1:N 的任務(wù)表示為 T_1:N。以 S_t 為例。當(dāng) T_t 開(kāi)始時(shí),首先將全局模型 M_g 發(fā)送給 S_t,并使用他的私有數(shù)據(jù)進(jìn)行訓(xùn)練,直到收斂。將第 i 層的學(xué)習(xí)權(quán)重矩陣表示為 (W^l_i)_1。然后,逐漸修剪每層 (W^l_i)_1 中最小的權(quán)重的一部分,將其設(shè)置為 0,并重新訓(xùn)練其他權(quán)重以恢復(fù)性能。最后,將權(quán)重劃分為三部分:1)后來(lái)的 speaker S_t+1:N 釋放的零值權(quán)重;2)由以前的 speaker S_1:t-1 保留的固定權(quán)重 (W^1:t-1)_S;3)由 S_t 保留的權(quán)重 W^t_S。如果后來(lái)的 speaker S_t+1:N 的釋放權(quán)重小于閾值 λ,將模型的隱藏大小擴(kuò)展為 μ。修剪狀態(tài)存儲(chǔ)在漸進(jìn)修剪掩碼中,表示為 m_p。然后固定 W^t_S,并將 m_p 和 M_g(除了 private speaker 模塊)發(fā)送到下一個(gè) speaker S_t+1 的設(shè)備上,繼續(xù)進(jìn)行順序訓(xùn)練。當(dāng)?shù)谝惠喗Y(jié)束時(shí),每個(gè)說(shuō)話(huà)人都保留了某一部分權(quán)重,表示為 (W^1:N)_S,由 m_p 表示。由于每個(gè)任務(wù)的權(quán)重都是固定的,每個(gè)說(shuō)話(huà)人都可以在推理中完美地保留他們的語(yǔ)調(diào)。最后,將 m_p 和 Mg 發(fā)送到 S_1:N 的設(shè)備上。因此,每個(gè)說(shuō)話(huà)人都有 m_p、M_g 和他所保留的 private speaker 模塊的參數(shù)。
在第二輪訓(xùn)練中,引入選擇性掩碼來(lái)轉(zhuǎn)移說(shuō)話(huà)者的知識(shí),以解決數(shù)據(jù)稀缺的問(wèn)題。將圖 13 中的選擇性掩碼訓(xùn)練為自動(dòng)選擇說(shuō)話(huà)人保留的有用權(quán)重。作者提出了一個(gè)修改后的選擇程序,從所有任務(wù)中選擇權(quán)重,這對(duì) federated TTS 任務(wù)中的每個(gè) speaker(特別是對(duì)更多的 previous speakers)都是更公平的。對(duì)于一個(gè)特定的說(shuō)話(huà)人 S_t,兩輪訓(xùn)練放棄了 W^t_S 和選擇性掩碼的聯(lián)邦訓(xùn)練,這導(dǎo)致了輕微的性能下降。但是對(duì)于每個(gè) speaker,我們使其有可能從之前和之后的任務(wù)中選擇權(quán)重,從而在整體上顯著改善了性能。
假設(shè)當(dāng)?shù)谝惠喗Y(jié)束時(shí),M_g 的權(quán)重分成幾個(gè)部分 (W^1:N)_S,這些部分被 S_1:N 保存。為了在保持隱私的同時(shí)從協(xié)作訓(xùn)練中獲益,作者引入了一個(gè)可學(xué)習(xí)的掩碼 m_b∈{0,1} 來(lái)轉(zhuǎn)移由其他說(shuō)話(huà)人保留的參數(shù)的知識(shí)。本文使用 piggyback 方法,學(xué)習(xí)一個(gè)實(shí)值掩碼 m_s,并應(yīng)用一個(gè)閾值進(jìn)行二值化處理以構(gòu)建 m_b。對(duì)于某個(gè)說(shuō)話(huà)人 S_t 來(lái)說(shuō),掩碼 (m^t)_b 是在他的本地?cái)?shù)據(jù)集上訓(xùn)練出來(lái)的,通過(guò)以下方式從其他說(shuō)話(huà)人位置選擇權(quán)重:
以一維卷積層中的選擇性掩碼的訓(xùn)練過(guò)程為例進(jìn)行描述。在任務(wù) t,M_g(即 W^1:N_S)是固定的。將二進(jìn)制掩碼表示為 m^t_b。那么,輸入 - 輸出關(guān)系的方程為:
在反向傳播過(guò)程中,m^t_b 是不可分的。所以引入實(shí)值的選擇性掩碼,表示為 (m^t)_s。將 σ 表示為選擇的閾值。在訓(xùn)練二進(jìn)制掩碼 (m^t)_b 時(shí),在后向傳遞中更新實(shí)值掩碼 (m^t)_s;然后用應(yīng)用于 (m^t)_s 的二進(jìn)制函數(shù) β 對(duì) (m^t)_b 進(jìn)行量化,并在前向傳遞中使用。訓(xùn)練結(jié)束后,丟棄 (m^t)_s,只存儲(chǔ) (m^t)_b 用于推理。將 m^t_s 的方程表述為:
為了簡(jiǎn)單起見(jiàn),作者用 S_t 的例子來(lái)描述推理階段?,F(xiàn)在 S_t 有 m_p、(m^t)_b、M_g 和本地保存的說(shuō)話(huà)人模塊的參數(shù)。使用 m_p 挑選權(quán)重 W^t_S,并使用 (m^t)_b 選擇性地重復(fù)使用 (W^1:t-1)_S∪(W^t+1:N)_S 中的權(quán)重。為了不傷害 S_t 的語(yǔ)調(diào),將未使用的權(quán)重固定為零。用 FedSpeech 進(jìn)行的兩輪訓(xùn)練的總體過(guò)程見(jiàn)算法 1。
5.2 實(shí)驗(yàn)結(jié)果分析
作者在 VCTK 數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn),該數(shù)據(jù)集包含了約 44 小時(shí)的語(yǔ)音,由 109 位具有不同口音的英語(yǔ)母語(yǔ)者說(shuō)出來(lái)。每個(gè)說(shuō)話(huà)人讀出了約 400 個(gè)句子,其中大部分是從報(bào)紙上選出來(lái)的,再加上《彩虹傳》和一個(gè)旨在識(shí)別說(shuō)話(huà)人口音的 elicitation 段落。為了模擬低資源語(yǔ)言場(chǎng)景,隨機(jī)選擇并將每個(gè)說(shuō)話(huà)人的樣本分成 3 組:100 個(gè)樣本用于訓(xùn)練,20 個(gè)樣本用于驗(yàn)證,20 個(gè)樣本用于測(cè)試。作者隨機(jī)選擇了 10 位 speaker,分別表示為任務(wù) 1 至 10,進(jìn)行評(píng)估。為了緩解發(fā)音錯(cuò)誤的問(wèn)題,作者用一個(gè)開(kāi)源的字母到音素的轉(zhuǎn)換工具將文本序列轉(zhuǎn)換成音素序列。作者將原始波形轉(zhuǎn)換為 mel-spectrograms,并將幀大小和跳躍大小設(shè)置為 1024 和 256,采樣率為 22050。
作者在測(cè)試集上評(píng)估 MOS(mean opinion score)來(lái)衡量音頻質(zhì)量。不同模型之間的設(shè)置和文本內(nèi)容是一致的,以排除其他干擾因素,只考察音頻質(zhì)量。每個(gè)音頻都由 10 個(gè)英語(yǔ)為母語(yǔ)的人進(jìn)行評(píng)判。作者將本文模型生成的音頻樣本的 MOS 與其他系統(tǒng)進(jìn)行比較,其中包括:1)GT,VCTK 中的 ground truth 音頻。2) GT (Mel + PWG), 首先將 ground-truth 音頻轉(zhuǎn)換為 Mel-spectrograms, 然后使用 ParallelWaveGAN (PWG) 將 Mel-spectrograms 轉(zhuǎn)換為音頻;3) Multi-task, 無(wú)隱私限制的聯(lián)邦訓(xùn)練;4) Scratch, 從頭開(kāi)始獨(dú)立學(xué)習(xí)每個(gè)任務(wù);5) Finetune, 從隨機(jī)選擇的前一個(gè)模型進(jìn)行微調(diào)并重復(fù) 5 次(對(duì)于任務(wù) 1, Finetune 等同于 Scratch)。6)FedAvg,聚集本地信息(如梯度或模型參數(shù))并訓(xùn)練一個(gè)全局模型。7)CPG,一種用于持續(xù)學(xué)習(xí)的參數(shù)隔離方法。作者把 3)表示為上界,其他的表示為基線(xiàn)。相應(yīng)地,3)、4)、5)、6)、7) 和 FedSpeech 中的所有系統(tǒng)都使用預(yù)先訓(xùn)練好的 PWG 作為聲碼器進(jìn)行公平比較。MOS 結(jié)果顯示在表 9 中。從表中我們可以看出,與所有基線(xiàn)相比,F(xiàn)edSpeech 取得了最高的 MOS。值得一提的是 FedSpeech 的表現(xiàn)優(yōu)于 CPG,這說(shuō)明了有選擇地重用以前和以后的 speaker 的知識(shí)的有效性。此外,F(xiàn)edAvg 的結(jié)果明顯比其他方法差,這意味著來(lái)自其他 speaker 的梯度極大地影響了每個(gè) speaker 的語(yǔ)氣。此外,F(xiàn)edSpeech 在 VCTK 上的 MOS 值接近于多任務(wù)訓(xùn)練(上限)。這些結(jié)果證明了 FedSpeech 在聯(lián)邦多語(yǔ)者 TTS 任務(wù)中的優(yōu)勢(shì)。
表 9. MOS 與 95% 的置信區(qū)間。
作者在測(cè)試集上進(jìn)行說(shuō)話(huà)人相似度評(píng)估,以衡量合成音頻和 ground-truth 音頻之間的相似度。為了排除其他干擾因素,作者在不同的模型中保持文本內(nèi)容的一致性。對(duì)于每項(xiàng)任務(wù),作者利用編碼器推導(dǎo)出總結(jié)說(shuō)話(huà)人聲音特征的高級(jí)表示向量。具體來(lái)說(shuō),編碼器是一個(gè)帶有投影的 3 層 LSTM,它為提取說(shuō)話(huà)人的語(yǔ)調(diào)嵌入進(jìn)行了預(yù)訓(xùn)練。余弦相似度是衡量說(shuō)話(huà)人表述向量相似度的標(biāo)準(zhǔn),其定義為 cos sim (A, B) =A - B/kAk kBk。其結(jié)果范圍從 - 1 到 1,數(shù)值越大,說(shuō)明向量越相似。作者計(jì)算合成音頻的說(shuō)話(huà)人表示向量和 ground-truth 音頻之間的余弦相似度作為評(píng)價(jià)標(biāo)準(zhǔn)。
最終實(shí)驗(yàn)結(jié)果顯示在表 10 中。FedSpeech 的平均得分最高,甚至高于多任務(wù)的上限。這意味著 FedSpeech 可以在推理階段更好地保留每個(gè)說(shuō)話(huà)人的聲音,并證明了參數(shù)隔離的有效性。此外,在任務(wù) 1 中,F(xiàn)edSpeech 的結(jié)果明顯高于 CPG??梢钥闯?,有選擇地重用前一個(gè)和后一個(gè)說(shuō)話(huà)人的知識(shí)給說(shuō)話(huà)人帶來(lái)了很大的好處,因此,在聯(lián)邦多語(yǔ)者 TTS 任務(wù)中,所有的說(shuō)話(huà)人都能獲得更好的聲音。
表 10. 基線(xiàn)和 FedSpeech 之間說(shuō)話(huà)人相似度的比較。平均值是指 10 個(gè)任務(wù)的平均值,γ 是指與 256 個(gè)隱大小的 FedSpeech 相比的模型擴(kuò)展率
為了測(cè)量音頻質(zhì)量,作者進(jìn)行了 MOS 評(píng)估,每個(gè)音頻由 10 個(gè)英語(yǔ)母語(yǔ)者進(jìn)行評(píng)判。如表 11 所示,去除漸進(jìn)修剪掩碼或去除選擇性掩碼都不會(huì)導(dǎo)致明顯的質(zhì)量下降,這意味著選擇性掩碼有能力自動(dòng)選擇漸進(jìn)修剪掩碼所保留的權(quán)重。然而,去除這兩種類(lèi)型的掩碼會(huì)導(dǎo)致災(zāi)難性的質(zhì)量下降。此外,作者還進(jìn)行了說(shuō)話(huà)人相似性評(píng)估。如表 11 所示,稍微去除這些選擇性掩碼或漸進(jìn)修剪掩碼會(huì)導(dǎo)致輕微的性能下降,而去除這兩個(gè)掩碼則會(huì)導(dǎo)致災(zāi)難性的下降??梢钥闯觯瑵u進(jìn)修剪掩碼完美地保留了每個(gè)說(shuō)話(huà)人的語(yǔ)調(diào)。此外,選擇性掩碼有能力自動(dòng)選擇漸進(jìn)修剪掩碼所保留的權(quán)重,將它們結(jié)合起來(lái)會(huì)導(dǎo)致更好的結(jié)果。
表 11. 在消融實(shí)驗(yàn)中 MOS 和說(shuō)話(huà)人相似度的比較。SM 指的是選擇性掩碼,GPM 指的是漸進(jìn)修剪掩碼,相似度是余弦相似度
對(duì)于未來(lái)的工作,作者提出將繼續(xù)提高合成語(yǔ)音的質(zhì)量,并提出新的掩碼策略來(lái)壓縮模型和加快訓(xùn)練速度。此外,他們還將把 FedSpeech 應(yīng)用于 zero-shot 的 multi-speaker 設(shè)置,通過(guò)使用 private speaker 模塊來(lái)生成掩碼。
6、本文小結(jié)
在這篇文章中,我們淺析了四篇聯(lián)邦連續(xù)學(xué)習(xí)相關(guān)的最新論文。這四篇文章的重點(diǎn)都是解決聯(lián)邦學(xué)習(xí)框架下不同客戶(hù)端間相互干擾的問(wèn)題,具體選擇了將參數(shù)分解為全局參數(shù)和本地參數(shù)、著重考慮設(shè)備上存儲(chǔ)數(shù)據(jù)的類(lèi)別意識(shí)和類(lèi)別語(yǔ)義、增加知識(shí)蒸餾以平衡不同數(shù)據(jù)域關(guān)系等方法,在論文給出的場(chǎng)景中都獲得不錯(cuò)的效果。不過(guò),這些文章都沒(méi)有著重分析這種方法可能造成的通信代價(jià)。關(guān)于聯(lián)邦持續(xù)學(xué)習(xí)的實(shí)用性還有待更進(jìn)一步的研究,以更好的滿(mǎn)足當(dāng)前數(shù)據(jù)隱私保護(hù)高要求條件下的數(shù)據(jù)分析和應(yīng)用需求。