【Tensorflow之Java部署方案】最近使用Tensorflow的Estimator高阶API进行模型训练,支持保存成checkpoint和saved model格式。
其中saved model可以使用Tensorflow Serving进行部署.
但是目前公司内部还是使用tensorflow的Java API进行部署,所以需要把saved model转化成pb格式,
并用Java进行部署。希望后续能够转移到Tensorflow Serving方式部署,毕竟这是官方推荐的部署方式,而且无论从扩展性,
易用性都能得到保证,而Java API直接就告诉你这是Experimental的了, 还不是稳定版。
saved_model转pb 由于从estimator保存的格式为saved model,所以第一步是需要转成pd格式.
需要注意的是estimator在export_saved_model时,需要指定好对应的placeholder,确保后续有对应的feed的入口。
- 使用tf.saved_model.loader.load加载saved model格式,其中需要指定tags,而estimator保存saved model时默认tag为’serve’
- 根据output node来freeze graph。如果我们用了feature column的话,还需要把
init_all_tables
也加入到output node,该问题在github已经被讨论过, 具体见下面的参考资料。 - 把freeze之后的graph和constants参数序列化为pb文件
# estimator export feed node
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
"index": tf.placeholder(tf.int32, [None, field_size], name='index'),
"value": tf.placeholder(tf.float32, [None, field_size], name='value'),
})estimator.export_saved_model(saved_model_dir, input_fn)
def save_model2pb(save_model_dir, pb_path, output_node, tags=['serve']):
with tf.Session(graph=tf.Graph()) as sess, tf.device("/cpu:0"):
tf.saved_model.loader.load(sess, tags, save_model_dir)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph_def,
output_node_names=output_node)with tf.gfile.GFile(pb_path, "wb") as f:
f.write(output_graph_def.SerializeToString())print("%d ops in the final graph." % len(output_graph_def.node))
Java部署 java的部署方式与python类似,对于下面代码,我们的input是value和index,output是output_node
- 先创建graph和session,以byte的方式读取pb文件
- 构建用于feed的tensor. 其中用于feed的node需要在estimator保存成saved model时export出来
- 使用runner进行predict,取出predict结果。
public class DeepModel {private Graph graph;
private Session sess;
public DeepModel(String pbFile) {
try {
graph = new Graph();
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(pbFile));
graph.importGraphDef(graphBytes);
sess = new Session(graph);
} catch (java.io.IOException e) {
System.out.println("DeepModel initial fail!!!");
}
}public boolean isNull() {
return (sess == null) || (graph == null);
}public float[][] predict(int[][] index, float[][] value) {
Tensor indexTensor = Tensor.create(index);
Tensor valueTensor = Tensor.create(value);
Tensor rlt = sess.runner().feed("index", indexTensor).feed("value", valueTensor).fetch("output_node").run().get(0);
float[][] finalRlt = new float[index.length][1];
rlt.copyTo(finalRlt);
return finalRlt;
}
}
推荐一下自己的项目:search-deeplearning
参考资料:
- Using the SavedModel format
- freeze_graph not initializing tables
推荐阅读
- tensorflow|vgg16 finetune tensorflow实现
- deeplearning|ubuntu 16.04 + GTX1050安装nvidia驱动
- DCGAN——菜鸟实现
- DCGAN——菜鸟系列 model.py
- DeepLearning|YOLO: Real-Time Object Detection
- tensorflow|Bert瞎玩系列——bert的数据处理部分