在 PyTorch 中使用 Datasets 和 DataLoader 自定義數(shù)據(jù)
有時(shí)候,在處理大數(shù)據(jù)集時(shí),一次將整個(gè)數(shù)據(jù)加載到內(nèi)存中變得非常難。
因此,唯一的方法是將數(shù)據(jù)分批加載到內(nèi)存中進(jìn)行處理,這需要編寫額外的代碼來(lái)執(zhí)行此操作。對(duì)此,PyTorch 已經(jīng)提供了 Dataloader 功能。
DataLoader
下面顯示了 PyTorch 庫(kù)中DataLoader函數(shù)的語(yǔ)法及其參數(shù)信息。
- DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
- batch_sampler=None, num_workers=0, collate_fn=None,
- pin_memory=False, drop_last=False, timeout=0,
- worker_init_fn=None, *, prefetch_factor=2,
- persistent_workers=False)
幾個(gè)重要參數(shù)
- dataset:必須首先使用數(shù)據(jù)集構(gòu)造 DataLoader 類。
- Shuffle :是否重新整理數(shù)據(jù)。
- Sampler :指的是可選的 torch.utils.data.Sampler 類實(shí)例。采樣器定義了檢索樣本的策略,順序或隨機(jī)或任何其他方式。使用采樣器時(shí)應(yīng)將 Shuffle 設(shè)置為 false。
- Batch_Sampler :批處理級(jí)別。
- num_workers :加載數(shù)據(jù)所需的子進(jìn)程數(shù)。
- collate_fn :將樣本整理成批次。Torch 中可以進(jìn)行自定義整理。
加載內(nèi)置 MNIST 數(shù)據(jù)集
MNIST 是一個(gè)著名的包含手寫數(shù)字的數(shù)據(jù)集。下面介紹如何使用DataLoader功能處理 PyTorch 的內(nèi)置 MNIST 數(shù)據(jù)集。
- import torch
- import matplotlib.pyplot as plt
- from torchvision import datasets, transforms
上面代碼,導(dǎo)入了 torchvision 的torch計(jì)算機(jī)視覺(jué)模塊。通常在處理圖像數(shù)據(jù)集時(shí)使用,并且可以幫助對(duì)圖像進(jìn)行規(guī)范化、調(diào)整大小和裁剪。
對(duì)于 MNIST 數(shù)據(jù)集,下面使用了歸一化技術(shù)。
ToTensor()能夠把灰度范圍從0-255變換到0-1之間。
- transform = transforms.Compose([transforms.ToTensor()])
下面代碼用于加載所需的數(shù)據(jù)集。使用 PyTorchDataLoader通過(guò)給定 batch_size = 64來(lái)加載數(shù)據(jù)。shuffle=True打亂數(shù)據(jù)。
- trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
為了獲取數(shù)據(jù)集的所有圖像,一般使用iter函數(shù)和數(shù)據(jù)加載器DataLoader。
- dataiter = iter(trainloader)
- images, labels = dataiter.next()
- print(images.shape)
- print(labels.shape)
- plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r')
自定義數(shù)據(jù)集
下面的代碼創(chuàng)建一個(gè)包含 1000 個(gè)隨機(jī)數(shù)的自定義數(shù)據(jù)集。
- from torch.utils.data import Dataset
- import random
- class SampleDataset(Dataset):
- def __init__(self,r1,r2):
- randomlist=[]
- for i in range(120):
- n = random.randint(r1,r2)
- randomlist.append(n)
- self.samples=randomlist
- def __len__(self):
- return len(self.samples)
- def __getitem__(self,idx):
- return(self.samples[idx])
- dataset=SampleDataset(1,100)
- dataset[100:120]
在這里插入圖片描述
最后,將在自定義數(shù)據(jù)集上使用 dataloader 函數(shù)。將 batch_size 設(shè)為 12,并且還啟用了num_workers =2 的并行多進(jìn)程數(shù)據(jù)加載。
- from torch.utils.data import DataLoader
- loader = DataLoader(dataset,batch_size=12, shuffle=True, num_workers=2 )
- for i, batch in enumerate(loader):
- print(i, batch)
寫在后面通過(guò)幾個(gè)示例了解了 PyTorch Dataloader 在將大量數(shù)據(jù)批量加載到內(nèi)存中的作用。