pytorch|深入浅出Pytorch系列(4)(实战--FashionMNIST时装分类)

时装分类的任务 FashionMNIST数据集中包含已经预先划分好的训练集和测试集,其中训练集共60,000张图像,测试集共10,000张图像。每张图像均为单通道黑白图像,大小为32*32pixel,分属10个类别。
首先导入必要的包 import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
配置训练环境和超参数 配置GPU
配置超参数如:batch_size, num_workers, learning rate, 以及总的epochs
数据读入和加载 这里同时展示两种方式:
下载并使用PyTorch提供的内置数据集
从网站下载以csv格式存储的数据,读入并转成预期的格式
第一种数据读入方式只适用于常见的数据集,如MNIST,CIFAR10等,PyTorch官方提供了数据下载。这种方式往往适用于快速测试方法(比如测试下某个idea在MNIST数据集上是否有效)
第二种数据读入方式需要自己构建Dataset,这对于PyTorch应用于自己的工作中十分重要
同时,还需要对数据进行必要的变换,比如说需要将图片统一为一致的大小,以便后续能够输入网络训练;需要将数据格式转为Tensor类,等等。
模型设计 【pytorch|深入浅出Pytorch系列(4)(实战--FashionMNIST时装分类)】通过nn.module以及nn.sequential对网络结构进行搭建
设定损失函数 多分类问题一般使用torch.nn模块自带的nn.CrossEntropy损失
设定优化器 Adam优化器较为常用
训练和测试(验证) 常规做法是将训练和测试各自封装成函数,方便后续调用
两者的主要区别:

  • 模型状态设置
  • 是否需要初始化优化器
  • 是否需要将loss传回到网络
  • 是否需要每步更新optimizer
此外,对于测试或验证过程,可以计算分类准确率
学习链接:
https://github.com/datawhalechina/thorough-pytorch

    推荐阅读