PyTorch线性模型训练实例图解

给定分配给它的随机参数, 我们绘制了线性模型。我们发现它与我们的数据不太吻合。我们要做的。我们需要训练该模型, 以便该模型具有最佳的权重和偏差参数并拟合该数据。
有以下步骤可以训练模型:
步骤1
我们的第一步是指定损失函数, 我们打算将其最小化。 PyTorch提供了一种非常有效的方法来指定丢失的功能。 PyTorch提供MSELoss()函数(称为均方损失), 以

criterion=nn.MSELoss()

第2步
现在, 我们的下一步是更新参数。为此, 我们指定使用梯度下降算法的优化器。我们使用称为随机梯度下降的SGD()函数进行优化。 SGD一次可以减少一个样本的总损失, 并且通常可以更快地收敛, 因为它会在相同样本大小内频繁更新模型的权重。
optimizer=torch.optim.SGD(model.parameters(), lr=0.01)

在此, lr代表学习率, 最初设置为0.01。
第三步
我们将针对指定的时期数训练模型(我们计算了误差函数, 并对该误差函数的梯度下降进行了反向传播以更新权重)。
epochs=100

现在, 对于每个时代, 我们都必须最小化模型系统的误差。误差只是模型预测与实际值之间的比较。
Losses=[]For i in range (epochs): ypred=model.forward(x) #Prediction of y loss=criterion(ypred, y) #Find loss losses.append()# Add loss in list optimizer.zero_grad() # Set the gradient to zero loss.backward() #To compute derivatives optimizer.step() # Update the parameters

步骤4
现在, 最后, 我们只需调用plotfit()方法来绘制新的线性模型。
plotfit('Trained Model')

完整的代码 【PyTorch线性模型训练实例图解】程序
import torchimport torch.nn as nnimport matplotlib.pyplot as pltimport numpy as npX=torch.randn(100, 1)*10y=X+3*torch.randn(100, 1)plt.plot(X.numpy(), y.numpy(), 'o')plt.ylabel('y')plt.xlabel('x')class LR(nn.Module): def __init__(self, input_size, output_size):super().__init__()self.linear=nn.Linear(input_size, output_size) def forward(self, x):pred=self.linear(X)return pred torch.manual_seed(1) #For consistency of random result model=LR(1, 1)criterion=nn.MSELoss() #Using Loss Functionoptimizer=torch.optim.SGD(model.parameters(), lr=0.01)#Using optimizer which uses GD algorithmprint(model)[a, b]=model.parameters() #Unpacking of parametersepochs=100losses=[]for i in range(epochs): ypred=model.forward(X) loss=criterion(ypred, y) print("epoch:", i, "loss:", loss.item()) losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()defgrtparameters():return(a[0][0].item(), b[0].item())defplotfit(title): plt.title=title a1, b1=grtparameters() x1=np.array([-30, 30]) y1=a1*x1+b1 plt.plot(x1, y1, 'r') plt.scatter(X, y) plt.show()plotfit('Trained Model')

输出
PyTorch线性模型训练实例图解

文章图片
PyTorch线性模型训练实例图解

文章图片

    推荐阅读