TensorFlow中的tf.app.flags的目的是什么?


115

我在Tensorflow中阅读一些示例代码,发现以下代码

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
                 'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
                 'for unit testing.')

tensorflow/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py

但我找不到有关的用法的任何文档tf.app.flags

我发现该标志的实现在 tensorflow/tensorflow/python/platform/default/_flags.py

显然,这tf.app.flags是以某种方式用于配置网络的,所以为什么在API文档中没有呢?谁能解释这是怎么回事?

Answers:


110

tf.app.flags模块目前是python-gflags的 一个瘦包装,因此该项目文档是如何使用它的最佳资源argparse,它实现了一部分功能python-gflags

请注意,该模块当前已打包为方便编写演示应用程序使用,从技术上讲,它不是公共API的一部分,因此将来可能会更改。

我们建议您使用argparse或任何您喜欢的库来实现自己的标志解析。

编辑:tf.app.flags模块实际上并未使用实现python-gflags,但它使用了类似的API。


80
“打包为编写演示应用程序提供了便利,并且从技术上讲,它不是公共AP的一部分。”奇怪的是,几乎所有教程都使用了它,但是其中没有文档。导致大量混乱。
speedplane

2
对于如何使用argparse传递参数给一个TensorFlow模型,以及如何将其捆绑成一个Python模块为云一个很好的例子,看到task.pytaxifare模块中的,一部分训练数据分析师课程教材
charlesreid1

3
tf.app.run也不是公共API的一部分吗?因为它依赖tf.app.flags并且具有公共文档(tensorflow.org/api_docs/python/tf/app/run),所以我认为它是公共的并受支持。如果建议改为使用argparse,是否可以提供有关将其与一起使用的推荐方法的简短示例argparse
naktinis

6
张量流中的所有内容都不是文档问题。
死码

37

tf.app.flags模块是Tensorflow提供的功能,用于为Tensorflow程序实现命令行标志。例如,您遇到的代码将执行以下操作:

flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

第一个参数定义标志的名称,第二个参数定义默认值,以防执行文件时未指定标志。

因此,如果运行以下命令:

$ python fully_connected_feed.py --learning_rate 1.00

那么学习率将设置为1.00,如果未指定该标志,则将保持0.01。

本文所述,文档可能不存在,因为这可能是Google内部要求开发人员使用的文档。

此外,如文章中所述,使用Tensorflow标志比其他Python软件包提供的标志功能有多个优势,例如argparse在处理Tensorflow模型时尤其如此,最重要的是可以向代码提供Tensorflow特定信息,例如信息有关使用哪个GPU的信息。


1
第三个参数说什么?大概就像小文档字符串。很想知道我是否错了。
shivam13juna

是的,大概就是这样。到目前为止,我还没有看到任何实际的用途,因此我想让您理解。
Vedang Waradpande

11

在Google,他们使用标记系统来设置参数的默认值。它类似于argparse。他们使用自己的标记系统,而不是argparse或sys.argv。

资料来源:我以前在那里工作过。


5

使用时tf.app.run(),可以使用方便地在线程之间传输变量tf.app.flags。请参阅此内容以进一步使用tf.app.flags


4

经过多次尝试后,我发现它可以打印所有FLAGS键以及实际值-

for key in tf.app.flags.FLAGS.flag_values_dict():

  print(key, FLAGS[key].value)

3
您的意思是FLAGS [key]
Physncubus
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.