给定分配给它的随机参数, 我们绘制了线性模型。我们发现它与我们的数据不太吻合。我们要做的。我们需要训练该模型, 以便该模型具有最佳的权重和偏差参数并拟合该数据。
有以下步骤可以训练模型:
步骤1
我们的第一步是指定损失函数, 我们打算将其最小化。 PyTorch提供了一种非常有效的方法来指定丢失的功能。 PyTorch提供MSELoss()函数(称为均方损失), 以
criterion=nn.MSELoss()
第2步
现在, 我们的下一步是更新参数。为此, 我们指定使用梯度下降算法的优化器。我们使用称为随机梯度下降的SGD()函数进行优化。 SGD一次可以减少一个样本的总损失, 并且通常可以更快地收敛, 因为它会在相同样本大小内频繁更新模型的权重。
optimizer=torch.optim.SGD(model.parameters(), lr=0.01)
在此, lr代表学习率, 最初设置为0.01。
第三步
我们将针对指定的时期数训练模型(我们计算了误差函数, 并对该误差函数的梯度下降进行了反向传播以更新权重)。
epochs=100
现在, 对于每个时代, 我们都必须最小化模型系统的误差。误差只是模型预测与实际值之间的比较。
Losses=[]For i in range (epochs): ypred=model.forward(x) #Prediction of y loss=criterion(ypred, y) #Find loss losses.append()# Add loss in list optimizer.zero_grad() # Set the gradient to zero loss.backward() #To compute derivatives optimizer.step() # Update the parameters
步骤4
现在, 最后, 我们只需调用plotfit()方法来绘制新的线性模型。
plotfit('Trained Model')
完整的代码 【PyTorch线性模型训练实例图解】程序
import torchimport torch.nn as nnimport matplotlib.pyplot as pltimport numpy as npX=torch.randn(100, 1)*10y=X+3*torch.randn(100, 1)plt.plot(X.numpy(), y.numpy(), 'o')plt.ylabel('y')plt.xlabel('x')class LR(nn.Module): def __init__(self, input_size, output_size):super().__init__()self.linear=nn.Linear(input_size, output_size) def forward(self, x):pred=self.linear(X)return pred torch.manual_seed(1) #For consistency of random result model=LR(1, 1)criterion=nn.MSELoss() #Using Loss Functionoptimizer=torch.optim.SGD(model.parameters(), lr=0.01)#Using optimizer which uses GD algorithmprint(model)[a, b]=model.parameters() #Unpacking of parametersepochs=100losses=[]for i in range(epochs): ypred=model.forward(X) loss=criterion(ypred, y) print("epoch:", i, "loss:", loss.item()) losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()defgrtparameters():return(a[0][0].item(), b[0].item())defplotfit(title): plt.title=title a1, b1=grtparameters() x1=np.array([-30, 30]) y1=a1*x1+b1 plt.plot(x1, y1, 'r') plt.scatter(X, y) plt.show()plotfit('Trained Model')
输出
文章图片
文章图片
推荐阅读
- PyTorch如何训练感知器模型(实例图解————)
- PyTorch实战(卷积神经网络模型的训练)
- 微信红包提示:解析失败的原因与处理办法_微信
- 微信红包设置异常,7天内暂时无法领红包的处理办法_微信
- 维嘉微博数学题答案是啥?_新浪微博
- 新浪微博怎样防范评论带图?微博防范评论带图办法_新浪微博
- 可达鸭头像LOL英雄联盟系列_微信
- 厘米人AI是啥?厘米人AI介绍
- 王者荣耀系列可达鸭头像大全_微信