持續(xù)學(xué)習(xí)是指在不忘記從前面的任務(wù)中獲得的知識的情況下,按順序?qū)W習(xí)大量任務(wù)的模型。這是一個(gè)重要的概念,因?yàn)樵诒O(jiān)督學(xué)習(xí)的前提下,機(jī)器學(xué)習(xí)模型被訓(xùn)練為針對給定數(shù)據(jù)集或數(shù)據(jù)分布的最佳函數(shù)。而在現(xiàn)實(shí)環(huán)境中,數(shù)據(jù)很少是靜態(tài)的,可能會(huì)發(fā)生變化。當(dāng)面對不可見的數(shù)據(jù)時(shí),典型的ML模型可能會(huì)性能下降。這種現(xiàn)象被稱為災(zāi)難性遺忘。

解決這類問題的常用方法是在包含新舊數(shù)據(jù)的新的更大數(shù)據(jù)集上對整個(gè)模型進(jìn)行再訓(xùn)練。但是這種做法往往代價(jià)高昂。所以有一個(gè)ML研究領(lǐng)域正在研究這個(gè)問題,基于該領(lǐng)域的研究,本文將討論6種方法,使模型可以在保持舊的性能的同時(shí)適應(yīng)新數(shù)據(jù),并避免需要在整個(gè)數(shù)據(jù)集(舊+新)上進(jìn)行重新訓(xùn)練。
Prompt
Prompt 想法源于對GPT 3的提示(短序列的單詞)可以幫助驅(qū)動(dòng)模型更好地推理和回答。所以在本文中將Prompt 翻譯為提示。提示調(diào)優(yōu)是指使用小型可學(xué)習(xí)的提示,并將其與實(shí)際輸入一起作為模型的輸入。這允許我們只在新數(shù)據(jù)上訓(xùn)練提供提示的小模型,而無需再訓(xùn)練模型權(quán)重。
具體來說,我選擇了使用提示進(jìn)行基于文本的密集檢索的例子,這個(gè)例子改編自Wang的文章《Learning to Prompt for continuous Learning》。
該論文的作者使用下圖描述了他們的想法:

