TensorFlow保存到文件中/从文件中加载图形


98

根据到目前为止的经验,有几种不同的方法可以将TensorFlow图转储到文件中,然后再将其加载到另一个程序中,但是我无法找到关于它们如何工作的清晰示例/信息。我已经知道的是:

  1. 使用a将模型的变量保存到检查点文件(.ckpt)中,tf.train.Saver()并在以后还原它们(
  2. 将模型保存到.pb文件,然后使用tf.train.write_graph()tf.import_graph_def()source)将其加载回
  3. 从.pb文件加载模型,对其进行重新训练,然后使用Bazel将其转储到新的.pb文件中(
  4. 冻结图形以将图形和权重保存在一起(
  5. 使用as_graph_def()保存模型,并为权重/变量,它们映射到常数(

但是,我无法清除有关这些不同方法的几个问题:

  1. 关于检查点文件,它们仅保存模型的训练权重吗?是否可以将检查点文件加载到新程序中并用于运行模型,还是仅将它们用作在特定时间/阶段将权重保存在模型中的方法?
  2. 关于tf.train.write_graph(),权重/变量也被保存吗?
  3. 关于Bazel,它只能保存到.pb文件中或从中加载以进行重新训练吗?是否有一个简单的Bazel命令只是将图形转储到.pb中?
  4. 关于冻结,是否可以使用来加载冻结图tf.import_graph_def()
  5. TensorFlow的Android演示从.pb文件加载到Google的Inception模型中。如果我想替换自己的.pb文件,该怎么做?我需要更改任何本机代码/方法吗?
  6. 通常,所有这些方法之间到底有什么区别?或更广泛地说,/。as_graph_def()ckpt / .pb有什么区别?

简而言之,我正在寻找一种将图形(如各种操作等)及其权重/变量都保存到文件中的方法,然后可以将其用于将图形和权重加载到另一个程序中,以供使用(不一定要继续/训练)。

关于此主题的文档不是很简单,因此,非常感谢您提供任何答案/信息。


2
最新/最完善的API是元图,这将为您提供一次保存所有三个的方法-1)图2)参数值3)集合:tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Answers:


80

有很多方法可以解决在TensorFlow中保存模型的问题,这可能会使它有些混乱。依次处理您的每个子问题:

  1. 检查点文件(例如产生通过调用saver.save()一个上tf.train.Saver对象)只包含的权重,并且在相同程序中定义的任何其它变量。要在另一个程序中使用它们,您必须重新创建关联的图形结构(例如,通过运行代码以再次构建它,或调用tf.import_graph_def()),这告诉TensorFlow如何处理这些权重。请注意,调用saver.save()还会生成一个包含的文件MetaGraphDef,该文件包含一个图形以及如何将检查点的权重与该图形相关联的详细信息。有关更多详细信息,请参见教程

  2. tf.train.write_graph()只写图结构;不是重量。

  3. Bazel与读取或写入TensorFlow图无关。(也许我误会了您的问题:请随时在评论中予以澄清。)

  4. 冻结的图可以使用加载tf.import_graph_def()。在这种情况下,权重(通常)嵌入在图形中,因此您无需加载单独的检查点。

  5. 主要更改将是更新输入到模型中的张量的名称以及从模型中获取的张量的名称。在TensorFlow Android演示中,这将与传递给的inputNameoutputName字符串相对应TensorFlowClassifier.initializeTensorFlow()

  6. GraphDef是该程序的结构,其通常不通过训练过程而改变。检查点是训练过程状态的快照,通常在训练过程的每个步骤都会改变。结果,TensorFlow对这些类型的数据使用不同的存储格式,并且低级API提供了不同的方式来保存和加载它们。更高级别的库,如MetaGraphDef图书馆,Kerasskflow对这些机制的构建提供更加便捷的方式来保存和恢复整个模型。


当它说您可以加载保存的图形然后执行它时,是否表示C ++ API文档存在tf.train.write_graph()
mnicky

2
C ++ API文档并不存在,但是缺少一些细节。最重要的细节是,除了由GraphDef保存tf.train.write_graph(),还需要记住在执行图形时要馈送和获取的张量的名称(上面的项目5)。
mrry

@mrry:我尝试使用tensorflows DeepDream示例。但似乎需要pb格式的预训练模型!我运行了Cifar10示例,但它仅创建检查点!我找不到任何PB文件或任何东西!如何将检查点转换为deepdream示例使用的pb格式?
里卡

2
@ Coderx7我真的认为您不能将.ckpt转换为.pb,因为检查点仅包含权重和变量,并且对图形的结构
一无所知

1
是否有简单的代码加载.pb文件然后运行?

1

您可以尝试以下代码:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.