tf.app.run()
Tensorflow 中的工作如何翻译演示?
在中tensorflow/models/rnn/translate/translate.py
,有一个呼叫到tf.app.run()
。如何处理?
if __name__ == "__main__":
tf.app.run()
tf.app.run()
Tensorflow 中的工作如何翻译演示?
在中tensorflow/models/rnn/translate/translate.py
,有一个呼叫到tf.app.run()
。如何处理?
if __name__ == "__main__":
tf.app.run()
Answers:
if __name__ == "__main__":
表示当前文件在shell下执行,而不是作为模块导入。
tf.app.run()
如您所见,通过文件 app.py
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
让我们逐行中断:
flags_passthrough = f._parse_flags(args=args)
这样可以确保您通过命令行传递的参数有效,例如,
python my_model.py --data_dir='...' --max_iteration=10000
实际上,此功能是基于python标准argparse
模块实现的。
main = main or sys.modules['__main__'].main
第一个main
在右边=
是当前函数的第一个参数run(main=None, argv=None)
。While sys.modules['__main__']
表示当前正在运行的文件(例如my_model.py
)。
因此有两种情况:
您没有的main
功能,my_model.py
那么您必须调用tf.app.run(my_main_running_function)
您在中具有main
功能my_model.py
。(通常是这种情况。)
最后一行:
sys.exit(main(sys.argv[:1] + flags_passthrough))
确保使用解析后的参数正确调用您的main(argv)
or my_main_running_function(argv)
函数。
abseil
TF必须吸收abseil.io/docs/python/guides/flags
这只是一个非常快速的包装程序,可以处理标志解析,然后分派到您自己的主程序。参见代码。
main = main or sys.modules['__main__'].main
和 sys.exit(main(sys.argv[:1] + flags_passthrough))
意味着什么?
main()
呢?
简单来说,的工作tf.app.run()
是首先设置全局标志以供以后使用,例如:
from tensorflow.python.platform import flags
f = flags.FLAGS
然后使用一组参数运行自定义的main函数。
例如,在TensorFlow NMT代码库中,用于训练/推理的程序执行的第一个入口点就是从这一点开始(请参见下面的代码)
if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
在使用解析参数之后argparse
,tf.app.run()
您将运行函数“ main”,其定义如下:
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
因此,在设置了供全局使用的标志之后,tf.app.run()
只需运行main
传递给它的函数argv
作为其参数即可。
PS:正如萨尔瓦多·达利(Salvador Dali)的回答所说,我猜这只是一个很好的软件工程实践,尽管我不确定TensorFlow是否会执行main
比使用常规CPython 进行的函数优化的运行。
Google代码很大程度上取决于要在库/二进制文件/ python脚本中访问的全局标志,因此tf.app.run()解析出这些标志以在FLAGs(或类似的变量)中创建全局状态,然后调用python main( ) 正如它应该。
如果他们没有对tf.app.run()的调用,则用户可能会忘记进行FLAG解析,从而导致这些库/二进制文件/脚本无法访问所需的FLAG。
2.0兼容答:如果你想使用tf.app.run()
中Tensorflow 2.0
,我们应该使用的命令,
tf.compat.v1.app.run()
或者您可以使用tf_upgrade_v2
将1.x
代码转换为2.0
。
tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.')
,然后使用tf.app.run()
它进行设置,以便可以全局访问所定义标志的传递值,例如tf.flags.FLAGS.batch_size
从代码中需要的地方访问。