深度学习(基于pytorch)|深度学习笔记(七)——pytorch数据处理工具箱(一)


数据处理工具箱

  • 数据处理工具箱概述
  • utils.data简介
      • 参考文献

数据下载和预处理是机器学习、深度学习实际项目中耗时又重要的任务,尤其是数据预处理,关系到数据质量和模型性能,往往要占据项目的大部分时间。 【深度学习(基于pytorch)|深度学习笔记(七)——pytorch数据处理工具箱(一)】
数据处理工具箱概述 pytorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如下图所示:
深度学习(基于pytorch)|深度学习笔记(七)——pytorch数据处理工具箱(一)
文章图片

上图的左边是torch.utils.data工具包,它包括以下4个类。
1)Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其中的两个方法(_getitem_,_ len_)。
2)DataLoader:定义了一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能。
3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
4)*sample:多种采样函数。
上图中间是pytorch可视化处理工具(torchvision),其是pytorch的一个视觉处理工具包。
它包括四个类,各类主要功能如下。
1)datasets:提供常用的数据集加在,设计上都是继承自torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet和COCO等。
2)models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择pretrained=True),包括AlexNet、VGG系列、ResNet系列、Inception系列等。
3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中,另一个是save_img,它能将Tensor保存成图片。
utils.data简介 utils.data包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义数据集需要继承这个类,并实现两个函数,一个是_len_,另一个是_getitem_,前者提供数据的大小(size),后者通过给定索引获取数据和标签。_getitem_一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。首先我们来定义一个简单的数据集,然后通过具体使用Dataset及DataLoader。
1)导入需要的模块
import torch from torch.utils import data import numpy as np

2)定义获取数据集的类
该类继承基类Dataset,自定义一个数据集及对应标签。
class TestDataset(data.Dataset):# 继承Dataset def __init__(self): self.Data = https://www.it610.com/article/np.asarray([[1, 2], [3, 4], [2, 1], [3, 4], [4, 5]])# 一些由2维向量表示的数据集 self.Label = np.asarray([0, 1, 0, 1, 2])# 这是数据集对应的标签def __getitem__(self, index): # 把numpy转换为Tensor txt = torch.from_numpy(self.Data[index]) label = torch.tensor(self.Label[index]) return txt, labeldef __len__(self): return len(self.Data)

3)获取数据集中数据。
Test = TestDataset() print(Test[2]) print(Test.__len__())

运行结果:
深度学习(基于pytorch)|深度学习笔记(七)——pytorch数据处理工具箱(一)
文章图片

以上数据以tuple返回,每次返回一个样本。实际上,Dataset只负责数据的抽取,调用一次_getitem_只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。DataLoader的格式为:
data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None )

主要参数说明:
dataset:加载的数据集。
batch_size:批大小。
shuffle:是否将数据打乱。
sampler:样本抽样。
num_workers:使用多进程加载的进程数,0代表不使用多进程。
collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU中会更快一些。
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
test_loader = data.DataLoader(Test, batch_size=2, shuffle=False, num_workers=0) for i, traindata in enumerate(test_loader): print('i: ', i) Data, Label = traindata print('data: ', Data) print('Label: ', Label)

运行结果:
深度学习(基于pytorch)|深度学习笔记(七)——pytorch数据处理工具箱(一)
文章图片

从这个结果可以看出,这是批量读取。我们可以像迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器。
test_loader = data.DataLoader(Test, batch_size=1, shuffle=False, num_workers=0) dataiter = iter(test_loader)

针对迭代器的使用,这里给出两个方案:
1)直接对迭代器使用for循环:
for imgs, labels in dataiter: print(imgs) print(labels)

2)将异常检测结构与next函数配合使用。
while True: try: imgs, labels = next(dataiter) print(imgs) print(labels) except StopIteration: break

一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就不太方便。不过,使用PyTorch另一种可视化数据处理工具(即:torchvision)就非常方便,不但可以自动获取标签,还提供了很多数据预处理、数据增强等转换函数。
参考文献
吴茂贵,郁明敏,杨本法,李涛,张粤磊. Python深度学习(基于Pytorch). 北京:机械工业出版社,2019.

    推荐阅读