算法|Pytorch框架训练MNIST数据集

代码:

import torch from torchvision import datasets, transforms import torch.nn as nn import torch.optim as optim from torch.autograd import Variable from datetime import datetimeclass Config: batch_size = 64 epoch = 10 alpha = 1e-3print_per_step = 100# 控制输出class CNN(nn.Module):def __init__(self): super(CNN, self).__init__() """ Conv2d参数: 第一位:input channels输入通道数 第二位:output channels 输出通道数 第三位:kernel size 卷积核尺寸 第四位:stride 步长,默认为1 第五位:padding size 默认为0,不补 """ self.conv1 = nn.Sequential( nn.Conv2d(1, 32, 3, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2) )self.conv2 = nn.Sequential( nn.Conv2d(32, 64, 5), nn.ReLU(), nn.MaxPool2d(2, 2) )self.fc1 = nn.Sequential( nn.Linear(64 * 5 * 5, 128), nn.BatchNorm1d(128), nn.ReLU() )self.fc2 = nn.Sequential( nn.Linear(128, 64), nn.BatchNorm1d(64),# 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面) nn.ReLU() )self.fc3 = nn.Linear(64, 10)def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size()[0], -1) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return xclass LSTM(nn.Module): def __init__(self): super(LSTM, self).__init__()self.lstm = nn.LSTM( input_size=28, hidden_size=64, num_layers=1, batch_first=True, )self.output = nn.Linear(64, 10)def forward(self, x): r_out, (_, _) = self.lstm(x, None)out = self.output(r_out[:, -1, :]) return outclass TrainProcess:def __init__(self, model="CNN"): self.train, self.test = self.load_data() self.model = model if self.model == "CNN": self.net = CNN() elif self.model == "LSTM": self.net = LSTM() else: raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model) self.criterion = nn.CrossEntropyLoss()# 定义损失函数 self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)@staticmethod def load_data(): print("Loading Data......") """加载MNIST数据集,本地数据不存在会自动下载""" train_data = https://www.it610.com/article/datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)test_data = https://www.it610.com/article/datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())# 返回一个数据迭代器 # shuffle:是否打乱顺序 train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=Config.batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=Config.batch_size, shuffle=False) return train_loader, test_loaderdef train_step(self): steps = 0 start_time = datetime.now()print("Training & Evaluating based on '%s'......" % self.model) for epoch in range(Config.epoch): print("Epoch {:3}.".format(epoch + 1))for data, label in self.train: data, label = Variable(data.cpu()), Variable(label.cpu()) # LSTM输入为3维,CNN输入为4维 if self.model == "LSTM": data = https://www.it610.com/article/data.view(-1, 28, 28) self.optimizer.zero_grad()# 将梯度归零 outputs = self.net(data)# 将数据传入网络进行前向运算 loss = self.criterion(outputs, label)# 得到损失函数 loss.backward()# 反向传播 self.optimizer.step()# 通过梯度做一步参数更新# 每100次打印一次结果 if steps % Config.print_per_step == 0: _, predicted = torch.max(outputs, 1) correct = int(sum(predicted == label))# 计算预测正确个数 accuracy = correct / Config.batch_size# 计算准确率 end_time = datetime.now() time_diff = (end_time - start_time).seconds time_usage ='{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60) msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}." print(msg.format(steps, loss, accuracy, time_usage))steps += 1test_loss = 0. test_correct = 0 for data, label in self.test: data, label = Variable(data.cpu()), Variable(label.cpu()) if self.model == "LSTM": data = https://www.it610.com/article/data.view(-1, 28, 28) outputs = self.net(data) loss = self.criterion(outputs, label) test_loss += loss * Config.batch_size _, predicted = torch.max(outputs, 1) correct = int(sum(predicted == label)) test_correct += correctaccuracy = test_correct / len(self.test.dataset) loss = test_loss / len(self.test.dataset) print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))end_time = datetime.now() time_diff = (end_time - start_time).seconds print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))if __name__ == "__main__": p = TrainProcess(model='CNN') p.train_step()

使用默认的CNN网络:
输出:
【算法|Pytorch框架训练MNIST数据集】算法|Pytorch框架训练MNIST数据集
文章图片

上面用的是CNN网络,现在改为LSTM网络:
算法|Pytorch框架训练MNIST数据集
文章图片

训练结果:
算法|Pytorch框架训练MNIST数据集
文章图片

结束~!

    推荐阅读