tf.app.flags命令行参数解析模块
说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考我之前的一篇文章:python argparse 模块命令行参数用法及说明。
在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。
直接上代码实例,新建一个名为 test_flags.py 的文件,内容如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | #coding:utf-8 import tensorflow as tf FLAGS = tf.app.flags.FLAGS # tf.app.flags.DEFINE_string("param_name", "default_val", "description") tf.app.flags.DEFINE_string( "train_data_path" , "/home/feige" , "training data dir" ) tf.app.flags.DEFINE_string( "log_dir" , "./logs" , " the log dir" ) tf.app.flags.DEFINE_integer( "train_batch_size" , 128 , "batch size of train data" ) tf.app.flags.DEFINE_integer( "test_batch_size" , 64 , "batch size of test data" ) tf.app.flags.DEFINE_float( "learning_rate" , 0.001 , "learning rate" ) def main(unused_argv): train_data_path = FLAGS.train_data_path print ( "train_data_path" , train_data_path) train_batch_size = FLAGS.train_batch_size print ( "train_batch_size" , train_batch_size) test_batch_size = FLAGS.test_batch_size print ( "test_batch_size" , test_batch_size) size_sum = tf.add(train_batch_size, test_batch_size) with tf.Session() as sess: sum_result = sess.run(size_sum) print ( "sum_result" , sum_result) # 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数 if __name__ = = '__main__' : tf.app.run() # 解析命令行参数,调用main 函数 main(sys.argv) |
上述代码已给出较为详细的注释,在此不再赘述。
该文件的调用示例以及运行结果如下所示
如果需要修改默认参数的值,则在命令行传入自定义参数值即可,若全部使用默认参数值,则可直接在命令行运行该 python 文件。
读者可能会对 tf.app.run() 有些疑问,在上述注释中也有所解释,但要真正弄清楚其运行原理
还需查阅其源代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | def run(main = None , argv = None ): """Runs the program with an optional 'main' function and 'argv' list.""" f = flags.FLAGS # Extract the args from the optional `argv` list. args = argv[ 1 :] if argv else None # Parse the known flags from that list, or from the command # line otherwise. # pylint: disable=protected-access flags_passthrough = f._parse_flags(args = args) # pylint: enable=protected-access main = main or sys.modules[ '__main__' ].main # Call the main function, passing through any arguments # to the final program. sys.exit(main(sys.argv[: 1 ] + flags_passthrough)) |
flags_passthrough=f._parse_flags(args=args)
这里的_parse_flags
就是我们tf.app.flags
源码中用来解析命令行参数的函数。
所以这一行就是解析参数的功能;
下面两行代码也就是 tf.app.run 的核心意思:执行程序中 main 函数,并解析命令行参数!
以上为个人经验,希望能给大家一个参考,也希望大家多多支持IT俱乐部。