Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)


文章目录

    • matplotlib绘图原理和步骤
      • matplotlib.pyplot绘图原理
      • 快速绘图
    • imshow()快速显示Fashion-MNIST数据集图片
    • pyplot.subplots()批量显示FashionMNIST图片
      • subplots() 语法格式:
      • 按一行10个显示FashionMNIST数据集中的图片

我们在使用图像分类数据集时,在载入数据集后,一定会先取一批数据预览一下。
最常用的方法之一是matplotlib库显示图片。
这里自己记录一下图片类数据集常用的图片预览方法,顺便复习matplotlib的绘图原理。
matplotlib绘图原理和步骤 先温习一下matplotlib的绘图原理
matplotlib.pyplot绘图原理
【Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)】简述一下,在使用matplotlib.pyplot模块绘图时,会先创建三大核心元素:
画布(figure)→坐标系(axes)→坐标轴(axis),如图:
Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)
文章图片

因此底层步骤为:
创建画布figure → 立画图区域axes(又叫坐标系)→ 区域内设定坐标轴
→ 使用绘图语句绘图(如plot()函数)→ 使用show()展现出来
创建画布常用 plt.figure() 函数;
建立画图区域常用 fig.subplot() 函数;
坐标轴使用默认,然后再 axes上使用绘图函数绘图。
例子:
import matplotlib.pyplot as pltfig = plt.figure() ax = fig.add_subplot(1,1,1) x = [1,2,3] y = [1,2,3] ax.plot(x, y) plt.show() # 输出:

Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)
文章图片

快速绘图
但是我们使用pyplot画图,如果直接使用绘图函数,matplotlib自己会创建一张默认的画布,再开辟一块默认的坐标轴区域,然后给我们绘制上去。
所以在jupyter notebook中我们可以直接使用plt.plot快速绘图:
Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)
文章图片

imshow()快速显示Fashion-MNIST数据集图片 我们使用FashionMNIST数据集,Fashion-MNIST是一个10分类数据集,包括了衣物、包包、运动鞋等时尚妆扮的类别。
使用imshow()方法直接传入一个符合要求的Tensor,取一张图片快速显示:
(何为符合要求的Tensor见下面注释)
import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # 使用torchvision的datasets类获取FashionMNIST数据集,并新建一个训练集 # 使用了transform = transforms.ToTensor()转换器,使所有数据转换为Tensor。 # 转换成的Tensor 数据类型为torch.float32,位于[0.0, 1.0] mnist_train = torchvision.datasets.FashionMNIST( root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) # feature为图像信息,label为标签 feature, label = mnist_train[0] print(feature.shape, label) plt.imshow(feature.view((28,28))) # 图片是单通道1*28*28像素图片,须转成28*28像素灰度图片来显示 # 显示: (9号类别即为“ankle boot(短靴)”)

Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)
文章图片

pyplot.subplots()批量显示FashionMNIST图片 在一开始的例子里使用过subplot()方法:ax = fig.add_subplot(1,1,1)
其传入figure中axes的位置,返回一个axes对象。
当需要同时展示多张图片的时候,我们可以使用pyplot.subplots()方法,传入图片陈列的行数列数等信息,返回一个figure对象和一个axes对象。
因此代码里常用:fig, ax = plt.subplots(),然后在各个ax中配置参数并绘图。
subplots() 语法格式:
官方文档
matplotlib.pyplot.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)
常用参数:
nrows:图表的行数,默认为 1
ncols:图表的列数,默认为 1
sharex、sharey:设置各个区域 x、y 轴是否使用相同的刻度
**fig_kw:其他关键字参数传递给 pyplot.figure调用,比如设置figsize
按一行10个显示FashionMNIST数据集中的图片
例子:
import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # 新建一个训练集 mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())def show_fashion_mnist(images, labels): # _表示忽略不适用的变量,返回的fig用不上 _, axes = plt.subplots(nrows=1, ncols=len(images), figsize=(12,12)) # 使用zip方法,在for循环中设置axes中的各个子区域的参数并绘图 for ax, img, lbl in zip(axes, images, labels): ax.imshow(img.view((28, 28)).numpy()) ax.set_title(lbl) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) plt.show()X, y = [], [] for i in range(10): X.append(mnist_train[i][0]) y.append(mnist_train[i][1])show_fashion_mnist(X, y) # 显示:

Pytorch学习笔记|【Pytorch学习笔记】3.温习matplotlib——实用的 matplotlib.pyplot 预览图片类数据集的方法(以FashionMNIST为例)
文章图片

    推荐阅读