【mxnet|mxnet模型转pytorch模型】转换基本流程:
1)创建pytorch的网络结构模型;
2)利用mxnet来读取其存储的预训练模型,用于读取模型的参数;
3)遍历mxnet加载的模型参数;
4)对一些指定的key值,需要进行相应的处理和转换;
5)对修改后的层名(key值),利用numpy之间的转换来实现加载;
6)对相应层进行参数(feature)进行比较;
流程基本是与caffe模型转pytorch模型这篇文章一致,唯一需要注意的一点就是:
mxnet中解析的参数有三个:
sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
arg_params是主要参数如weights;
aux_params是辅助参数主要是bias或者是batchnorm中的一些参数;
不是所有参数都在arg_params中,batchnorm层的权值和偏置就是保存在aux_params
以Resnet20为例:
1)(略)
2)加载mxnet模型
#加载符号图与模型参数
def get_model(model_path, epoch):
sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
return sym, arg_params,aux_params
3)4)5)
def init_model(self, model,param_dict,aux_params):
# print(model)
for n, m in model.named_modules():
print(n)
if isinstance(m, BatchNorm2d):
self.bn_init(n, m, param_dict,aux_params)
elif isinstance(m, Conv2d):
self.conv_init(n, m, param_dict)
elif isinstance(m, Linear):
self.fc_init(n, m, param_dict)
elif isinstance(m, PReLU):
self.prelu_init(n, m, param_dict)return modeldef bn_init(self, n, m, param_dict,aux_params):
if not (m.weight is None):
m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_gamma'].asnumpy()))
m.bias.data.copy_(torch.FloatTensor(param_dict[n+'_beta'].asnumpy()))
m.running_mean.copy_(torch.FloatTensor(aux_params[n+'_moving_mean'].asnumpy()))
m.running_var.copy_(torch.FloatTensor(aux_params[n+'_moving_var'].asnumpy()))def conv_init(self, n, m, param_dict):
# print('n = ', n)
m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_weight'].asnumpy()))
if n in ['conv1_1', 'conv4_1', 'conv3_1', 'conv2_1']:
m.bias.data.copy_(torch.FloatTensor(param_dict[n + '_bias'].asnumpy()))def fc_init(self, n, m, param_dict):
m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_weight'].asnumpy()))
m.bias.data.copy_(torch.FloatTensor(param_dict[n+'_bias'].asnumpy()))def prelu_init(self, n, m, net):
m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_gamma'].asnumpy()))
6)(略,我自己也没写,直接使用测试集测试了一下转换后的模型,小数点后四位没有精度偏差,就没有做了)
推荐阅读
- 深度学习|【庖丁解牛】从零实现FCOS(终)(CenterSample的重要性)
- 炼丹|使用FCOS训练自己的数据
- 图像对比度修正|论文学习笔记: Learning Multi-Scale Photo Exposure Correction(含pytorch代码复现)
- 深度学习|yolov5之可视化特征图和检测结果
- 图像分类|保姆级使用PyTorch训练与评估自己的Wide ResNet网络教程
- 强化学习|强化学习-PPO算法实现pendulum
- pytorch|对于torch.nn.AdaptiveAvgPool2d()自适应平均池化函数的一些理解
- 用PyTorch研究张量
- 论文学习|resnet 论文笔记