警告:tensorflow:sample_weight模式从…强制为['…']


47

使用.fit_generator().fit()将字典传递class_weight=为参数来训练图像分类器。

我从未在TF1.x中遇到错误,但在2.1中,开始训练时得到以下输出:

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']

强制从...['...']到底意味着什么?

tensorflow的回购中此警告的来源在此处,注释为:

尝试将sample_weight_modes强制转换为目标结构。这隐含地依赖于模型展平其内部表示的输出这一事实。


7
很高兴看到这样一个最近的问题成为我自己警告的唯一搜索结果。
jmkjaer

1
@jorijnsmit您可以提供复制问题/警告的代码吗?
hv89

2
实际上,使用切换到TF2 %tensorflow_version 2.x足以使此警告出现:colab.research.google.com/gist/jorijnsmit/…–
jorijnsmit

1
@jorijnsmit,不,我收到相同的警告,但实际上已在pip install tensorflow(在pyenv / virtualenv环境中)安装了TF2.1
lurix66

1
是的,确实是@ lurix66,在中引入了生成此错误的代码2.1.0rc0
jorijnsmit

Answers:


11

这似乎是假消息。升级到TensorFlow 2.1后,我得到了相同的警告消息,但是我根本不使用任何类权重或样本权重。我确实使用了一个生成器,该生成器返回这样的元组:

return inputs, targets

现在我将其更改为以下内容,以消除警告:

return inputs, targets, [None]

我不知道这是否相关,但是我的模型使用3个输入,因此我的inputs变量实际上是3个numpy数组的列表。 targets仅仅是单个numpy数组。

无论如何,这只是一个警告。无论哪种方式,培训都可以正常进行。

编辑TensorFlow 2.2:

这个错误似乎已在TensorFlow 2.2中修复,这很棒。但是,上面的修复将在TF 2.2中失败,因为它会尝试获取样本权重的形状,显然会导致AttributeError: 'NoneType' object has no attribute 'shape'。因此,在升级到2.2时,请撤消上述修复。


这也对我有用。
罗伯特·拉格

14

我相信这是tensorflow的错误,当您model.compile()使用默认参数sample_weight_mode=None调用然后model.fit()使用指定的sample_weight或调用时,将会发生此错误class_weight

从张量流存储库中:

  • fit() 最终致电 _process_training_inputs()
  • _process_training_inputs() sample_weight_modes = [None]基础上model.sample_weight_mode = None,然后创建一个DataAdaptersample_weight_modes = [None]
  • DataAdapter通话broadcast_sample_weight_modes()sample_weight_modes = [None]初始化
  • broadcast_sample_weight_modes() 似乎期望 sample_weight_modes = None但收到[None]
  • 它断言[None]sample_weight/的结构不同class_weightNone通过适合sample_weight/ 的结构将其覆盖回去class_weight并输出警告

警告抛开这有没有影响fit()作为sample_weight_modesDataAdapter重新设置为None

请注意,tensorflow 文档指出sample_weight必须为numpy数组。如果调用fit()具有sample_weight.tolist()相反,你不会得到一个警告,但sample_weight被悄悄改写到None_process_numpy_inputs()被称为预处理和接收长度更大的输入不止一个。


1
非常详尽的解释,谢谢。我唯一不了解的是警告描述...为被强迫[...],而在您的情况下[None]则被强迫为None……
jorijnsmit

4

我已接管您的Gist并安装了Tensorflow 2.0,而不是TFA,并且在没有任何此类警告的情况下可以正常工作。

这是完整代码的要点。下面显示了安装Tensorflow的代码:

!pip install tensorflow==2.0

成功执行的屏幕截图如下所示:

在此处输入图片说明

更新:此错误已修复Tensorflow Version 2.2.


5
谢谢您的答复。没错,直到版本才引入警告消息2.1.0rc0。但是,我怕我的遗体的问题:“这是什么意思强迫从东西...['...']?”
jorijnsmit '19

3
我注意到,有些情况可能意想不到的东西时,sample_weight_mode=Nonetarget_structure为类型dictsample_weight_modes然后[None]在例外broadcast_sample_weight_modes由于被捉住dict。可以将其视为错误吗?
FranzKnülle'19

2
不。问题一直在收集观点和支持,但没有答案。
jorijnsmit

1
@gkennos:如果您认为这是一个错误,可以在Github Tensorflow存储库中提交错误吗?
Tensorflow支持

1
这绝对是一个错误,但它现在固定TensorFlow 2.2
JLH

2

而不是提供字典

weights = {'0': 42.0, '1': 1.0}

我尝试了一个清单

weights = [42.0, 1.0]

警告消失了。


谢啦!我正在尝试(不成功)使用字典。通过使用列表,可以修复错误!
维克多·

尽管这确实消除了错误,但对我来说,这破坏了每个类的权重,导致结果更糟。在切换到列表之前,我会检查一致性。
CanofDrink
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.