Pytorch例程|pytorch 实现MNIST数据集建立及训练

文章主要包含:官方数据集导入、自定义数据集,自定义网络结构,训练,训练后的模型使用
头文件导入

import torch import torchvision import torchsummary import os import numpy as np import matplotlib.pyplot as plt

常量定义
BATCH_SIZE = 64 #图像行列像素数量 IMAGE_ROW = 28 IMAGE_COL = 28 #数据根路径 DATA_SOURCE_DIR = "../datasets/MNIST/raw/" TRANSFORM = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.,),(1.,)) ]) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(DEVICE)

数据集导入
数据集说明:
参考网址 https://www.cnblogs.com/xianhan/p/9145966.html
数据集网址 http://yann.lecun.com/exdb/mnist/
train-labels-idx1-ubyte
[offset] [type][value][description]
000032 bit integer0x00000801(2049) magic number (MSB first)
000432 bit integer60000number of items
0008unsigned byte??label
0009unsigned byte??label
........
xxxxunsigned byte??label
train-images-idx3-ubyte
[offset] [type][value][description]
000032 bit integer0x00000803(2051) magic number
000432 bit integer60000number of images
000832 bit integer28number of rows
001232 bit integer28number of columns
0016unsigned byte??pixel
0017unsigned byte??pixel
........
xxxxunsigned byte??pixel
【Pytorch例程|pytorch 实现MNIST数据集建立及训练】t10k-labels-idx1-ubyte
[offset] [type][value][description]
000032 bit integer0x00000801(2049) magic number (MSB first)
000432 bit integer10000number of items
0008unsigned byte??label
0009unsigned byte??label
........
xxxxunsigned byte??label
t10k-images-idx3-ubyte
[offset] [type][value][description]
000032 bit integer0x00000803(2051) magic number
000432 bit integer10000number of images
000832 bit integer28number of rows
001232 bit integer28number of columns
0016unsigned byte??pixel
0017unsigned byte??pixel
........
xxxxunsigned byte??pixel
官方数据集导入
TRAIN_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=True,download=True,transform=TRANSFORM) TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE) TEST_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=False,download=True,transform=TRANSFORM) TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)

查看图片
img,label = TRAIN_DATASETS[0] img = img.numpy() plt.title(label) plt.imshow(img[0])

自定义数据集
torch官方解释文档(纯英文) https://pytorch.org/docs/stable/data.html
torch.utils.data.Dataset源码https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset
自定义DataSet基本结构
calss DataSets(torch.utils.data.Dataset):
def __init__(self):
super(DataSets,self).__init__()
pass
def __getitem__(self,idx):
pass
def __len__(self):
pass
struct.unpack_from(fmt, buf,offset)
fmt: 内容解析格式 '>'or'<' + str(number) + 'B'or'b'or'I'or'i'
buf: 文件缓存
offset:指针偏移量

import struct def decode_idx1_ubyte(idx1_ubyte_file): with open(idx1_ubyte_file, 'rb') as fp: bin_data = https://www.it610.com/article/fp.read() #解析头文件 fmt =">II" magic_number,label_number = struct.unpack_from(fmt, bin_data, 0) offset = 8 #指针偏移量 print("magic number:0x{:0>8x}({})\tlabel number:{}".format(magic_number,magic_number,label_number)) fmt=">B" label=[] for idx in range(label_number): label.append(struct.unpack_from(fmt,bin_data,offset+idx)) return label def decode_idx3_ubyte(idx3_ubyte_file): with open(idx3_ubyte_file, 'rb') as fp: bin_data = https://www.it610.com/article/fp.read() #解析头文件 fmt =">IIII" magic_number,image_number,rows,cols = struct.unpack_from(fmt, bin_data,0) offset = 16 #指针偏移量 print("magic number:0x{:0>8x}({})\t image number:{}".format(magic_number, magic_number, image_number)) print("rows:{}\t columns:{}".format(rows, cols)) fmt='>'+str(rows*cols)+'B' image=[] for idx in range(image_number): data = https://www.it610.com/article/struct.unpack_from(fmt, bin_data, offset+idx*rows*cols) data = np.array(data,dtype=np.uint8).reshape((rows, cols)) image.append(data) image = np.array(image) return image

