【Pytorch|Pytorch之模型加载/保存】pytorch保存模型有两种方法:
- 保存整个模型 (结构+参数)
- 只保存参数(官方推荐)
torch.save(obj, dir)
实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。
保存整个模型 这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。
# 保存
model = Mymodel()
torch.save(model, path)
# 加载
model = torch.load(path)
Note:PyTorch约定使用.pt或.pth后缀命名保存文件。
保存参数 重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:
- 模型参数
- 优化器参数
- loss
- epoch
- args
这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:
# 获得保存信息
save_data = https://www.it610.com/article/{'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'epoch': epoch,
'args': args
...
}
# 保存
torch.save(save_data , path)
load_data = https://www.it610.com/article/torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
# 加载参数
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...
Note:PyTorch约定使用.pt或.pth后缀命名保存文件。
推荐阅读
- C语言学习|第十一届蓝桥杯省赛 大学B组 C/C++ 第一场
- paddle|动手从头实现LSTM
- pytorch|使用pytorch从头实现多层LSTM
- SG平滑轨迹算法的原理和实现
- 人工智能|干货!人体姿态估计与运动预测
- 推荐系统论文进阶|CTR预估 论文精读(十一)--Deep Interest Evolution Network(DIEN)
- Python专栏|数据分析的常规流程
- pytorch|YOLOX 阅读笔记
- 读书笔记|《白话大数据和机器学习》学习笔记1
- Pytorch学习|sklearn-SVM 模型保存、交叉验证与网格搜索