运行object_detection_tutorial TypeError的问题:load()缺少2个必需的位置参数


11

我对tensorflow非常陌生,并且正在尝试运行object_detection_tutorial。我收到TypeErrror,不知道如何解决。

这是load_model函数,它缺少2个参数:

标签:一组字符串标签,用于标识所需的MetaGraphDef。这些应该与使用SavedModel save()API保存变量时使用的标签相对应。

export_dir:SavedModel协议缓冲区和要加载的变量所在的目录。

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model
WARNING:tensorflow:From <ipython-input-9-f8a3c92a04a4>:11: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e10c73a22cc9> in <module>
      1 model_name = 'ssd_mobilenet_v1_coco_2017_11_17'
----> 2 detection_model = load_model(model_name)

<ipython-input-9-f8a3c92a04a4> in load_model(model_name)
      9   model_dir = pathlib.Path(model_dir)/"saved_model"
     10 
---> 11   model = tf.saved_model.load(str(model_dir))
     12   model = model.signatures['serving_default']
     13 

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

TypeError: load() missing 2 required positional arguments: 'tags' and 'export_dir'

您能帮我解决这个问题并运行我的第一个物体检测器:D吗?

Answers:


14

我有同样的问题,我现在想解决1周。我想解决方案应该是这样;

model = tf.compat.v2.saved_model.load(str(model_dir), None)

更详细的是(从官方网站);

从export_dir加载SavedModel。

tf.saved_model.load(
    export_dir,
    tags=None
)

别名:

tf.compat.v1.saved_model.load_v2

tf.compat.v2.saved_model.load

1
我使用了您的解决方案,并且遇到了另一个错误。我更新了所有可能的方法,并且可以正常工作!我也有一个错误,pathlib没有被取消。
多米尼克

@Dominik您能更具体些吗?也许我可以帮忙,因为这次张量流冒险使我解决了很多问题:D
Onur Baskin

4
@OnurBaskin稍后会出现错误:TypeError:int()参数必须是字符串,类似字节的对象或数字,而不是'
Tensor'– kaitsu

@Dominik我认为这是您的Tensorflow版本。它应该是2.0版(稳定)。这是我问的问题的链接,也许您遇到的确切错误。另外,搜索任何需要“ compat.v1”的旧导入。稍后,您应该会有更多的错误,但这就是您迁移旧代码的方式。
Onur Baskin

@OnurBaskin我很困惑。我认为对象检测API仅与TensorFlow 1版本兼容。
3

0

我猜想这是一个分支问题,使用tf_2_1_reference分支可以帮到我:

igian@iGians-MBP models % git checkout tf_2_1_reference
M   research/object_detection/object_detection_tutorial.ipynb
Branch 'tf_2_1_reference' set up to track remote branch 'tf_2_1_reference' from 'origin'.
Switched to a new branch 'tf_2_1_reference'
igians@iGians-MBP models % jupyter notebook

然后像一个好新手一样执行本教程的每个木星单元!

这是我使用的分支:https : //github.com/tensorflow/models/tree/tf_2_1_reference

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.