自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

PyTorch中的數(shù)據(jù)集Torchvision和Torchtext

人工智能 深度學(xué)習(xí)
對于PyTorch加載和處理不同類型數(shù)據(jù),官方提供了torchvision和torchtext。之前使用 torchDataLoader類直接加載圖像并將其轉(zhuǎn)換為張量?,F(xiàn)在結(jié)合torchvision和torchtext介紹torch中的內(nèi)置數(shù)據(jù)集。

[[421061]]

對于PyTorch加載和處理不同類型數(shù)據(jù),官方提供了torchvision和torchtext。

之前使用 torchDataLoader類直接加載圖像并將其轉(zhuǎn)換為張量?,F(xiàn)在結(jié)合torchvision和torchtext介紹torch中的內(nèi)置數(shù)據(jù)集

Torchvision 中的數(shù)據(jù)集

MNIST

MNIST是一個(gè)由標(biāo)準(zhǔn)化和中心裁剪的手寫圖像組成的數(shù)據(jù)集。它有超過 60,000 張訓(xùn)練圖像和 10,000 張測試圖像。這是用于學(xué)習(xí)和實(shí)驗(yàn)?zāi)康淖畛S玫臄?shù)據(jù)集之一。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.MNIST()。

Fashion MNIST

Fashion MNIST數(shù)據(jù)集類似于MNIST,但該數(shù)據(jù)集包含T恤、褲子、包包等服裝項(xiàng)目,而不是手寫數(shù)字,訓(xùn)練和測試樣本數(shù)分別為60,000和10,000。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.FashionMNIST()

CIFAR

CIFAR數(shù)據(jù)集有兩個(gè)版本,CIFAR10和CIFAR100。CIFAR10 由 10 個(gè)不同標(biāo)簽的圖像組成,而 CIFAR100 有 100 個(gè)不同的類。這些包括常見的圖像,如卡車、青蛙、船、汽車、鹿等。

  1. torchvision.datasets.CIFAR10() 
  2. torchvision.datasets.CIFAR100() 

COCO

COCO數(shù)據(jù)集包含超過 100,000 個(gè)日常對象,如人、瓶子、文具、書籍等。這個(gè)圖像數(shù)據(jù)集廣泛用于對象檢測和圖像字幕應(yīng)用。下面是可以加載 COCO 的位置:torchvision.datasets.CocoCaptions()

EMNIST

EMNIST數(shù)據(jù)集是 MNIST 數(shù)據(jù)集的高級版本。它由包括數(shù)字和字母的圖像組成。如果您正在處理基于從圖像中識別文本的問題,EMNIST是一個(gè)不錯(cuò)的選擇。下面是可以加載 EMNIST的位置::torchvision.datasets.EMNIST()

IMAGE-NET

ImageNet 是用于訓(xùn)練高端神經(jīng)網(wǎng)絡(luò)的旗艦數(shù)據(jù)集之一。它由分布在 10,000 個(gè)類別中的超過 120 萬張圖像組成。通常,這個(gè)數(shù)據(jù)集加載在高端硬件系統(tǒng)上,因?yàn)閱为?dú)的 CPU 無法處理這么大的數(shù)據(jù)集。下面是加載 ImageNet 數(shù)據(jù)集的類:torchvision.datasets.ImageNet()

Torchtext 中的數(shù)據(jù)集

IMDB

IMDB是一個(gè)用于情感分類的數(shù)據(jù)集,其中包含一組 25,000 條高度極端的電影評論用于訓(xùn)練,另外 25,000 條用于測試。使用以下類加載這些數(shù)據(jù)torchtext:torchtext.datasets.IMDB()

WikiText2

WikiText2語言建模數(shù)據(jù)集是一個(gè)超過 1 億個(gè)標(biāo)記的集合。它是從維基百科中提取的,并保留了標(biāo)點(diǎn)符號和實(shí)際的字母大小寫。它廣泛用于涉及長期依賴的應(yīng)用程序??梢詮膖orchtext以下位置加載此數(shù)據(jù):torchtext.datasets.WikiText2()

