文章主要包含:官方数据集导入、自定义数据集,自定义网络结构,训练,训练后的模型使用
头文件导入
import torch
import torchvision
import torchsummary
import os
import numpy as np
import matplotlib.pyplot as plt
常量定义
BATCH_SIZE = 64
#图像行列像素数量
IMAGE_ROW = 28
IMAGE_COL = 28
#数据根路径
DATA_SOURCE_DIR = "../datasets/MNIST/raw/"
TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.,),(1.,))
])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)
数据集导入
数据集说明:官方数据集导入
参考网址 https://www.cnblogs.com/xianhan/p/9145966.html
数据集网址 http://yann.lecun.com/exdb/mnist/
train-labels-idx1-ubyte
[offset] [type][value][description]
000032 bit integer0x00000801(2049) magic number (MSB first)
000432 bit integer60000number of items
0008unsigned byte??label
0009unsigned byte??label
........
xxxxunsigned byte??label
train-images-idx3-ubyte
[offset] [type][value][description]
000032 bit integer0x00000803(2051) magic number
000432 bit integer60000number of images
000832 bit integer28number of rows
001232 bit integer28number of columns
0016unsigned byte??pixel
0017unsigned byte??pixel
........
xxxxunsigned byte??pixel
【Pytorch例程|pytorch 实现MNIST数据集建立及训练】t10k-labels-idx1-ubyte
[offset] [type][value][description]
000032 bit integer0x00000801(2049) magic number (MSB first)
000432 bit integer10000number of items
0008unsigned byte??label
0009unsigned byte??label
........
xxxxunsigned byte??label
t10k-images-idx3-ubyte
[offset] [type][value][description]
000032 bit integer0x00000803(2051) magic number
000432 bit integer10000number of images
000832 bit integer28number of rows
001232 bit integer28number of columns
0016unsigned byte??pixel
0017unsigned byte??pixel
........
xxxxunsigned byte??pixel
TRAIN_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=True,download=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=False,download=True,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)
查看图片
img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])
自定义数据集
torch官方解释文档(纯英文) https://pytorch.org/docs/stable/data.html
torch.utils.data.Dataset源码https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset
自定义DataSet基本结构
calss DataSets(torch.utils.data.Dataset):
def __init__(self):
super(DataSets,self).__init__()
pass
def __getitem__(self,idx):
pass
def __len__(self):
pass
struct.unpack_from(fmt, buf,offset)
fmt: 内容解析格式 '>'or'<' + str(number) + 'B'or'b'or'I'or'i'
buf: 文件缓存
offset:指针偏移量
import struct
def decode_idx1_ubyte(idx1_ubyte_file):
with open(idx1_ubyte_file, 'rb') as fp:
bin_data = https://www.it610.com/article/fp.read()
#解析头文件
fmt =">II"
magic_number,label_number = struct.unpack_from(fmt, bin_data, 0)
offset = 8 #指针偏移量
print("magic number:0x{:0>8x}({})\tlabel number:{}".format(magic_number,magic_number,label_number))
fmt=">B"
label=[]
for idx in range(label_number):
label.append(struct.unpack_from(fmt,bin_data,offset+idx))
return label
def decode_idx3_ubyte(idx3_ubyte_file):
with open(idx3_ubyte_file, 'rb') as fp:
bin_data = https://www.it610.com/article/fp.read()
#解析头文件
fmt =">IIII"
magic_number,image_number,rows,cols = struct.unpack_from(fmt, bin_data,0)
offset = 16 #指针偏移量
print("magic number:0x{:0>8x}({})\t image number:{}".format(magic_number, magic_number, image_number))
print("rows:{}\t columns:{}".format(rows, cols))
fmt='>'+str(rows*cols)+'B'
image=[]
for idx in range(image_number):
data = https://www.it610.com/article/struct.unpack_from(fmt, bin_data, offset+idx*rows*cols)
data = np.array(data,dtype=np.uint8).reshape((rows, cols))
image.append(data)
image = np.array(image)
return image
class MyMNISTDataSets(torch.utils.data.Dataset):
def __init__(self,root=DATA_SOURCE_DIR,train=True,transform=None):
super(MyMNISTDataSets,self).__init__()
self.root = root
self.transform = transform
self.train = train
if self.train:
image_path = "train"
label_path = "train"
else:
image_path = "t10k"
label_path = "t10k"
image_path = image_path+"-images-idx3-ubyte"
label_path = label_path+"-labels-idx1-ubyte"
image_path = os.path.join(self.root,image_path)
label_path = os.path.join(self.root,label_path)
self.data, self.targets = decode_idx3_ubyte(image_path),decode_idx1_ubyte(label_path)def __getitem__(self,idx):
data,label = self.data[idx], self.targets[idx]
label = torch.as_tensor(label,dtype=torch.int64)
if self.transform is not None:
data = https://www.it610.com/article/self.transform(data)
data = data.type(torch.FloatTensor)
return data,labeldef __len__(self):
return len(self.data)
TRAIN_DATASETS = MyMNISTDataSets(train=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = MyMNISTDataSets(train=False,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)
查看图片
img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])
网络定义 自定义线性网络
class LinearNet(torch.nn.Module):
def __init__(self):
super(LinearNet,self).__init__()
self.l1 = torch.nn.Linear(28*28,512)
self.l2 = torch.nn.Linear(512,256)
self.l3 = torch.nn.Linear(256,128)
self.l4 = torch.nn.Linear(128,64)
self.l5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,IMAGE_ROW*IMAGE_COL)
x = torch.nn.functional.relu(self.l1(x))
x = torch.nn.functional.relu(self.l2(x))
x = torch.nn.functional.relu(self.l3(x))
x = torch.nn.functional.relu(self.l4(x))
y = self.l5(x)
return y
model = LinearNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))
自定义FCNN
class CNNNet(torch.nn.Module):
def __init__(self):
super(CNNNet,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels = 1,out_channels = 10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(in_channels = 10,out_channels = 20,kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(in_features = 320,out_features = 10)
self.relu = torch.nn.ReLU()
def forward(self,x):
batch_size = x.size(0)
x = self.conv1(x)
x = self.pooling(x)
x = self.relu(x)
x = self.conv2(x)
x = self.pooling(x)
x = self.relu(x)
x = x.view(batch_size,-1)
x = self.fc(x)
return x
model = CNNNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))
模型训练
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
import sysfor epoch in range(2):
model.train()
running_loss = 0.0
for batch_idx,data in enumerate(TRAIN_LOADER):
inputs,target = data
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
optimizer.zero_grad()
outputs = model(inputs)
target = target.squeeze()
loss = criterion(outputs,target)
loss.backward()
optimizer.step()running_loss += loss.item()
if batch_idx % 50 == 49:
sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t loss:{:.2f}\t\r".format(epoch,"train",(batch_idx+1)/len(TRAIN_LOADER),running_loss/(batch_idx+1)))
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx,data in enumerate(TEST_LOADER):
inputs,target = data
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
outputs = model(inputs)
target = target.squeeze()
_,predict = torch.max(outputs.data,dim=1)total += target.size(0)
correct += (predict == target).sum().item()
if batch_idx % 50 == 49:
sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t accuracy:{:.2%}\t\r".format(epoch,"test",(batch_idx+1)/len(TEST_LOADER),correct/total))
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
结果测试
with torch.no_grad():
choice = np.random.randint(0,len(TEST_DATASETS))
inputs,target = TEST_DATASETS[choice]
inputs = torch.as_tensor( inputs.numpy().reshape((1,1,28,28)))
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
outputs = model(inputs)
print(outputs)
_,predict = torch.max(outputs.data,dim=1)
plt.title(predict)
plt.imshow(inputs.to("cpu").numpy()[0,0])
推荐阅读
- 学习|(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第二天(加载 MNIST 数据集)
- Python|(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第三天(训练模型)
- Python|(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天(单例测试)
- 人脸识别|推荐 6 个 yyds 的人脸识别系统
- 深度学习|神经网络中的激活函数与损失函数&深入理解推导softmax交叉熵
- 算法|Pytorch框架训练MNIST数据集
- 枚举模拟|2020年第十一届蓝桥杯省赛Python组(真题+解析+代码)(蛇形填数)
- 蓝桥杯|蓝桥杯备战 每日训练3道 真题解析
- 二叉树(Binary|LeetCode 536. Construct Binary Tree from String - 二叉树系列题18