将keras的模型转化为onnx模型
ONNX模型(开放神经网络交换), 是一种用于表示深度学习模型的开放格式。ONNX可以获得很多框架的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。
我们经常使用keras训练一些模型,考虑到不同框架之间的转换和移动,可以考虑采用 onnx模型作为中建模型.
我们使用keras2onnx工具实现这个转换过程.
1. 安装
您可以从PyPi安装Keras2ONNX的最新版本:由于某些原因,软件包的发布已暂停,请从源代码进行安装,并且tensorflow 2.x上对keras或tf.keras的支持仅在源代码中可用。
安装使用pip install keras2onnx
或从源代码安装:
pip install -U git + https://github.com/microsoft/onnxconverter-common
pip install -U git + https://github.com/onnx/keras-onnx
在运行转换器之前,请注意,必须在python环境中安装tensorflow,您可以选择tensorflow / tensorflow-cpu软件包(CPU版本)或tensorflow-gpu(GPU版本)
Keras2ONNX依赖于onnxconverter-common。 实际上,此转换器的最新代码需要onnxconverter-common的最新版本,因此,如果从源代码安装此转换器,请在安装keras2onnx之前以源代码模式安装onnxconverter-common。
通过调用keras2onnx.convert_keras可以成功转换大多数Keras模型,包括CV,GAN,NLP,Speech。 但是,某些具有许多自定义操作的模型需要自定义转换,以下是一些示例,例如YOLOv3和Mask RCNN。
2.代码示例
代码1:
import keras
import keras2onnx
import onnx
from keras.models import load_model
model = load_model('./model/mymodel5.h5')
onnx_model = keras2onnx.convert_keras(model, model.name)
temp_model_file = './model/mymodel.onnx'
onnx.save_model(onnx_model, temp_model_file)
【将keras的模型转化为onnx模型】代码2
import keras2onnx
keras2onnx.convert_keras(model, name=None, doc_string='', target_opset=None, channel_first_inputs=None):
# type: (keras.Model, str, str, int, []) -> onnx.ModelProto
"""
:param model: keras model
:param name: the converted onnx model internal name
:param doc_string:
:param target_opset:
:param channel_first_inputs: A list of channel first input.
:return:
"""
import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
import keras2onnx
import onnxruntime# image preprocessing
img_path = 'street.jpg'# make sure the image is in img_path
img_size = 224
img = image.load_img(img_path, target_size=(img_size, img_size))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)# load keras model
from keras.applications.resnet50 import ResNet50
model = ResNet50(include_top=True, weights='imagenet')# convert to onnx model
onnx_model = keras2onnx.convert_keras(model, model.name)# runtime prediction
content = onnx_model.SerializeToString()
sess = onnxruntime.InferenceSession(content)
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)temp_model_file = 'model.onnx'
keras2onnx.save_model(onnx_model, temp_model_file)
sess = onnxruntime.InferenceSession(temp_model_file)
推荐阅读
- 热闹中的孤独
- JAVA(抽象类与接口的区别&重载与重写&内存泄漏)
- 放屁有这三个特征的,请注意啦!这说明你的身体毒素太多
- 一个人的旅行,三亚
- 布丽吉特,人生绝对的赢家
- 慢慢的美丽
- 尽力
- 一个小故事,我的思考。
- 家乡的那条小河
- 《真与假的困惑》???|《真与假的困惑》??? ——致良知是一种伟大的力量