tensorflow|tensorflow serving (二)(发布自己的服务)
https://www.jianshu.com/p/d673c9507988
通过简单运行了官网例子,对tensorflow serving有了大致的了解,但是怎么把自己的模型发布成服务呢?现在通过一个小例子来学习下。
0. 介绍
这里介绍两种保存模型的方式,发布服务需要的不再是之前保存的ckpt格式数据,而是export出来的模型或者pb模型。通过这两种方式把模型准备好,之后只需要挂在到指定路径下,就可以起服务了。
1. 1 exporter 模型
把官方的half_plus_two简单修改成了half_plus_ten。
与我们保存ckpt不同,需要调用的接口是:
from tensorflow.contrib.session_bundle import exporter
需要把输入输出给重新定义下,然后再用接口导出。
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporterdef Export():
export_path = "model/half_plus_ten"
with tf.Session() as sess:
# Make model parameters a&b variables instead of constants to
# exercise the variable reloading mechanisms.
a = tf.Variable(0.5)
b = tf.Variable(10.0)# Calculate, y = a*x + b
# here we use a placeholder 'x' which is fed at inference time.
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(a, x), b)# Run an export.
tf.global_variables_initializer().run()
export = exporter.Exporter(tf.train.Saver())
export.init(named_graph_signatures={
"inputs": exporter.generic_signature({"x": x}),
"outputs": exporter.generic_signature({"y": y}),
"regress": exporter.regression_signature(x, y)
})
export.export(export_path, tf.constant(123), sess)def main(_):
Export()if __name__ == "__main__":
tf.app.run()
保存好的模型看起来很像ckpt,但是再checkpoint里面可以看到,是“export”。 “00000123”这个文件名是自动生成的,我也不知道为什么会刚好是这个数字。
文章图片
保存好的模型 1.2 保存pb模型 https://www.jianshu.com/p/9221fbf52c55 通过这个教程,我们把模型保存为pb格式。同样把这个模型文件夹挂在到docker相应的目录下。
文章图片
保存为pb模型
2. 通过docker起服务 【tensorflow|tensorflow serving (二)(发布自己的服务)】要指定端口,挂载目录,docker才能访问这个模型,挂在的目录得是绝对路径。
- export之后的模型挂载。
docker run -t --rm -p 8501:8501 \
-v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
-e MODEL_NAME=half_plus_ten \
tensorflow/serving
- pb模型需要修改挂载路径,可以重新给模型起名字,这里还是用上面的名字“half_plus_ten"。
docker run -t --rm -p 8501:8501 \
-v "$(pwd)/pb_model:/models/half_plus_ten" \
-e MODEL_NAME=half_plus_ten \
tensorflow/serving
3. 测试服务 给它几个值来测试下这个服务。
curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict
能得到half plus ten这个结果!
文章图片
输出正确 用python代码访问服务
import osimport requests
from time import timeimport numpy as npurl = 'http://localhost:8501/v1/models/half_plus_ten:predict'a = np.array([1,2 ,3,4])predict_request = '{"instances" : [{"input": %s}]}' % list(a)# 一定要list才能传输,不然json错误print("start")
start_time = time()
r = requests.post(url,data=https://www.it610.com/article/predict_request)
print(r.content)
end_time = time()
Tips: 代码改写自官方例子:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/servables/tensorflow/testdata/export_half_plus_two.py
代码和模型都放在:
https://github.com/xxlxx1/learing_tf_serving
推荐阅读
- EffectiveObjective-C2.0|EffectiveObjective-C2.0 笔记 - 第二部分
- 遇到一哭二闹三打滚的孩子,怎么办┃山伯教育
- 赢在人生六项精进二阶Day3复盘
- 2019年12月24日
- 陇上秋二|陇上秋二 罗敷媚
- 一百二十三夜,请嫁给我
- 迷失的世界(二十七)
- 我要我们在一起(二)
- 基于|基于 antd 风格的 element-table + pagination 的二次封装
- (二)ES6第一节变量(let|(二)ES6第一节变量(let,const)