TensorFlow中的小知识(tf.flags.DEFINE_xxx())
读别人家的代码的时候经常看到这个,结果两三天不看居然忘记了,这脑子绝对上锈了,决定记下来免得老是查来查去的。。。
内容包含如下几个我们经常看到的几个函数:
①tf.flags.DEFINE_xxx()
②FLAGS = tf.flags.FLAGS
③FLAGS._parse_flags()
简单的说:
用于帮助我们添加命令行的可选参数。
也就是说利用该函数我们可以实现在命令行中选择需要设定的参数来运行程序,
可以不用反复修改源代码中的参数,直接在命令行中进行参数的设定。
举个栗子:
程序train.py文件中的小部分代码如下所示:
FLAGS = tf.flags.FLAGStf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')
#全局参数设置,显示在命令行
python train.py \
--input_file data/shakespeare.txt\
--name shakespeare \
--num_steps 50 \
--num_seqs 32 \
--learning_rate 0.01 \
--max_steps 20000
- 通过输入不同的文件名、参数,可以快速完成程序的调参和更换训练集的操作,不需要进入源码中更改。
实践操作一下:
现在我们有如下代码:
import tensorflow as tf
#取上述代码中一部分进行实验
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')#通过print()确定下面内容的功能
FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
FLAGS._parse_flags() #将其解析成字典存储到FLAGS.__flags中
print(FLAGS.__flags)print(FLAGS.num_seqs)print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
- 尝试执行一下上述代码了解其各行代码的功能,可能因为tensorflow版本原因出现报错现象。
推荐阅读
- 热闹中的孤独
- 一个小故事,我的思考。
- 家乡的那条小河
- 一个人的碎碎念
- 野营记-第五章|野营记-第五章 讨伐梦魇兽
- 昨夜小楼听风
- JS中的各种宽高度定义及其应用
- 2021-02-17|2021-02-17 小儿按摩膻中穴-舒缓咳嗽
- 基于微信小程序带后端ssm接口小区物业管理平台设计
- 我眼中的佛系经纪人