Pytorch|Pytorch之模型加载/保存

【Pytorch|Pytorch之模型加载/保存】pytorch保存模型有两种方法:

  1. 保存整个模型 (结构+参数)
  2. 只保存参数(官方推荐)
两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。
保存整个模型 这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。
# 保存 model = Mymodel() torch.save(model, path) # 加载 model = torch.load(path)

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。
保存参数 重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:
  • 模型参数
  • 优化器参数
  • loss
  • epoch
  • args
把这些信息用字典包装起来,然后保存即可。
这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:
# 获得保存信息 save_data = https://www.it610.com/article/{'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'epoch': epoch, 'args': args ... } # 保存 torch.save(save_data , path) load_data = https://www.it610.com/article/torch.load(path) model = Mymodel() optimizer = Myoptimizer() # 加载参数 model.load_state_dict(load_data ['model_state_dict']) optimizer.load_state_dict(load_data ['optimizer_state_dict']) ...

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

    推荐阅读