时装分类的任务 FashionMNIST数据集中包含已经预先划分好的训练集和测试集,其中训练集共60,000张图像,测试集共10,000张图像。每张图像均为单通道黑白图像,大小为32*32pixel,分属10个类别。
首先导入必要的包 import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
配置训练环境和超参数 配置GPU
配置超参数如:batch_size, num_workers, learning rate, 以及总的epochs
数据读入和加载 这里同时展示两种方式:
下载并使用PyTorch提供的内置数据集
从网站下载以csv格式存储的数据,读入并转成预期的格式
第一种数据读入方式只适用于常见的数据集,如MNIST,CIFAR10等,PyTorch官方提供了数据下载。这种方式往往适用于快速测试方法(比如测试下某个idea在MNIST数据集上是否有效)
第二种数据读入方式需要自己构建Dataset,这对于PyTorch应用于自己的工作中十分重要
同时,还需要对数据进行必要的变换,比如说需要将图片统一为一致的大小,以便后续能够输入网络训练;需要将数据格式转为Tensor类,等等。
模型设计 【pytorch|深入浅出Pytorch系列(4)(实战--FashionMNIST时装分类)】通过nn.module以及nn.sequential对网络结构进行搭建
设定损失函数 多分类问题一般使用torch.nn模块自带的nn.CrossEntropy损失
设定优化器 Adam优化器较为常用
训练和测试(验证) 常规做法是将训练和测试各自封装成函数,方便后续调用
两者的主要区别:
- 模型状态设置
- 是否需要初始化优化器
- 是否需要将loss传回到网络
- 是否需要每步更新optimizer
学习链接:
https://github.com/datawhalechina/thorough-pytorch
推荐阅读
- pytorch|深入浅出Pytorch系列(3)(主要组成模块)
- Python|pytorch,yolov5模型经onnx到Android(三)
- pytorch|yolov5-gpu版本部署与测试中遇到的问题与解决
- 【YOLOv5】6.0环境搭建(不定时更新)
- 深入浅出pytorch(四)
- 数据挖掘分析|Weka数据挖掘——分类
- 网络|视觉注意力机制概述
- 数据挖掘与数据仓库|数据挖掘与数据仓库——分类
- pytorch深度学习实战|Mask R-CNN详解(图文并茂)