除了上述兩個(gè)流行的數(shù)據(jù)集,torchtext庫中還有更多可用的數(shù)據(jù)集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。

深入查看 MNIST 數(shù)據(jù)集

MNIST 是最受歡迎的數(shù)據(jù)集之一?,F(xiàn)在我們將看到 PyTorch 如何從 pytorch/vision 存儲(chǔ)庫加載 MNIST 數(shù)據(jù)集。讓我們首先下載數(shù)據(jù)集并將其加載到名為 的變量中data_train

  1. from torchvision.datasets import MNIST 
  2.  
  3. # Download MNIST  
  4. data_train = MNIST('~/mnist_data', train=True, download=True
  5.  
  6. import matplotlib.pyplot as plt 
  7.  
  8. random_image = data_train[0][0] 
  9. random_image_label = data_train[0][1] 
  10.  
  11. # Print the Image using Matplotlib 
  12. plt.imshow(random_image) 
  13. print("The label of the image is:", random_image_label) 

DataLoader加載MNIST

下面我們使用DataLoader該類加載數(shù)據(jù)集,如下所示。

  1. import torch 
  2. from torchvision import transforms 
  3.  
  4. data_train = torch.utils.data.DataLoader( 
  5.     MNIST( 
  6.           '~/mnist_data', train=True, download=True,  
  7.           transform = transforms.Compose([ 
  8.               transforms.ToTensor() 
  9.           ])), 
  10.           batch_size=64, 
  11.           shuffle=True 
  12.           ) 
  13.  
  14. for batch_idx, samples in enumerate(data_train): 
  15.       print(batch_idx, samples) 

CUDA加載

我們可以啟用 GPU 來更快地訓(xùn)練我們的模型?,F(xiàn)在讓我們使用CUDA加載數(shù)據(jù)時(shí)可以使用的(GPU 支持 PyTorch)的配置。

  1. device = "cuda" if torch.cuda.is_available() else "cpu" 
  2. kwargs = {'num_workers': 1, 'pin_memory'True} if device=='cuda' else {} 
  3.  
  4. train_loader = torch.utils.data.DataLoader( 
  5.   torchvision.datasets.MNIST('/files/', train=True, download=True), 
  6.   batch_size=batch_size_train, **kwargs) 
  7.  
  8. test_loader = torch.utils.data.DataLoader( 
  9.   torchvision.datasets.MNIST('files/', train=False, download=True), 
  10.   batch_size=batch_size, **kwargs) 

ImageFolder

ImageFolder是一個(gè)通用數(shù)據(jù)加載器類torchvision,可幫助加載自己的圖像數(shù)據(jù)集。處理一個(gè)分類問題并構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)來識別給定的圖像是apple還是orange。要在 PyTorch 中執(zhí)行此操作,第一步是在默認(rèn)文件夾結(jié)構(gòu)中排列圖像,如下所示:

  1. root 
  2. ├── orange 
  3. │   ├── orange_image1.png 
  4. │   └── orange_image1.png 
  5. ├── apple 
  6. │   └── apple_image1.png 
  7. │   └── apple_image2.png 
  8. │   └── apple_image3.png 

可以使用ImageLoader該類加載所有這些圖像。

  1. torchvision.datasets.ImageFolder(root, transform) 

transforms

PyTorch 轉(zhuǎn)換定義了簡單的圖像轉(zhuǎn)換技術(shù),可將整個(gè)數(shù)據(jù)集轉(zhuǎn)換為獨(dú)特的格式。

如果是一個(gè)包含不同分辨率的不同汽車圖片的數(shù)據(jù)集,在訓(xùn)練時(shí),我們訓(xùn)練數(shù)據(jù)集中的所有圖像都應(yīng)該具有相同的分辨率大小。如果我們手動(dòng)將所有圖像轉(zhuǎn)換為所需的輸入大小,則很耗時(shí),因此我們可以使用transforms;使用幾行 PyTorch 代碼,我們數(shù)據(jù)集中的所有圖像都可以轉(zhuǎn)換為所需的輸入大小和分辨率。

