tf踩坑记-NotFoundError

tf的estimator api使用起来非常方便,填充下几个函数,就能写出高度结构化的模型代码。但封装越高级,使用中一旦遇到问题,处理起来就会相当麻烦。
今天在日常炼丹中就遇到了这么一个问题。模型训练完成,导出saved model时,始终报一个key找不到对应的var。

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

这种错误一般是两次使用的graph不一致造成的。可能的原因一般有:
  • 使用了不同版本的api,api会对变量加上一些不一样的前缀或后缀,不同的api可能处理的逻辑不一致。导致两次加载的graph不一样。
  • checkpoint里没保存这个变量。这个一般不太会存在,需要train的变量肯定需要保存。有些不需要train的变量才会出现这种可能。
这两种情况排查起来比较简单。可以用下面这个工具检查下checkpoint里的变量:
/tensorflow/python/tools/inspect_checkpoint.py checkpoint_file_name

一般会得到以下的输出:
tf踩坑记-NotFoundError
文章图片

这次我的错误,就是找不到一个age_1/embeddings的变量。从输出看,save的变量名是age/embeddings,加载时却要找一个age_1/embeddings的变量。名字变了,看起来应该是第一种原因,save时和export时用了不同的图。然而我的代码train和export是在同一份代码,使用的是同一个model_fn,且在同一个环境下测试的,应该说是不会使用到不同的graph才对的。
排查了很久,最终还是在图不一致上解决了问题。estimator有两个输入函数,一个input_fn,一个model_fn,除了model_fn会构建模型graph之外,input_fn里的tensor也会添加到graph里去。而训练阶段和export阶段使用的input_fn是不一致的。export的模型是要给线上serving用的,所以在input_fn里定义了一堆placeholder作为输入,而placeholder里也有一个tensor的name被set为 ”age“,这就导致model_fn里的age/embeddings在build graph时,被改成了age_1/embeddings,再去checkpoint里查找这个变量的值,自然是找不到的。
【tf踩坑记-NotFoundError】这个问题一开始被tf的报错给误导了,一直在model_fn里查问题。报错的地方不一定是有问题的地方,而tf1.13也不会报这种重名的问题。大多数时候我们创建的tensor也不会给一个name,tf会自动命名一个name,自动命名有一套规则,有重复了就新生成一个name,这就导致不管是自动set的name还是人工定义的name,都不会报重名的错误。一旦自己命名重名了,很可能会造成问题。

    推荐阅读