如何在Keras中为不平衡的班级设置班级权重?


128

我知道在Keras中使用class_weights参数字典进行拟合是有可能的,但是我找不到任何示例。有人愿意提供吗?

顺便说一句,在这种情况下,适当的做法仅仅是根据少数群体的代表性不足来按比例增加少数群体的比例?


是否有使用Keras的新更新方法?为什么字典由三个类别组成,对于以下类别:0:1.0 1:50.0 2:2.0 ???? 不应该:2:1.0吗?
Chuck

Answers:


112

如果您谈论的是正常情况,即您的网络仅产生一个输出,那么您的假设是正确的。为了强制您的算法将类1的每个实例视为类0的 50个实例,您必须:

  1. 使用标签及其相关权重定义字典

    class_weight = {0: 1.,
                    1: 50.,
                    2: 2.}
    
  2. 将字典作为参数输入:

    model.fit(X_train, Y_train, nb_epoch=5, batch_size=32, class_weight=class_weight)

编辑:“将类1的每个实例作为类0的 50个实例进行处理”意味着在损失函数中,您将更高的值分配给这些实例。因此,损失变为加权平均值,其中每个样本的权重由class_weight及其对应的类别指定。

Keras文档:class_weight:可选的字典,将类索引(整数)映射到权重(浮点)值,用于加权损失函数(仅在训练期间)。


1
如果您正在使用3D数据,还可以查看github.com/fchollet/keras/issues/3653
herve

对我来说,它给出一个错误dic没有形状属性。
弗拉维奥Filho的

我相信Keras可能会改变这种工作方式,这是针对2016年8月的版本。我将在一周内为您验证
上演者

4
@layser这仅适用于“ category_crossentropy”损失吗?如何为“ Sigmoid”和“ binary_crossentropy”损失给keras提供class_weight?
纳曼

1
@layser您能解释一下将类1的每个实例视为类0的50个实例吗?是在训练集中,对应于第1类的行重复50次以使其平衡,还是遵循其他一些过程?
Divyanshu Shekhar

121

您可以简单地实现class_weightfrom sklearn

  1. 让我们先导入模块

    from sklearn.utils import class_weight
  2. 为了计算班级权重,请执行以下操作

    class_weights = class_weight.compute_class_weight('balanced',
                                                     np.unique(y_train),
                                                     y_train)
    
  3. 第三,最后将其添加到模型拟合中

    model.fit(X_train, y_train, class_weight=class_weights)

注意:我编辑这篇文章,从改变变量名class_weightclass_weight 小号为了不覆盖导入的模块。从注释复制代码时进行相应调整。


20
对我来说,class_weight.compute_class_weight 产生一个数组,我需要将其更改为字典以便与Keras一起使用。更具体地说,在步骤2之后,请使用class_weight_dict = dict(enumerate(class_weight))
C.Lee

5
这对我不起作用。对于keras中的三类问题,y_train(300096, 3)numpy数组。因此,该class_weight=行给了我TypeError:无法散列的类型:'numpy.ndarray'–
Lembik

3
@Lembik我有一个类似的问题,其中y的每一行都是类索引的单编码向量。我通过将one-hot表示形式转换为int来解决此问题,如下所示:y_ints = [y.argmax() for y in y_train]
tkocmathla '18

3
如果我正在做多类标记,以便我的y_true向量中具有多个1,该怎么办:例如[1 0 0 0 1 0 0],其中有些x标记为0和4。即使如此,我每个向量的总数标签不平衡。我该如何使用课程加权?
Aalok

22

我将这种规则用于class_weight

import numpy as np
import math

# labels_dict : {ind_label: count_label}
# mu : parameter to tune 

def create_class_weight(labels_dict,mu=0.15):
    total = np.sum(labels_dict.values())
    keys = labels_dict.keys()
    class_weight = dict()

    for key in keys:
        score = math.log(mu*total/float(labels_dict[key]))
        class_weight[key] = score if score > 1.0 else 1.0

    return class_weight

# random labels_dict
labels_dict = {0: 2813, 1: 78, 2: 2814, 3: 78, 4: 7914, 5: 248, 6: 7914, 7: 248}

create_class_weight(labels_dict)

math.log平滑非常不平衡的班级的权重!返回:

