TensorFlow,保存模型后为什么会有3个文件?


113

阅读文档后,我在中保存了一个模型TensorFlow,这是我的演示代码:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

但是之后,我发现有3个文件

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

而且我无法通过还原model.ckpt文件来还原模型,因为没有这样的文件。这是我的代码

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

那么,为什么有3个文件?


2
您知道如何解决这个问题吗?如何重新加载模型(使用Keras)?
rajkiran '17

Answers:


116

试试这个:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

TensorFlow保存方法可保存三种文件,因为它将图结构变量值分开存储。该.meta文件描述了已保存的图形结构,因此您需要在还原检查点之前将其导入(否则,它不知道所保存的检查点值对应于哪些变量)。

或者,您可以这样做:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

即使没有名为的文件model.ckpt,在还原文件时,仍会使用该名称来引用已保存的检查点。从saver.py源代码

用户只需要与用户指定的前缀进行交互即可...而不是任何物理路径名。


1
所以不使用.index和.data吗?那两个文件什么时候使用?
ajfbiw.s

26
@ ajfbiw.s .meta存储图形结构,.data存储图形中每个变量的值,.index标识检查对象。因此,在上面的示例中:import_meta_graph使用.meta,而saver.restore使用.data和.index
TK Bartel

哦,我明白了。谢谢。
ajfbiw.s

1
您有可能使用与加载模型不同的TensorFlow版本保存模型吗?(github.com/tensorflow/tensorflow/issues/5639
TK巴特尔

5
有人知道那个 0000000001数字是什么意思吗?在variables.data-?????-of-?????档案中
伊万·塔拉拉耶夫

55
  • meta文件:描述保存的图形结构,包括GraphDef,SaverDef等;然后申请tf.train.import_meta_graph('/tmp/model.ckpt.meta'),将还原SaverGraph

  • 索引文件:它是一个不可变的字符串表(tensorflow :: table :: Table)。每个键是张量的名称,其值是序列化的BundleEntryProto。每个BundleEntryProto都描述张量的元数据:哪个“数据”文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等。

  • 数据文件:它是TensorBundle集合,保存所有变量的值。


我有用于图像分类的pb文件。我可以将其用于实时视频分类吗?

您能否让我知道,使用Keras 2,如果将模型另存为3个文件,如何加载该模型?
rajkiran '17

5

我正在从Word2Vec tensorflow教程中恢复经过训练的单词嵌入。

如果您创建了多个检查点:

例如,创建的文件如下所示

型号.ckpt-55695.data-00000-of-00001

型号.ckpt-55695.index

型号.ckpt-55695.meta

试试这个

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

调用restore_session()时:

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

“ model.ckpt-55695.data-00000-of-00001”中的“ 00000-of-00001”是什么意思?
hafiz031

0

例如,如果您使用辍学训练了CNN,则可以执行以下操作:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.