class MyMNISTDataSets(torch.utils.data.Dataset): def __init__(self,root=DATA_SOURCE_DIR,train=True,transform=None): super(MyMNISTDataSets,self).__init__() self.root = root self.transform = transform self.train = train if self.train: image_path = "train" label_path = "train" else: image_path = "t10k" label_path = "t10k" image_path = image_path+"-images-idx3-ubyte" label_path = label_path+"-labels-idx1-ubyte" image_path = os.path.join(self.root,image_path) label_path = os.path.join(self.root,label_path) self.data, self.targets = decode_idx3_ubyte(image_path),decode_idx1_ubyte(label_path)def __getitem__(self,idx): data,label = self.data[idx], self.targets[idx] label = torch.as_tensor(label,dtype=torch.int64) if self.transform is not None: data = https://www.it610.com/article/self.transform(data) data = data.type(torch.FloatTensor) return data,labeldef __len__(self): return len(self.data)

TRAIN_DATASETS = MyMNISTDataSets(train=True,transform=TRANSFORM) TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE) TEST_DATASETS = MyMNISTDataSets(train=False,transform=TRANSFORM) TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)

查看图片
img,label = TRAIN_DATASETS[0] img = img.numpy() plt.title(label) plt.imshow(img[0])

网络定义 自定义线性网络
class LinearNet(torch.nn.Module): def __init__(self): super(LinearNet,self).__init__() self.l1 = torch.nn.Linear(28*28,512) self.l2 = torch.nn.Linear(512,256) self.l3 = torch.nn.Linear(256,128) self.l4 = torch.nn.Linear(128,64) self.l5 = torch.nn.Linear(64,10) def forward(self,x): x = x.view(-1,IMAGE_ROW*IMAGE_COL) x = torch.nn.functional.relu(self.l1(x)) x = torch.nn.functional.relu(self.l2(x)) x = torch.nn.functional.relu(self.l3(x)) x = torch.nn.functional.relu(self.l4(x)) y = self.l5(x) return y model = LinearNet() model.to(DEVICE) torchsummary.summary(model,(1,28,28))

自定义FCNN
class CNNNet(torch.nn.Module): def __init__(self): super(CNNNet,self).__init__() self.conv1 = torch.nn.Conv2d(in_channels = 1,out_channels = 10,kernel_size=5) self.conv2 = torch.nn.Conv2d(in_channels = 10,out_channels = 20,kernel_size=5) self.pooling = torch.nn.MaxPool2d(2) self.fc = torch.nn.Linear(in_features = 320,out_features = 10) self.relu = torch.nn.ReLU() def forward(self,x): batch_size = x.size(0) x = self.conv1(x) x = self.pooling(x) x = self.relu(x) x = self.conv2(x) x = self.pooling(x) x = self.relu(x) x = x.view(batch_size,-1) x = self.fc(x) return x model = CNNNet() model.to(DEVICE) torchsummary.summary(model,(1,28,28))

模型训练
criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

import sysfor epoch in range(2): model.train() running_loss = 0.0 for batch_idx,data in enumerate(TRAIN_LOADER): inputs,target = data inputs,target = inputs.to(DEVICE),target.to(DEVICE) optimizer.zero_grad() outputs = model(inputs) target = target.squeeze() loss = criterion(outputs,target) loss.backward() optimizer.step()running_loss += loss.item() if batch_idx % 50 == 49: sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t loss:{:.2f}\t\r".format(epoch,"train",(batch_idx+1)/len(TRAIN_LOADER),running_loss/(batch_idx+1))) sys.stdout.flush() sys.stdout.write('\n') sys.stdout.flush()model.eval() correct = 0 total = 0 with torch.no_grad(): for batch_idx,data in enumerate(TEST_LOADER): inputs,target = data inputs,target = inputs.to(DEVICE),target.to(DEVICE) outputs = model(inputs) target = target.squeeze() _,predict = torch.max(outputs.data,dim=1)total += target.size(0) correct += (predict == target).sum().item() if batch_idx % 50 == 49: sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t accuracy:{:.2%}\t\r".format(epoch,"test",(batch_idx+1)/len(TEST_LOADER),correct/total)) sys.stdout.flush() sys.stdout.write('\n') sys.stdout.flush()

结果测试
with torch.no_grad(): choice = np.random.randint(0,len(TEST_DATASETS)) inputs,target = TEST_DATASETS[choice] inputs = torch.as_tensor( inputs.numpy().reshape((1,1,28,28))) inputs,target = inputs.to(DEVICE),target.to(DEVICE) outputs = model(inputs) print(outputs) _,predict = torch.max(outputs.data,dim=1) plt.title(predict) plt.imshow(inputs.to("cpu").numpy()[0,0])


    推荐阅读