笔记|Python读取CIFAR10数据集,附代码详解

Python读取CIFAR10数据集 初次接触机器学习,用到的第一个数据集就是CIFAR10。这是一个小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。
从网站上下载好压缩包之后,解压是这样一个文件。data_bath_1到5是训练数据,test_bath是测试数据。
笔记|Python读取CIFAR10数据集,附代码详解
文章图片

下面学习对这个数据集的简单处理。

import pickle import numpy as np import os import matplotlib.pyplot as plt from matplotlib.pyplot import imshow CIFAR_DIR='e:/learn-spyder/cifar-10-batches-py'#这个是CIFAR10文件目录,自行设定的 print(os.listdir(CIFAR_DIR))

首先是导入我们所需要的库,定义好CIFAR10文件所在的文件路径。
pickle:对python对象结构进行二进制序列化和反序列化的协议实现 。
os:操作系统接口模块。
matplotlib.pyplot:数据可视化 ,把数据显示成图形用的。
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
此时print打印的结果就和上图显示的一样为:[‘batches.meta’, ‘data_batch_1’, ‘data_batch_2’, ‘data_batch_3’, ‘data_batch_4’, ‘data_batch_5’, ‘readme.html’, ‘test_batch’]
with open(os.path.join(CIFAR_DIR,'data_batch_2'),'rb') as f: img_dict=pickle.load(f,encoding='bytes') print(img_dict.keys())

【笔记|Python读取CIFAR10数据集,附代码详解】with open()as f :常见的读写操作。 rb: 以二进制格式打开一个文件用于只读。文件指针将会放在文件的开头。
os.path.join :把路径连接起来,运行结果就是 e:/learn-spyder/cifar-10-batches-py\data_batch_2也就是这次实验使用data_batch_2文件中的图片。
pickle.load函数是实现反序列化,将文件中的数据解析为一个python对象,说白了就是将CIFAR10提供的文件读取到Python的数据结构(字典)中,也就是把数据从硬盘中出来。返回的是一个字典类型,每一个键值对应一个value值。
这个时候print打印的结果为:
dict_keys([b’batch_label’, b’labels’, b’data’, b’filenames’])
我们也可以用print去查看每个键值对应的value值,加深对每个键值的理解
print(img_dict[b'batch_label']) print(img_dict[b'labels']) print(img_dict[b'data']) print(img_dict[b'filenames'])

b’batch_label’:对应当前数据集是训练集中的那一份,打印结果为 b’training batch 2 of 5’
b’labels’:可以理解为每个图像的标签,我们可以在batches.meta可以找出对应的字符结果,比如飞机用0表示,打印结果为[1,6,6,8,8…]
b’data’:每张图片的数据,打印结果为:
[[ 35 27 25 … 169 168 168]
[ 20 20 18 … 111 97 51]
[116 115 155 … 18 84 124]

[127 139 155 … 197 192 191]
[190 200 208 … 163 182 192]
[177 174 182 … 119 127 136]]
b’filenames’:对应数据集中每张图片的文件名,打印结果为[b’auto_s_000241.png’, …]
for i in range(0,5): img=np.reshape(img_dict[b'data'][i],(3,32,32)) img=img.transpose((1,2,0)) imshow(img) plt.show() print(img_dict[b'filenames'][i])f.close()

for循环表示我们只显示前5个图片,先把b’date’中的图片数据转换成33232大小。此时img的输出是(channels,imagesize,imagesize) imshow()函数负责对图像进行处理,并显示其格式,但是不能显示图片。其后跟着plt.show()才能显示出来。 但请注意:plt.show()的输入格式为(imagesize,imagesize,channels),故需要通过转置transpose把(channels,imagesize,imagesize)转换成(imagesize,imagesize,channels)
f.close() 表示关闭文件。
再次说明一些 img=img.transpose((1,2,0)),表示专置,这里的0 1 2代表原本的位置,通过改变0 1 2的位置来改变维度。
到此为止,整个代码部分就讲解完啦,运行一下就会出现我们想看到的图片啦
笔记|Python读取CIFAR10数据集,附代码详解
文章图片

    推荐阅读