基础实战——FashionMNIST时装分类 这里基础实战已经给出了pytorch基础实战代码,配置好环境后直接打开,逐步运行即可。
整个流程为
导入必要的包 配置训练环境和超参数 数据读取和加载 数据格式变换 定义DataLoader 可视化验证读取数据 数据集读取 模型构建 设定损失函数 设定优化器 训练及测试 训练完毕
- 读取数据时,由于我们选择使用的是MNIST数据集,由PyTorch,torchvision提供的内置数据集。方式一直接下载速度较慢,选择第二种
- 自行构建Dataset类,提前下载好数据集,读入csv格式的数据
- 定义一个Dataloader类便于后续数据读取
- 设计训练模型,这里只是给了个较容易上手的CNN模型,中间模块化程序较多,需要自己理解。
- 设定损失函数,这里用的是自带的Crossentropy交叉熵损失。
- 设定优化器,选择Adam优化器
- 最后,将各部分代码封装,便于后续改进
- 根据结果?进行训练,调参?
- 至此模型训练完毕。
更新,直到最后,还没意识到其中问题
- 前面都没啥异常,只是最后一步,训练模型的时候一直卡在那里,未果,发现中间并没有使用GPU,反复安装环境,手动下载,pip安装torch对应的cuda版本的GPU包,.whl文件,至此torch.cuda.is_available()结果为True
- jupter notebook内核显示内核挂掉了?将会立即重启?
推荐阅读
- 深度学习|DSC和HD医学图像分割评价指标
- 深度学习|TransUnet官方代码测试自己的数据集(已训练完毕)
- 计算机视觉|TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)
- 华为|【Anaconda配置深度学习环境(Tensorflow或Pytorch或MindSpore)】
- pytorch|PyTorch模型转caffe
- 论文学习|Hierarchical Transformer Model for Scientific Named Entity Recognition 论文总结
- 人工智能+大数据|深入浅出pytorch求导机制
- pytorch|pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
- 超分算法在 WebRTC 高清视频传输弱网优化中的应用