就機(jī)器學(xué)習(xí)而言,音頻本身是一個(gè)有廣泛應(yīng)用的完整的領(lǐng)域,包括語(yǔ)音識(shí)別、音樂(lè)分類(lèi)和聲音事件檢測(cè)等等。傳統(tǒng)上音頻分類(lèi)一直使用譜圖分析和隱馬爾可夫模型等方法,這些方法已被證明是有效的,但也有其局限性。近期VIT已經(jīng)成為音頻任務(wù)的一個(gè)有前途的替代品,OpenAI的Whisper就是一個(gè)很好的例子。

在本文中,我們將利用ViT - Vision Transformer的是一個(gè)Pytorch實(shí)現(xiàn)在音頻分類(lèi)數(shù)據(jù)集GTZAN數(shù)據(jù)集-音樂(lè)類(lèi)型分類(lèi)上訓(xùn)練它。數(shù)據(jù)集介紹
GTZAN 數(shù)據(jù)集是在音樂(lè)流派識(shí)別 (MGR) 研究中最常用的公共數(shù)據(jù)集。 這些文件是在 2000-2001 年從各種來(lái)源收集的,包括個(gè)人 CD、收音機(jī)、麥克風(fēng)錄音,代表各種錄音條件下的聲音。

這個(gè)數(shù)據(jù)集由子文件夾組成,每個(gè)子文件夾是一種類(lèi)型。

加載數(shù)據(jù)集
我們將加載每個(gè).wav文件,并通過(guò)librosa庫(kù)生成相應(yīng)的Mel譜圖。
mel譜圖是聲音信號(hào)的頻譜內(nèi)容的一種可視化表示,它的垂直軸表示mel尺度上的頻率,水平軸表示時(shí)間。它是音頻信號(hào)處理中常用的一種表示形式,特別是在音樂(lè)信息檢索領(lǐng)域。
梅爾音階(Mel scale,英語(yǔ):mel scale)是一個(gè)考慮到人類(lèi)音高感知的音階。因?yàn)槿祟?lèi)不會(huì)感知線性范圍的頻率,也就是說(shuō)我們?cè)跈z測(cè)低頻差異方面要?jiǎng)儆诟哳l。 例如,我們可以輕松分辨出500 Hz和1000 Hz之間的差異,但是即使之間的距離相同,我們也很難分辨出10,000 Hz和10,500 Hz之間的差異。所以梅爾音階解決了這個(gè)問(wèn)題,如果梅爾音階的差異相同,則意指人類(lèi)感覺(jué)到的音高差異將相同。
def wav2melspec(fp):
y, sr = librosa.load(fp)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
log_S = librosa.amplitude_to_db(S, ref=np.max)
img = librosa.display.specshow(log_S, sr=sr, x_axis='time', y_axis='mel')
# get current figure without white border
img = plt.gcf()
img.gca().xaxis.set_major_locator(plt.NullLocator())
img.gca().yaxis.set_major_locator(plt.NullLocator())
img.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
img.gca().xaxis.set_major_locator(plt.NullLocator())
img.gca().yaxis.set_major_locator(plt.NullLocator())
# to pil image
img.canvas.draw()
img = Image.frombytes('RGB', img.canvas.get_width_height(), img.canvas.tostring_rgb())
return img
上述函數(shù)將產(chǎn)生一個(gè)簡(jiǎn)單的mel譜圖:

現(xiàn)在我們從文件夾中加載數(shù)據(jù)集,并對(duì)圖像應(yīng)用轉(zhuǎn)換。
class AudioDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.classes = sorted(os.listdir(root))
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
self.samples = []
for c in self.classes:
for fp in os.listdir(os.path.join(root, c)):
self.samples.append((os.path.join(root, c, fp), self.class_to_idx[c]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
fp, target = self.samples[idx]
img = Image.open(fp)
if self.transform:
img = self.transform(img)
return img, target
train_dataset = AudioDataset(root, transform=transforms.Compose([
transforms.Resize((480, 480)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
ViT模型
我們將利用ViT來(lái)作為我們的模型:Vision Transformer在論文中首次介紹了一幅圖像等于16x16個(gè)單詞,并成功地展示了這種方式不依賴任何的cnn,直接應(yīng)用于圖像Patches序列的純Transformer可以很好地執(zhí)行圖像分類(lèi)任務(wù)。

將圖像分割成Patches,并將這些Patches的線性嵌入序列作為T(mén)ransformer的輸入。Patches的處理方式與NLP應(yīng)用程序中的標(biāo)記(單詞)是相同的。
由于缺乏CNN固有的歸納偏差(如局部性),Transformer在訓(xùn)練數(shù)據(jù)量不足時(shí)不能很好地泛化。但是當(dāng)在大型數(shù)據(jù)集上訓(xùn)練時(shí),它確實(shí)在多個(gè)圖像識(shí)別基準(zhǔn)上達(dá)到或擊敗了最先進(jìn)的水平。
實(shí)現(xiàn)的結(jié)構(gòu)如下所示:
class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 356,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes)
訓(xùn)練
訓(xùn)練循環(huán)也是傳統(tǒng)的訓(xùn)練過(guò)程:
vit = ViT(
n_classes = len(train_dataset.classes)
)
vit.to(device)
# train
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
optimizer = optim.Adam(vit.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=0.3, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()
num_epochs = 30
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
vit.train()
running_loss = 0.0
running_corrects = 0
for inputs, labels in tqdm.tqdm(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = vit(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects.double() / len(train_dataset)
scheduler.step(epoch_acc)
print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
總結(jié)
使用PyTorch從頭開(kāi)始訓(xùn)練了這個(gè)Vision Transformer架構(gòu)的自定義實(shí)現(xiàn)。因?yàn)閿?shù)據(jù)集非常小(每個(gè)類(lèi)只有100個(gè)樣本),這影響了模型的性能,只獲得了0.71的準(zhǔn)確率。
這只是一個(gè)簡(jiǎn)單的演示,如果需要提高模型表現(xiàn),可以使用更大的數(shù)據(jù)集,或者稍微調(diào)整架構(gòu)的各種超參數(shù)!
這里使用的vit代碼來(lái)自:
https://medium.com/artificialis/vit-visiontransformer-a-pytorch-implementation-8d6a1033bdc5