現(xiàn)在讓我們加載 CIFAR10torchvision.datasets并應(yīng)用以下轉(zhuǎn)換:

  • 將所有圖像調(diào)整為 32×32
  • 對圖像應(yīng)用中心裁剪變換
  • 將裁剪后的圖像轉(zhuǎn)換為張量
  • 標(biāo)準(zhǔn)化圖像
  1. import torch 
  2. import torchvision 
  3. import torchvision.transforms as transforms 
  4. import matplotlib.pyplot as plt 
  5. import numpy as np 
  6.  
  7. transform = transforms.Compose([ 
  8.     # resize 32×32 
  9.     transforms.Resize(32), 
  10.     # center-crop裁剪變換 
  11.     transforms.CenterCrop(32), 
  12.     # to-tensor 
  13.     transforms.ToTensor(), 
  14.     # normalize 標(biāo)準(zhǔn)化 
  15.     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
  16. ]) 
  17.  
  18. trainset = torchvision.datasets.CIFAR10(root='./data', train=True
  19.                                         download=True, transform=transform) 
  20. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
  21.                                           shuffle=False

在 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集

下面將創(chuàng)建一個(gè)由數(shù)字和文本組成的簡單自定義數(shù)據(jù)集。需要封裝Dataset 類中的__getitem__()和__len__()方法。

  • __getitem__()方法通過索引返回?cái)?shù)據(jù)集中的選定樣本。
  • __len__()方法返回?cái)?shù)據(jù)集的總大小。

