tf.app.run()如何工作?


Answers:


134
if __name__ == "__main__":

表示当前文件在shell下执行,而不是作为模块导入。

tf.app.run()

如您所见,通过文件 app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们逐行中断:

flags_passthrough = f._parse_flags(args=args)

这样可以确保您通过命令行传递的参数有效,例如, python my_model.py --data_dir='...' --max_iteration=10000实际上,此功能是基于python标准argparse模块实现的。

main = main or sys.modules['__main__'].main

第一个main在右边=是当前函数的第一个参数run(main=None, argv=None) 。While sys.modules['__main__']表示当前正在运行的文件(例如my_model.py)。

因此有两种情况:

  1. 您没有的main功能,my_model.py那么您必须调用tf.app.run(my_main_running_function)

  2. 您在中具有main功能my_model.py。(通常是这种情况。)

最后一行:

sys.exit(main(sys.argv[:1] + flags_passthrough))

确保使用解析后的参数正确调用您的main(argv)or my_main_running_function(argv)函数。


67
对于初学者Tensorflow用户而言,这是一个缺少的难题:Tensorflow具有一些内置的命令行标志处理机制。您可以定义自己的标志tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.'),然后使用tf.app.run()它进行设置,以便可以全局访问所定义标志的传递值,例如tf.flags.FLAGS.batch_size从代码中需要的地方访问。
isarandi

1
在我看来,这是(当前)三个问题的更好答案。它解释了“ tf.app.run()如何工作”,而其他两个答案仅说明了它的作用。
Thomas Fauskanger

看起来abseilTF必须吸收abseil.io/docs/python/guides/flags
-CpILL

75

这只是一个非常快速的包装程序,可以处理标志解析,然后分派到您自己的主程序。参见代码


12
“处理标志解析”是什么意思?也许您可以添加一个链接来通知初学者这意味着什么?
Pinocchio

4
它使用标志包解析提供给程序的命令行参数。(它在后台使用了标准的“ argparse”库,并带有一些包装器)。它与我在答案中链接到的代码链接在一起。
dga

1
在app.py,做什么 main = main or sys.modules['__main__'].mainsys.exit(main(sys.argv[:1] + flags_passthrough))意味着什么?
hAcKnRoCk

3
这对我来说似乎很奇怪,如果可以直接调用它,为什么还要把main函数包装起来main()呢?
查理·帕克

2
hAcKnRoCk:如果文件中没有main,则使用sys.modules [' main '] .main中的任何内容。sys.exit意味着使用args和通过的所有标志运行由此找到的main命令,并以main的返回值退出。@CharlieParker-与Google现有的python应用程序库(例如gflags和google-apputils)兼容。参见例如github.com/google/google-apputils
dga


5

简单来说,的工作tf.app.run()首先设置全局标志以供以后使用,例如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后使用一组参数运行自定义的main函数。

例如,在TensorFlow NMT代码库中,用于训练/推理的程序执行的第一个入口点就是从这一点开始(请参见下面的代码)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

在使用解析参数之后argparsetf.app.run()您将运行函数“ main”,其定义如下:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置了供全局使用的标志之后,tf.app.run()只需运行main传递给它的函数argv作为其参数即可。

PS:正如萨尔瓦多·达利(Salvador Dali)的回答所说,我猜这只是一个很好的软件工程实践,尽管我不确定TensorFlow是否会执行main比使用常规CPython 进行的函数优化的运行。


2

Google代码很大程度上取决于要在库/二进制文件/ python脚本中访问的全局标志,因此tf.app.run()解析出这些标志以在FLAGs(或类似的变量)中创建全局状态,然后调用python main( ) 正如它应该。

如果他们没有对tf.app.run()的调用,则用户可能会忘记进行FLAG解析,从而导致这些库/二进制文件/脚本无法访问所需的FLAG。


1

2.0兼容答:如果你想使用tf.app.run()Tensorflow 2.0,我们应该使用的命令,

tf.compat.v1.app.run()或者您可以使用tf_upgrade_v21.x代码转换为2.0

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.