Python|(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天(单例测试)

1. Introduction 今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第四天,主要学习导入模型并进行单例测试。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108098147
第二天(加载 MNIST 数据集):https://blog.csdn.net/qq_36627158/article/details/108119048
第三天(训练模型):https://blog.csdn.net/qq_36627158/article/details/108163693
第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108183655




2. Code(mnist_classify.py) 感谢 凯神 提供的代码与耐心指导!

from torchvision import transforms from PIL import Image, ImageOps from mnist_train import *classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9') model = Net()def load_checkpoint(checkpoint_path, model): state = torch.load(checkpoint_path) model.load_state_dict(state['model'])if __name__ == '__main__': load_checkpoint( 'module/pytorch-mnist-batch-128-1407.pth', model )model = model.to(device) model.eval()img = Image.open("/home/ubuntu/Downloads/C6/3.jpg") img = ImageOps.invert(img)# rgb -> single channel image if len(img.split()) > 1: img = img.split()[0]plt.figure() plt.imshow(img) plt.show()trans = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), ]) img = trans(img)img = img.to(device)img = img.unsqueeze(0)output = model(img) prob = F.softmax(output, dim=1)max_value, max_index = torch.max(prob, 1)pred_class = classes[max_index.item()] print 'predicted class is', pred_class, ', probability is', round(max_value.item(), 6) * 100




3. Details 1、im.split()
r, g, b=im.split()该函数用来将RGB图片分割成三个通道的图片
Python-Image 基本的图像处理操作

2、torch.unsqueeze()
为 Torch Tensor 添加维度
【Python|(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天(单例测试)】https://blog.csdn.net/xiexu911/article/details/80820028

    推荐阅读