是否有必要更改Keras中Early Stopping回调所使用的指标?


13

当在Keras中使用Early Stopping回调时,某些指标(通常是验证损失)没有增加时,训练将停止。有没有一种方法可以使用其他指标(例如精度,召回率,f度量)代替验证损失?到目前为止,我所看到的所有示例都与此示例类似:callbacks.EarlyStopping(monitor ='val_loss',耐心= 5,冗长= 0,mode ='auto')

Answers:


11

您可以使用在编译模型时指定的任何度量标准函数。

假设您具有以下指标功能:

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

此函数的唯一要求是它接受真实的y和预测的y。

编译模型时,您可以指定此指标,类似于指定诸如“准确性”之类的内置指标的方式:

model.compile(metrics=['accuracy', my_metric], ...)

请注意,我们使用的函数名称为my_metric,不带''(与“ accuracy”中的构建相反)。

然后,如果您定义了EarlyStopping,则只需使用函数的名称(这次使用''):

EarlyStopping(monitor='my_metric', mode='min')

确保指定模式(如果越低越好,请选择最小;如果越高越好,请指定最大)。

您可以像使用任何内置指标一样使用它。这可能也可以与其他回调一起使用,例如ModelCheckpoint(但我还没有测试过)。在内部,Keras只是使用函数名称将新指标添加到该模型可用的指标列表中。

如果您在model.fit(...)中指定用于验证的数据,则还可以通过使用'val_my_metric'将其用于EarlyStopping。


3

当然,只需创建自己的!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

我还没有测试过,但这应该是您如何使用它的一般口味。如果不起作用,请通知我,我将在周末重试。我还假设您已经实现了自己的f1得分。如果不是,仅导入sklearn。


+1仍可从2020
奥斯汀
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.