Pytorch教程|Pytorch教程[02]DataLoader与Dataset

机器学习模型训练步骤 Pytorch教程|Pytorch教程[02]DataLoader与Dataset
文章图片

Pytorch教程|Pytorch教程[02]DataLoader与Dataset
文章图片

一.DataLoader torch.utils.data.DataLoader()
功能:构建可迭代的数据装载器
? dataset: Dataset类,决定数据从哪读取
及如何读取
? batchsize : 批大小
? num_works: 是否多进程读取数据
? shuffle: 每个epoch是否乱序
? drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

[Epoch、Epoch、Batch]三者之间的关系
  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration
【Pytorch教程|Pytorch教程[02]DataLoader与Dataset】例:
样本总数:80, Batchsize:8
1 Epoch = 10 Iteration
样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False
二、Dataset torch.utils.data.Dataset()
功能:Dataset抽象类,所有自定义的
Dataset需要继承它,并且复写
__getitem__() getitem #接收一个索引,返回一个样本

class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])

Pytorch教程|Pytorch教程[02]DataLoader与Dataset
文章图片

Pytorch教程|Pytorch教程[02]DataLoader与Dataset
文章图片

    推荐阅读