文章目录
- 1.什么是早停止?为什么使用早停止?
- 2.如何使用早停止?
- 3. Refferences
1.什么是早停止?为什么使用早停止? 早停止(Early Stopping)是 当达到某种或某些条件时,认为模型已经收敛,结束模型训练,保存现有模型的一种手段。
机器学习或深度学习中,有很大一批算法是依靠梯度下降,求来优化模型的。是通过更新参数,让Loss往小的方向走,来优化模型的。可参考BP神经网络推导过程详解
关于模型何时收敛(模型训练好了,性能达到要求了或不能再优化了),此时我们可以采取一些判断标准:
1.验证集上的Loss在模型多次迭代后,没有下降
2.验证集上的Loss开始上升
3.验证集上的准确率在模型多次迭代后,没有上升
3.验证集上的准确率开始下降
……
这时,一般可以认为,模型没必要再训练了,可以及时结束训练了。这就被称为早停止,也是避免模型过拟合的一种方法(不等模型拟合,就结束训练了)。
2.如何使用早停止? early_stopping.py
import numpy as np
import torch
import osclass EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, save_path, patience=7, verbose=False, delta=0):
"""
Args:
save_path : 模型保存文件夹
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.save_path = save_path
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = deltadef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).Saving model ...')
path = os.path.join(self.save_path, 'best_network.pth')
torch.save(model.state_dict(), path) # 这里会存储迄今最优模型的参数
self.val_loss_min = val_loss
把该文件拷贝到自己项目中,
在需要使用早停止的文件中,导入:
from early_stopping import EarlyStopping
使用示例(大致代码):
train_losses = []
train_acces = []
# 用数组保存每一轮迭代中,在测试数据上测试的损失值和精确度,也是为了通过画图展示出来。
eval_losses = []
eval_acces = []save_path = ".\\" #当前目录下
early_stopping = EarlyStopping(save_path)for e in range(20000):# 4.1==========================训练模式==========================
train_loss = 0
train_acc = 0
model.train()# 将模型改为训练模式# 每次迭代都是处理一个小批量的数据,batch_size是64
for im, label in train_data:
im = Variable(im)
targets = Variable(label)# 计算前向传播,并且得到损失函数的值
outputs = model(im)
loss = criterion(outputs, targets)#add by tyb#model.save_metrics(metrics)
# 反向传播,记得要把上一次的梯度清0,反向传播,并且step更新相应的参数。
optimizer.zero_grad()loss.backward()
optimizer.step()
#scheduler.step()# 记录误差
train_loss += loss.item()# 计算分类的准确率
out_t = outputs.argmax(dim=1) #取出预测的最大值
num_correct = (out_t == targets).sum().item()
acc = num_correct / im.shape[0]
train_acc += acctrain_losses.append(train_loss / len(train_data))
train_acces.append(train_acc / len(train_data))# 4.2==========================每次进行完一个训练迭代,就去测试一把看看此时的效果==========================
# 在测试集上检验效果
eval_loss = 0
eval_acc = 0model.eval()# 将模型改为预测模式# 每次迭代都是处理一个小批量的数据,batch_size是128
for im, label in test_data:#print("test_data len:",len(test_data))
im = Variable(im)# torch中训练需要将其封装即Variable,此处封装像素即784
label = Variable(label)# 此处为标签out = model(im)# 经网络输出的结果
loss = criterion(out, label)# 得到误差# 记录误差
eval_loss += loss.item()# 记录准确率
out_t = out.argmax(dim=1)# 取出预测的最大值的索引
num_correct = (out_t == label).sum().item()# 判断是否预测正确
acc = num_correct / im.shape[0]# 计算准确率
eval_acc += acceval_losses.append(eval_loss / len(test_data))
eval_acces.append(eval_acc / len(test_data))
#scheduler.step()print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
.format(e, train_loss / len(train_data), train_acc / len(train_data),
eval_loss / len(test_data), eval_acc / len(test_data)))# 早停止
early_stopping(eval_loss, model)
#达到早停止条件时,early_stop会被置为True
if early_stopping.early_stop:
print("Early stopping")
break #跳出迭代,结束训练
【pytorch|PyTorch框架中使用早停止Early Stopping(含详细代码)】未用早停止:训练集和验证集上的accuracy和loss曲线
文章图片
使用早停止:训练集和验证集上的accuracy和loss曲线
文章图片
3. Refferences
- 在 Pytorch 中实现 early stopping
- 线性代数及其应用——“早停止”与“L2正则”的关系
- BP神经网络推导过程详解
推荐阅读
- python|10、python——模块与包
- python|python2和python3版本的区别
- 强化学习|强化学习笔记(七)演员-评论家算法(Actor-Critic Algorithms)及Pytorch实现
- 强化学习|强化学习笔记(五)Pytorch实现简单DQN
- opencv|基于深度学习的口罩识别与检测PyTorch实现
- 定位|利用稀疏的语义视觉特征进行道路建图和定位(ICRA2021)
- 可视化|opencv基础入门——环境搭建与基础操作
- Python学习笔记|Newton法求解非线性方程(Python实现)
- Python学习笔记|利用pyecharts绘制省市地图