{0: 1.0,
 1: 3.749820767859636,
 2: 1.0,
 3: 3.749820767859636,
 4: 1.0,
 5: 2.5931008483842453,
 6: 1.0,
 7: 2.5931008483842453}

3
为什么要使用log而不是仅将一个类的样本数除以样本总数?我假设在model.fit_generator(...)的参数class_weight中有一些我不理解的东西
startoftext

@startoftext这就是我的方法,但我认为您已将其反转。我上过n_total_samples / n_class_samples每堂课。
Colllin

2
在您的示例中,类0(具有2813个示例)和类6(具有7914示例)的权重正好为1.0。这是为什么?6年级的学生要大几倍!您可能希望将0类放大,而6类缩小以使其达到相同的水平。
Vladislavs Dovgalecs

9

注意:请参阅注释,此答案已过时。

要平均加权所有类,您现在可以像这样简单地将class_weight设置为“ auto”:

model.fit(X_train, Y_train, nb_epoch=5, batch_size=32, class_weight = 'auto')

1
class_weight='auto'在Keras文档或源代码中都找不到任何参考。您能告诉我们您在哪里找到的吗?
法比奥·佩雷斯

2
这个答案可能是错误的。检查这个问题:github.com/fchollet/keras/issues/5116
法比奥·佩雷斯

奇。发布评论时,我正在使用class_balanced ='auto',但现在找不到对它的引用。也许随着Keras的快速发展,它已经改变了。
大卫·格罗佩

就像上面指出Keras问题中指定的那样,您可以传递任何随机字符串class_weight,这样不会产生任何效果。因此,此答案不正确。
ncasas

3

class_weight很好,但正如@Aalok所说,如果您是对多个标签的类进行一键编码,则此方法将无效。在这种情况下,请使用sample_weight

sample_weight:与x长度相同的可选数组,其中包含权重以应用于每个样本的模型损失。对于时间数据,您可以传递具有形状(样本,sequence_length)的2D数组,以对每个样本的每个时间步施加不同的权重。在这种情况下,您应确保在compile()中指定sample_weight_mode =“ temporal”。

sample_weights用于为每个训练样本提供权重。这意味着您应该传递与训练样本具有相同数量元素的一维数组(指示每个样本的权重)。

class_weights用于为每个输出类提供权重或偏差。这意味着您应该为要分类的每个类传递权重。

必须给sample_weight一个numpy数组,因为它将评估其形状。

另请参阅以下答案:https : //stackoverflow.com/questions/48315094/using-sample-weight-in-keras-for-sequence-labelling


2

https://github.com/keras-team/keras/issues/2115处添加解决方案。如果您需要的不仅仅是类加权,那么您需要不同的误报和误报成本。现在,使用新的keras版本,您可以覆盖如下所示的相应损失函数。注意,这weights是一个方矩阵。

from tensorflow.python import keras
from itertools import product
import numpy as np
from tensorflow.python.keras.utils import losses_utils

class WeightedCategoricalCrossentropy(keras.losses.CategoricalCrossentropy):

    def __init__(
        self,
        weights,
        from_logits=False,
        label_smoothing=0,
        reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
        name='categorical_crossentropy',
    ):
        super().__init__(
            from_logits, label_smoothing, reduction, name=f"weighted_{name}"
        )
        self.weights = weights

    def call(self, y_true, y_pred):
        weights = self.weights
        nb_cl = len(weights)
        final_mask = keras.backend.zeros_like(y_pred[:, 0])
        y_pred_max = keras.backend.max(y_pred, axis=1)
        y_pred_max = keras.backend.reshape(
            y_pred_max, (keras.backend.shape(y_pred)[0], 1))
        y_pred_max_mat = keras.backend.cast(
            keras.backend.equal(y_pred, y_pred_max), keras.backend.floatx())
        for c_p, c_t in product(range(nb_cl), range(nb_cl)):
            final_mask += (
                weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
        return super().call(y_true, y_pred) * final_mask

0

我发现了以下使用minist数据集在损失函数中编码类权重的示例。在这里查看链接:https : //github.com/keras-team/keras/issues/2115

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

0
from collections import Counter
itemCt = Counter(trainGen.classes)
maxCt = float(max(itemCt.values()))
cw = {clsID : maxCt/numImg for clsID, numImg in itemCt.items()}

这适用于生成器或标准。相对于最大类别,最大类别的权重为1,而其他类别的权重则大于1。

类权重接受字典类型输入

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.