pytorch|pytorch基本操作(使用神经网络进行分类任务)

1.读取Mnist数据 首先,读取Mnist数据,在深度学习框架中,数据的基本结构是tensor,据需转换成tensor才能参与后续建模训练,可用map函数将数据转换为tensor格式

import torchx_train, y_train, x_valid, y_valid = map( torch.tensor, (x_train, y_train, x_valid, y_valid) ) n, c = x_train.shape x_train, x_train.shape, y_train.min(), y_train.max() print(x_train, y_train) print(x_train.shape) print(y_train.min(), y_train.max())

pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片

2.torch.nn.functionaltorch.nn.functional中有很多功能, 比如,常见的激活函数、损失函数,一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片

3.创建一个model
  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nnclass Mnist_NN(nn.Module): def __init__(self): super().__init__() self.hidden1 = nn.Linear(784, 128) self.hidden2 = nn.Linear(128, 256) self.out= nn.Linear(256, 10)def forward(self, x): x = F.relu(self.hidden1(x)) x = F.relu(self.hidden2(x)) x = self.out(x) return x

打印出来:
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片

通过named_parameters()或者parameters()返回迭代器
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片

4.使用TensorDataset和DataLoader加载数据TensorDataset:将训练数据的特征和标签组合
DataLoader:随机读取小批量
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片


5.训练模块 梯度下降方法和损失函数
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片

torch默认会叠加梯度,所以结束后需要将梯度置零
pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片


  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl): for step in range(steps): model.train() for xb, yb in train_dl: loss_batch(model, loss_func, xb, yb, opt)model.eval() with torch.no_grad(): # 验证时不进行梯度下降 losses, nums = zip( *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl] ) val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) # 平均损失 print('当前step:'+str(step), '验证集损失:'+str(val_loss))

pytorch|pytorch基本操作(使用神经网络进行分类任务)
文章图片









【pytorch|pytorch基本操作(使用神经网络进行分类任务)】

    推荐阅读