import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nnimport numpy as np
batch_size = 64transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])#把[]中的操作整成一个pipline,均值和标准差train_dataset = datasets.MNIST(root='./dataset/mnist/',
train=True,
download=True,
transform=transform)
train_loader = DataLoader(train_dataset,
shuffle=True,
batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist/',
train=False,
download=True,
transform=transform)
test_loader = DataLoader(test_dataset,
shuffle=False,
batch_size=batch_size)import matplotlib.pyplot as pltfigure = plt.figure()
num_of_images = 60for imgs,tragets in test_loader:
breakfor index in range(num_of_images):
plt.subplot(6,10,index + 1)
plt.axis('off')
img = imgs[index,...]
plt.imshow(img.numpy().squeeze(),cmap = 'gray_r')
plt.show()
【MINIST 数据展示代码可视化minist(深入浅出pytorch)】
文章图片
文章图片
推荐阅读
- Pytorch学习|sklearn-SVM 模型保存、交叉验证与网格搜索
- Pytorch学习|使用torch.load()加载模型参数时,提示“xxx.pt is a zip archive(did you mean to use torch.jit.load()?)“