實(shí)際編碼的文本輸入用作從提示池中識別最小匹配對的key。在將這些標(biāo)識的提示輸入到模型之前,首先將它們添加到未編碼的文本嵌入中。這樣做的目的是訓(xùn)練這些提示來表示新的任務(wù),同時(shí)保持舊的模型不變,這里提示的很小,大概每個(gè)提示只有20個(gè)令牌。
class PromptPool(nn.Module):
def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
super().__init__()
self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
self.length = length
self.hidden = hidden_size
self.n = N
nn.init.xavier_normal_(self.pool)
nn.init.xavier_normal_(self.keys)
def init_weights(self, embedding):
pass
# function to select from pool based on index
def concat(self, indices, input_embeds):
subset = self.pool[indices, :] # 2, 2, 20, 768
subset = subset.to("cuda:0").reshape(indices.size(0),
self.n*self.length,
self.hidden) # 2, 40, 768
return torch.cat((subset, input_embeds), 1)
# x is cls output
def query_fn(self, x):
# encode input x to same dim as key using cosine
x = x / x.norm(dim=1)[:, None]
k = self.keys / self.keys.norm(dim=1)[:, None]
scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
# get argmin
subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
return subsets
pool = PromptPool()
然后我們使用的經(jīng)過訓(xùn)練的舊數(shù)據(jù)模型,訓(xùn)練新的數(shù)據(jù),這里只訓(xùn)練提示部分的權(quán)重。
def train():
count = 0
print("*********** Started Training *************")
start = time.time()
for epoch in range(40):
model.eval()
pool.train()
optimizer.zero_grad(set_to_none=True)
lap = time.time()
for batch in iter(train_dataloader):
count += 1
q, p, train_labels = batch
queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
attention_mask=q['attention_mask'].to("cuda:0"))
passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
attention_mask=p['attention_mask'].to("cuda:0"))
# pool
q_idx = pool.query_fn(queries_emb)
raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))
q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
p_idx = pool.query_fn(passage_emb)
raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))
p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
qattention_mask = torch.ones(batch_size, q.size(1))
pattention_mask = torch.ones(batch_size, p.size(1))
queries_emb = model.model(inputs_embeds=q,
attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
passage_emb = model.model(inputs_embeds=p,
attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
q_cls = queries_emb[:, pool.n*pool.length+1, :]
p_cls = passage_emb[:, pool.n*pool.length+1, :]
loss, ql, pl = calc_loss(q_cls, p_cls)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if count % 10 == 0:
print("Model Loss:", round(loss.item(),4), \
"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), \
"| Took:", round(time.time() - lap), "seconds\n")
lap = time.time()
if count % 40 == 0 and count > 0:
print("model saved")
torch.save(model.state_dict(), model_PATH)
torch.save(pool.state_dict(), pool_PATH)
if count == 4600: return
print("Training Took:", round(time.time() - start), "seconds")
print("\n*********** Training Complete *************")
訓(xùn)練完成后,后續(xù)的推理過程需要將輸入與檢索到的提示結(jié)合起來。例如這個(gè)例子得到了性能—93%的新數(shù)據(jù)提示池,而完全(舊+新)訓(xùn)練為—94%。這與原論文中提到的表現(xiàn)類似。但是需要說明的一點(diǎn)是結(jié)果可能會(huì)因任務(wù)而不同,你應(yīng)該嘗試實(shí)驗(yàn)來知道什么是最好的。
要使此方法成為值得考慮的方法,它必須能夠在舊數(shù)據(jù)上保留老模型> 80%的性能,同時(shí)提示也應(yīng)該幫助模型在新數(shù)據(jù)上獲得良好的性能。
這種方法的缺點(diǎn)是需要使用提示池,這會(huì)增加額外的時(shí)間。這也不是一個(gè)永久的解決方案,但是目前來說是可行的,也或許以后還會(huì)有新的方法出現(xiàn)。
Data Distillation
你可能聽說過知識蒸餾一詞,這是一種使用來自教師模型的權(quán)重來指導(dǎo)和訓(xùn)練較小規(guī)模模型的技術(shù)。數(shù)據(jù)蒸餾(Data Distillation)的工作原理也類似,它是使用來自真實(shí)數(shù)據(jù)的權(quán)重來訓(xùn)練更小的數(shù)據(jù)子集。因?yàn)閿?shù)據(jù)集的關(guān)鍵信號被提煉并濃縮為更小的數(shù)據(jù)集,我們對新數(shù)據(jù)的訓(xùn)練只需要提供一些提煉的數(shù)據(jù)以保持舊的性能。
在此示例中,我將數(shù)據(jù)蒸餾應(yīng)用于密集檢索(文本)任務(wù)。目前看沒有其他人在這個(gè)領(lǐng)域使用這種方法,所以結(jié)果可能不是最好的,但如果你在文本分類上使用這種方法應(yīng)該會(huì)得到不錯(cuò)的結(jié)果。
本質(zhì)上,文本數(shù)據(jù)蒸餾的想法源于 Li 的一篇題為 Data Distillation for Text Classification 的論文,該論文的靈感來自 Wang 的 Dataset Distillation,他對圖像數(shù)據(jù)進(jìn)行了蒸餾。Li 用下圖描述了文本數(shù)據(jù)蒸餾的任務(wù):

根據(jù)論文,首先將一批蒸餾數(shù)據(jù)輸入到模型以更新其權(quán)重。然后使用真實(shí)數(shù)據(jù)評估更新后的模型,并將信號反向傳播到蒸餾數(shù)據(jù)集。該論文在 8 個(gè)公共基準(zhǔn)數(shù)據(jù)集上報(bào)告了良好的分類結(jié)果(> 80% 準(zhǔn)確率)。
按照提出的想法,我做了一些小的改動(dòng),使用了一批蒸餾數(shù)據(jù)和多個(gè)真實(shí)數(shù)據(jù)。以下是為密集檢索訓(xùn)練創(chuàng)建蒸餾數(shù)據(jù)的代碼:
class DistilledData(nn.Module):
def __init__(self, num_labels, M, q_len=64, hidden_size=768):
super().__init__()
self.num_samples = M
self.q_len = q_len
self.num_labels = num_labels
self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
# init using model embedding, xavier, or load from state dict
def init_weights(self, model, path=None):
if model:
self.data.requires_grad = False
print("Init weights using model embedding")
raw_embedding = model.model.get_input_embeddings()
soft_embeds = raw_embedding.weight[:, :].clone().detach()
nums = soft_embeds.size(0)
for i1 in range(self.num_labels):
for i2 in range(self.num_samples):
for i3 in range(self.q_len):
random_idx = random.randint(0, nums-1)
self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
print(self.data.shape)
self.data.requires_grad = True
if not path:
nn.init.xavier_normal_(self.data)
else:
distilled_data.load_state_dict(torch.load(path), strict=False)
# function to sample a passage and positive sample as in the article, i am doing dense retrieval
def get_sample(self, label):
q_idx = random.randint(0, self.num_samples-1)
sampled_dist_q = self.data[label, q_idx, :, :]
p_idx = random.randint(0, self.num_samples-1)
while q_idx == p_idx:
p_idx = random.randint(0, self.num_samples-1)
sampled_dist_p = self.data[label, p_idx, :, :]
return sampled_dist_q, sampled_dist_p, q_idx, p_idx
這是將信號提取到蒸餾數(shù)據(jù)上的代碼
def distll_train(chunk_size=32):
count, times = 0, 0
print("*********** Started Training *************")
start = time.time()
lap = time.time()
for epoch in range(40):
distilled_data.train()
for batch in iter(train_dataloader):
count += 1
# get real query, pos, label, distilled data query, distilled data pos, ... from batch
q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
for idx in range(0, dq['input_ids'].size(0), chunk_size):
model.train()
with torch.enable_grad():
# train on distiled data first
x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
q_emb = model(inputs_embeds=x1.to("cuda:0"),
attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
p_emb = model(inputs_embeds=x2.to("cuda:0"),
attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
loss = default_loss(q_emb.to("cuda:0"), p_emb)
del q_emb, p_emb
loss.backward(retain_graph=True, create_graph=False)
state_dict = model.state_dict()
# update model weights
with torch.no_grad():
for idx, param in enumerate(model.parameters()):
if param.requires_grad and not param.grad is None:
param.data -= (param.grad*3e-5)
# real data
model.eval()
q_embs = []
p_embs = []
for k in range(0, len(q['input_ids']), chunk_size):
with torch.no_grad():
q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
q_embs.append(q_emb)
p_embs.append(p_emb)
q_embs = torch.cat(q_embs, 0)
p_embs = torch.cat(p_embs, 0)
r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
del q_embs, p_embs
# distill backward
if count % 2 == 0:
d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = q_indexes
else:
d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = p_indexes
loss.detach()
r_loss.detach()
grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
for i, k in enumerate(indexes):
grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") \
+ d_grad[0][i, :, :]
distilled_data.data.grad = grads
data_optimizer.step()
data_optimizer.zero_grad(set_to_none=True)
model.load_state_dict(state_dict)
model_optimizer.step()
model_optimizer.zero_grad(set_to_none=True)
if count % 10 == 0:
print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", \
round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
# print()
lap = time.time()
if count % 100 == 0:
torch.save(model.state_dict(), model_PATH)
torch.save(distilled_data.state_dict(), distill_PATH)
if loss < 0.1 and r_loss < 1:
times += 1
if times > 100:
print("Training Took:", round(time.time() - start), "seconds")
print("\n*********** Training Complete *************")
return
del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
print("Training Took:", round(time.time() - start), "seconds")
print("\n*********** Training Complete *************")
這里省略了數(shù)據(jù)加載等代碼,訓(xùn)練完蒸餾的數(shù)據(jù)后,我們可以通過在其上訓(xùn)練新模型來使用它,例如將其與新數(shù)據(jù)合并一起訓(xùn)練。
根據(jù)我的實(shí)驗(yàn),一個(gè)在蒸餾數(shù)據(jù)上訓(xùn)練的模型(每個(gè)標(biāo)簽只包含4個(gè)樣本)獲得了66%的最佳性能,而一個(gè)完全在原始數(shù)據(jù)上訓(xùn)練的模型也是得到了66%的最佳性能。而未經(jīng)訓(xùn)練的普通模型得到45%的性能。就像上面提到的這些數(shù)字對于密集檢索任務(wù)可能不太好,分類數(shù)據(jù)上會(huì)好很多。
要使此方法成為在調(diào)整模型以適應(yīng)新數(shù)據(jù)時(shí)值是一個(gè)有用的方法,需要能夠提取出比原始數(shù)據(jù)小得多的數(shù)據(jù)集(即~ 1%)。經(jīng)過提煉的數(shù)據(jù)也能夠給你一個(gè)略低于或等于主動(dòng)學(xué)習(xí)方法的表現(xiàn)。
這個(gè)方法的優(yōu)點(diǎn)是可以創(chuàng)建用于永久使用的蒸餾數(shù)據(jù)。缺點(diǎn)是提取的數(shù)據(jù)沒有可解釋性,并且需要額外的訓(xùn)練時(shí)間。
Curriculum/Active training
Curriculum training是一種方法,訓(xùn)練時(shí)向模型提供訓(xùn)練樣本的難度逐漸變大。在對新數(shù)據(jù)進(jìn)行訓(xùn)練時(shí),此方法需要人工的對任務(wù)進(jìn)行標(biāo)注,將任務(wù)分為簡單、中等或困難,然后對數(shù)據(jù)進(jìn)行采樣。為了理解模型的簡單、中等或困難意味著什么,我以這張圖片為例:

這是在分類任務(wù)中的混淆矩陣,困難樣本是假陽性(False Positive),是指模型預(yù)測為True的可能性很高,但實(shí)際上不是True的樣本。中等樣本是那些具有中到高的正確性可能性但低于預(yù)測閾值的True Negative。而簡單樣本則是那些可能性較低的True Positive/Negative。
Maximally Interfered Retrieval
這是 Rahaf 在題為“Online Continual Learning with Maximally Interfered Retrieval”的論文(1908.04742)中介紹的一種方法。主要思想是,對于正在訓(xùn)練的每個(gè)新數(shù)據(jù)批次,如果針對較新數(shù)據(jù)更新模型權(quán)重,將需要識別在損失值方面受影響最大的舊樣本。保留由舊數(shù)據(jù)組成的有限大小的內(nèi)存,并檢索最大干擾的樣本以及每個(gè)新數(shù)據(jù)批次以一起訓(xùn)練。
這篇論文在持續(xù)學(xué)習(xí)領(lǐng)域是一篇成熟的論文,并且有很多引用,因此可能適用于您的案例。
Retrieval Augmentation
檢索增強(qiáng)(Retrieval Augmentation)是指通過從集合中檢索項(xiàng)目來擴(kuò)充輸入、樣本等的技術(shù)。這是一個(gè)普遍的概念而不是一個(gè)特定的技術(shù)。我們到目前為止所討論的方法,大多數(shù)都在一定程度都是檢索相關(guān)的操作。Izacard 的題為 Few-shot Learning with Retrieval Augmented Language Models 的論文使用更小的模型獲得了出色的少樣本 學(xué)習(xí)的性能。檢索增強(qiáng)也用于許多其他情況,例如單詞生成或回答事實(shí)問題。
擴(kuò)展模型
在訓(xùn)練時(shí)使用附加層是最常見也最簡單的方法,但是不一定有效,所以在這里不進(jìn)行詳細(xì)的討論,這里的一個(gè)例子是 Lewis 的 Efficient Few-Shot Learning without Prompts。使用附加層通常是在新舊數(shù)據(jù)上獲得良好性能的最簡單但經(jīng)過嘗試和測試的方法。主要思想是保持模型權(quán)重固定,并通過分類損失在新數(shù)據(jù)上訓(xùn)練一層或幾層。
總結(jié)
在本文中,我介紹了在新數(shù)據(jù)上訓(xùn)練模型時(shí)可以使用的 6 種方法。與往常一樣應(yīng)該進(jìn)行實(shí)驗(yàn)并決定哪種方法最適合,但是需要注意的是,除了我上面的方法外還有很多方法,例如數(shù)據(jù)蒸餾是計(jì)算機(jī)視覺中的一個(gè)活躍領(lǐng)域,你可以找到很多關(guān)于它的論文。最后說明的一點(diǎn)是:要使這些方法有價(jià)值,它們應(yīng)該在舊數(shù)據(jù)和新數(shù)據(jù)上同時(shí)獲得良好的性能 。