机器学习模型训练步骤
文章图片
文章图片
一.DataLoader
torch.utils.data.DataLoader()
功能:构建可迭代的数据装载器
? dataset: Dataset类,决定数据从哪读取
及如何读取
? batchsize : 批大小
? num_works: 是否多进程读取数据
? shuffle: 每个epoch是否乱序
? drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
DataLoader( dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
[Epoch、Epoch、Batch]三者之间的关系
- Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
- Iteration:一批样本输入到模型中,称之为一个Iteration
- Batchsize:批大小,决定一个Epoch有多少个Iteration
样本总数:80, Batchsize:8
1 Epoch = 10 Iteration
样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False
二、Dataset
torch.utils.data.Dataset()
功能:Dataset抽象类,所有自定义的
Dataset需要继承它,并且复写
__getitem__()
getitem #接收一个索引,返回一个样本
class Dataset(object): def __getitem__(self, index):
raise NotImplementedError def __add__(self, other):
return ConcatDataset([self, other])
文章图片
文章图片
推荐阅读
- Pytorch教程|Pytorch教程[01]张量操作
- Pytorch教程|Pytorch教程[04]torch.nn---Containers
- DEEP|什么是端到端神经网络()
- 问答系统|简单问答系统实现原理 - 基于机器学习的
- 深度学习|【深度学习】经典网络-VGG复现(使用Tensorflow实现)
- python|整理了 47 个 Python 人工智能库
- 图灵奖得主Yann LeCun走进百度,与世界研究工作者展开交流
- 建设领先的AI原生云,百度智能云落地新一代高性能AI计算集群
- 报名啦!中小企业如何借力AI逆势突围(飞桨中国行定档3月23日!)