下面是曾經(jīng)封裝FruitImagesDataset數(shù)據(jù)集的代碼,基本是比較好的 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集的模板。

  1. import os 
  2. import numpy as np 
  3. import cv2 
  4. import torch 
  5. import matplotlib.patches as patches 
  6. import albumentations as A 
  7. from albumentations.pytorch.transforms import ToTensorV2 
  8. from matplotlib import pyplot as plt 
  9. from torch.utils.data import Dataset 
  10. from xml.etree import ElementTree as et 
  11. from torchvision import transforms as torchtrans 
  12.  
  13. class FruitImagesDataset(torch.utils.data.Dataset): 
  14.     def __init__(self, files_dir, width, height, transforms=None): 
  15.         self.transforms = transforms 
  16.         self.files_dir = files_dir 
  17.         self.height = height 
  18.         self.width = width 
  19.  
  20.  
  21.         self.imgs = [image for image in sorted(os.listdir(files_dir)) 
  22.                      if image[-4:] == '.jpg'
  23.  
  24.         self.classes = ['_','apple''banana''orange'
  25.  
  26.     def __getitem__(self, idx): 
  27.  
  28.         img_name = self.imgs[idx] 
  29.         image_path = os.path.join(self.files_dir, img_name) 
  30.  
  31.         # reading the images and converting them to correct size and color 
  32.         img = cv2.imread(image_path) 
  33.         img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) 
  34.         img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA) 
  35.         # diving by 255 
  36.         img_res /= 255.0 
  37.  
  38.         # annotation file 
  39.         annot_filename = img_name[:-4] + '.xml' 
  40.         annot_file_path = os.path.join(self.files_dir, annot_filename) 
  41.  
  42.         boxes = [] 
  43.         labels = [] 
  44.         tree = et.parse(annot_file_path) 
  45.         root = tree.getroot() 
  46.  
  47.         # cv2 image gives size as height x width 
  48.         wt = img.shape[1] 
  49.         ht = img.shape[0] 
  50.  
  51.         # box coordinates for xml files are extracted and corrected for image size given 
  52.         for member in root.findall('object'): 
  53.             labels.append(self.classes.index(member.find('name').text)) 
  54.  
  55.             # bounding box 
  56.             xmin = int(member.find('bndbox').find('xmin').text) 
  57.             xmax = int(member.find('bndbox').find('xmax').text) 
  58.  
  59.             ymin = int(member.find('bndbox').find('ymin').text) 
  60.             ymax = int(member.find('bndbox').find('ymax').text) 
  61.  
  62.             xmin_corr = (xmin / wt) * self.width 
  63.             xmax_corr = (xmax / wt) * self.width 
  64.             ymin_corr = (ymin / ht) * self.height 
  65.             ymax_corr = (ymax / ht) * self.height 
  66.  
  67.             boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr]) 
  68.  
  69.         # convert boxes into a torch.Tensor 
  70.         boxes = torch.as_tensor(boxes, dtype=torch.float32) 
  71.  
  72.         # getting the areas of the boxes 
  73.         area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 
  74.  
  75.         # suppose all instances are not crowd 
  76.         iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) 
  77.  
  78.         labels = torch.as_tensor(labels, dtype=torch.int64) 
  79.  
  80.         target = {} 
  81.         target["boxes"] = boxes 
  82.         target["labels"] = labels 
  83.         target["area"] = area 
  84.         target["iscrowd"] = iscrowd 
  85.         # image_id 
  86.         image_id = torch.tensor([idx]) 
  87.         target["image_id"] = image_id 
  88.  
  89.         if self.transforms: 
  90.             sample = self.transforms(image=img_res, 
  91.                                      bboxes=target['boxes'], 
  92.                                      labels=labels) 
  93.  
  94.             img_res = sample['image'
  95.             target['boxes'] = torch.Tensor(sample['bboxes']) 
  96.         return img_res, target 
  97.     def __len__(self): 
  98.         return len(self.imgs) 
  99.  
  100. def get_transform(train): 
  101.     if train: 
  102.         return A.Compose([ 
  103.             A.HorizontalFlip(0.5), 
  104.             ToTensorV2(p=1.0) 
  105.         ], bbox_params={'format''pascal_voc''label_fields': ['labels']}) 
  106.     else
  107.         return A.Compose([ 
  108.             ToTensorV2(p=1.0) 
  109.         ], bbox_params={'format''pascal_voc''label_fields': ['labels']}) 
  110.  
  111. files_dir = '../input/fruit-images-for-object-detection/train_zip/train' 
  112. test_dir = '../input/fruit-images-for-object-detection/test_zip/test' 
  113.  
  114. dataset = FruitImagesDataset(train_dir, 480, 480) 

 

責(zé)任編輯:姜華 來源: Python之王
相關(guān)推薦

2023-04-07 07:29:54

Torchvisio計(jì)算機(jī)視覺

2023-12-18 10:41:28

深度學(xué)習(xí)NumPyPyTorch

2021-12-13 09:14:06

清單管理數(shù)據(jù)集

2019-06-19 09:13:29

機(jī)器學(xué)習(xí)中數(shù)據(jù)集深度學(xué)習(xí)

2021-09-10 10:26:45

PyTorch數(shù)據(jù)集S3 Plugin

2010-04-27 13:21:58

Oracle數(shù)據(jù)字符集

2023-12-01 16:23:52

大數(shù)據(jù)人工智能

2020-07-15 13:51:48

TensorFlow數(shù)據(jù)機(jī)器學(xué)習(xí)

2023-07-28 09:54:14

SQL數(shù)據(jù)Excel

2021-09-03 06:46:34

SQL分組集功能

2021-11-09 08:48:48

Python開源項(xiàng)目

2020-06-24 07:53:03

機(jī)器學(xué)習(xí)技術(shù)人工智能

2020-10-05 21:57:17

GitHub 開源開發(fā)

2020-10-27 09:37:43

PyTorchTensorFlow機(jī)器學(xué)習(xí)

2009-08-03 14:39:25

Asp.Net函數(shù)集

2020-04-29 13:40:32

數(shù)據(jù)集數(shù)據(jù)科學(xué)冠狀病毒

2022-09-16 00:11:45

PyTorch神經(jīng)網(wǎng)絡(luò)存儲(chǔ)

2010-06-17 09:29:32

SQLServer 2

2023-02-21 08:00:00

2020-10-15 11:22:34

PyTorchTensorFlow機(jī)器學(xué)習(xí)
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號