PyTorch中的數(shù)據(jù)集Torchvision和Torchtext
對于PyTorch加載和處理不同類型數(shù)據(jù),官方提供了torchvision和torchtext。
之前使用 torchDataLoader類直接加載圖像并將其轉(zhuǎn)換為張量?,F(xiàn)在結(jié)合torchvision和torchtext介紹torch中的內(nèi)置數(shù)據(jù)集
Torchvision 中的數(shù)據(jù)集
MNIST
MNIST是一個(gè)由標(biāo)準(zhǔn)化和中心裁剪的手寫圖像組成的數(shù)據(jù)集。它有超過 60,000 張訓(xùn)練圖像和 10,000 張測試圖像。這是用于學(xué)習(xí)和實(shí)驗(yàn)?zāi)康淖畛S玫臄?shù)據(jù)集之一。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.MNIST()。
Fashion MNIST
Fashion MNIST數(shù)據(jù)集類似于MNIST,但該數(shù)據(jù)集包含T恤、褲子、包包等服裝項(xiàng)目,而不是手寫數(shù)字,訓(xùn)練和測試樣本數(shù)分別為60,000和10,000。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.FashionMNIST()
CIFAR
CIFAR數(shù)據(jù)集有兩個(gè)版本,CIFAR10和CIFAR100。CIFAR10 由 10 個(gè)不同標(biāo)簽的圖像組成,而 CIFAR100 有 100 個(gè)不同的類。這些包括常見的圖像,如卡車、青蛙、船、汽車、鹿等。
- torchvision.datasets.CIFAR10()
- torchvision.datasets.CIFAR100()
COCO
COCO數(shù)據(jù)集包含超過 100,000 個(gè)日常對象,如人、瓶子、文具、書籍等。這個(gè)圖像數(shù)據(jù)集廣泛用于對象檢測和圖像字幕應(yīng)用。下面是可以加載 COCO 的位置:torchvision.datasets.CocoCaptions()
EMNIST
EMNIST數(shù)據(jù)集是 MNIST 數(shù)據(jù)集的高級版本。它由包括數(shù)字和字母的圖像組成。如果您正在處理基于從圖像中識別文本的問題,EMNIST是一個(gè)不錯(cuò)的選擇。下面是可以加載 EMNIST的位置::torchvision.datasets.EMNIST()
IMAGE-NET
ImageNet 是用于訓(xùn)練高端神經(jīng)網(wǎng)絡(luò)的旗艦數(shù)據(jù)集之一。它由分布在 10,000 個(gè)類別中的超過 120 萬張圖像組成。通常,這個(gè)數(shù)據(jù)集加載在高端硬件系統(tǒng)上,因?yàn)閱为?dú)的 CPU 無法處理這么大的數(shù)據(jù)集。下面是加載 ImageNet 數(shù)據(jù)集的類:torchvision.datasets.ImageNet()
Torchtext 中的數(shù)據(jù)集
IMDB
IMDB是一個(gè)用于情感分類的數(shù)據(jù)集,其中包含一組 25,000 條高度極端的電影評論用于訓(xùn)練,另外 25,000 條用于測試。使用以下類加載這些數(shù)據(jù)torchtext:torchtext.datasets.IMDB()
WikiText2
WikiText2語言建模數(shù)據(jù)集是一個(gè)超過 1 億個(gè)標(biāo)記的集合。它是從維基百科中提取的,并保留了標(biāo)點(diǎn)符號和實(shí)際的字母大小寫。它廣泛用于涉及長期依賴的應(yīng)用程序??梢詮膖orchtext以下位置加載此數(shù)據(jù):torchtext.datasets.WikiText2()
除了上述兩個(gè)流行的數(shù)據(jù)集,torchtext庫中還有更多可用的數(shù)據(jù)集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。
深入查看 MNIST 數(shù)據(jù)集
MNIST 是最受歡迎的數(shù)據(jù)集之一?,F(xiàn)在我們將看到 PyTorch 如何從 pytorch/vision 存儲(chǔ)庫加載 MNIST 數(shù)據(jù)集。讓我們首先下載數(shù)據(jù)集并將其加載到名為 的變量中data_train
- from torchvision.datasets import MNIST
- # Download MNIST
- data_train = MNIST('~/mnist_data', train=True, download=True)
- import matplotlib.pyplot as plt
- random_image = data_train[0][0]
- random_image_label = data_train[0][1]
- # Print the Image using Matplotlib
- plt.imshow(random_image)
- print("The label of the image is:", random_image_label)
DataLoader加載MNIST
下面我們使用DataLoader該類加載數(shù)據(jù)集,如下所示。
- import torch
- from torchvision import transforms
- data_train = torch.utils.data.DataLoader(
- MNIST(
- '~/mnist_data', train=True, download=True,
- transform = transforms.Compose([
- transforms.ToTensor()
- ])),
- batch_size=64,
- shuffle=True
- )
- for batch_idx, samples in enumerate(data_train):
- print(batch_idx, samples)
CUDA加載
我們可以啟用 GPU 來更快地訓(xùn)練我們的模型?,F(xiàn)在讓我們使用CUDA加載數(shù)據(jù)時(shí)可以使用的(GPU 支持 PyTorch)的配置。
- device = "cuda" if torch.cuda.is_available() else "cpu"
- kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}
- train_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('/files/', train=True, download=True),
- batch_size=batch_size_train, **kwargs)
- test_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('files/', train=False, download=True),
- batch_size=batch_size, **kwargs)
ImageFolder
ImageFolder是一個(gè)通用數(shù)據(jù)加載器類torchvision,可幫助加載自己的圖像數(shù)據(jù)集。處理一個(gè)分類問題并構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)來識別給定的圖像是apple還是orange。要在 PyTorch 中執(zhí)行此操作,第一步是在默認(rèn)文件夾結(jié)構(gòu)中排列圖像,如下所示:
- root
- ├── orange
- │ ├── orange_image1.png
- │ └── orange_image1.png
- ├── apple
- │ └── apple_image1.png
- │ └── apple_image2.png
- │ └── apple_image3.png
可以使用ImageLoader該類加載所有這些圖像。
- torchvision.datasets.ImageFolder(root, transform)
transforms
PyTorch 轉(zhuǎn)換定義了簡單的圖像轉(zhuǎn)換技術(shù),可將整個(gè)數(shù)據(jù)集轉(zhuǎn)換為獨(dú)特的格式。
如果是一個(gè)包含不同分辨率的不同汽車圖片的數(shù)據(jù)集,在訓(xùn)練時(shí),我們訓(xùn)練數(shù)據(jù)集中的所有圖像都應(yīng)該具有相同的分辨率大小。如果我們手動(dòng)將所有圖像轉(zhuǎn)換為所需的輸入大小,則很耗時(shí),因此我們可以使用transforms;使用幾行 PyTorch 代碼,我們數(shù)據(jù)集中的所有圖像都可以轉(zhuǎn)換為所需的輸入大小和分辨率。
現(xiàn)在讓我們加載 CIFAR10torchvision.datasets并應(yīng)用以下轉(zhuǎn)換:
- 將所有圖像調(diào)整為 32×32
- 對圖像應(yīng)用中心裁剪變換
- 將裁剪后的圖像轉(zhuǎn)換為張量
- 標(biāo)準(zhǔn)化圖像
- import torch
- import torchvision
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- import numpy as np
- transform = transforms.Compose([
- # resize 32×32
- transforms.Resize(32),
- # center-crop裁剪變換
- transforms.CenterCrop(32),
- # to-tensor
- transforms.ToTensor(),
- # normalize 標(biāo)準(zhǔn)化
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
- ])
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
- download=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
- shuffle=False)
在 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集
下面將創(chuàng)建一個(gè)由數(shù)字和文本組成的簡單自定義數(shù)據(jù)集。需要封裝Dataset 類中的__getitem__()和__len__()方法。
- __getitem__()方法通過索引返回?cái)?shù)據(jù)集中的選定樣本。
- __len__()方法返回?cái)?shù)據(jù)集的總大小。
下面是曾經(jīng)封裝FruitImagesDataset數(shù)據(jù)集的代碼,基本是比較好的 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集的模板。
- import os
- import numpy as np
- import cv2
- import torch
- import matplotlib.patches as patches
- import albumentations as A
- from albumentations.pytorch.transforms import ToTensorV2
- from matplotlib import pyplot as plt
- from torch.utils.data import Dataset
- from xml.etree import ElementTree as et
- from torchvision import transforms as torchtrans
- class FruitImagesDataset(torch.utils.data.Dataset):
- def __init__(self, files_dir, width, height, transforms=None):
- self.transforms = transforms
- self.files_dir = files_dir
- self.height = height
- self.width = width
- self.imgs = [image for image in sorted(os.listdir(files_dir))
- if image[-4:] == '.jpg']
- self.classes = ['_','apple', 'banana', 'orange']
- def __getitem__(self, idx):
- img_name = self.imgs[idx]
- image_path = os.path.join(self.files_dir, img_name)
- # reading the images and converting them to correct size and color
- img = cv2.imread(image_path)
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
- img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
- # diving by 255
- img_res /= 255.0
- # annotation file
- annot_filename = img_name[:-4] + '.xml'
- annot_file_path = os.path.join(self.files_dir, annot_filename)
- boxes = []
- labels = []
- tree = et.parse(annot_file_path)
- root = tree.getroot()
- # cv2 image gives size as height x width
- wt = img.shape[1]
- ht = img.shape[0]
- # box coordinates for xml files are extracted and corrected for image size given
- for member in root.findall('object'):
- labels.append(self.classes.index(member.find('name').text))
- # bounding box
- xmin = int(member.find('bndbox').find('xmin').text)
- xmax = int(member.find('bndbox').find('xmax').text)
- ymin = int(member.find('bndbox').find('ymin').text)
- ymax = int(member.find('bndbox').find('ymax').text)
- xmin_corr = (xmin / wt) * self.width
- xmax_corr = (xmax / wt) * self.width
- ymin_corr = (ymin / ht) * self.height
- ymax_corr = (ymax / ht) * self.height
- boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
- # convert boxes into a torch.Tensor
- boxes = torch.as_tensor(boxes, dtype=torch.float32)
- # getting the areas of the boxes
- area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
- # suppose all instances are not crowd
- iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
- labels = torch.as_tensor(labels, dtype=torch.int64)
- target = {}
- target["boxes"] = boxes
- target["labels"] = labels
- target["area"] = area
- target["iscrowd"] = iscrowd
- # image_id
- image_id = torch.tensor([idx])
- target["image_id"] = image_id
- if self.transforms:
- sample = self.transforms(image=img_res,
- bboxes=target['boxes'],
- labels=labels)
- img_res = sample['image']
- target['boxes'] = torch.Tensor(sample['bboxes'])
- return img_res, target
- def __len__(self):
- return len(self.imgs)
- def get_transform(train):
- if train:
- return A.Compose([
- A.HorizontalFlip(0.5),
- ToTensorV2(p=1.0)
- ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
- else:
- return A.Compose([
- ToTensorV2(p=1.0)
- ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
- files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
- test_dir = '../input/fruit-images-for-object-detection/test_zip/test'
- dataset = FruitImagesDataset(train_dir, 480, 480)