tensorflow|tensorflow serving (二)(发布自己的服务)

https://www.jianshu.com/p/d673c9507988
通过简单运行了官网例子,对tensorflow serving有了大致的了解,但是怎么把自己的模型发布成服务呢?现在通过一个小例子来学习下。
0. 介绍 这里介绍两种保存模型的方式,发布服务需要的不再是之前保存的ckpt格式数据,而是export出来的模型或者pb模型。通过这两种方式把模型准备好,之后只需要挂在到指定路径下,就可以起服务了。
1. 1 exporter 模型 把官方的half_plus_two简单修改成了half_plus_ten。
与我们保存ckpt不同,需要调用的接口是:

from tensorflow.contrib.session_bundle import exporter

需要把输入输出给重新定义下,然后再用接口导出。
import tensorflow as tf from tensorflow.contrib.session_bundle import exporterdef Export(): export_path = "model/half_plus_ten" with tf.Session() as sess: # Make model parameters a&b variables instead of constants to # exercise the variable reloading mechanisms. a = tf.Variable(0.5) b = tf.Variable(10.0)# Calculate, y = a*x + b # here we use a placeholder 'x' which is fed at inference time. x = tf.placeholder(tf.float32) y = tf.add(tf.multiply(a, x), b)# Run an export. tf.global_variables_initializer().run() export = exporter.Exporter(tf.train.Saver()) export.init(named_graph_signatures={ "inputs": exporter.generic_signature({"x": x}), "outputs": exporter.generic_signature({"y": y}), "regress": exporter.regression_signature(x, y) }) export.export(export_path, tf.constant(123), sess)def main(_): Export()if __name__ == "__main__": tf.app.run()

保存好的模型看起来很像ckpt,但是再checkpoint里面可以看到,是“export”。 “00000123”这个文件名是自动生成的,我也不知道为什么会刚好是这个数字。

tensorflow|tensorflow serving (二)(发布自己的服务)
文章图片
保存好的模型 1.2 保存pb模型 https://www.jianshu.com/p/9221fbf52c55 通过这个教程,我们把模型保存为pb格式。同样把这个模型文件夹挂在到docker相应的目录下。

tensorflow|tensorflow serving (二)(发布自己的服务)
文章图片
保存为pb模型
2. 通过docker起服务 【tensorflow|tensorflow serving (二)(发布自己的服务)】要指定端口,挂载目录,docker才能访问这个模型,挂在的目录得是绝对路径。
  1. export之后的模型挂载。
docker run -t --rm -p 8501:8501 \ -v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \ -e MODEL_NAME=half_plus_ten \ tensorflow/serving

  1. pb模型需要修改挂载路径,可以重新给模型起名字,这里还是用上面的名字“half_plus_ten"。
docker run -t --rm -p 8501:8501 \ -v "$(pwd)/pb_model:/models/half_plus_ten" \ -e MODEL_NAME=half_plus_ten \ tensorflow/serving

3. 测试服务 给它几个值来测试下这个服务。
curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict

能得到half plus ten这个结果!

tensorflow|tensorflow serving (二)(发布自己的服务)
文章图片
输出正确 用python代码访问服务
import osimport requests from time import timeimport numpy as npurl = 'http://localhost:8501/v1/models/half_plus_ten:predict'a = np.array([1,2 ,3,4])predict_request = '{"instances" : [{"input": %s}]}' % list(a)# 一定要list才能传输,不然json错误print("start") start_time = time() r = requests.post(url,data=https://www.it610.com/article/predict_request) print(r.content) end_time = time()

Tips: 代码改写自官方例子:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/servables/tensorflow/testdata/export_half_plus_two.py
代码和模型都放在:
https://github.com/xxlxx1/learing_tf_serving

    推荐阅读