如何使用predict_generator对Keras中的流测试数据进行预测?


16

Keras从头开始培训卷积网络的博客中,代码仅显示了运行在培训和验证数据上的网络。那测试数据呢?验证数据是否与测试数据相同(我认为不是)。如果在与train和validation文件夹相似的行上有一个单独的测试文件夹,我们如何获得测试数据的混淆矩阵。我知道我们必须使用scikit Learn或其他软件包来执行此操作,但是如何从类明智的测试数据概率中获取一些信息呢?我希望将其用于混淆矩阵。


1
您可以使用生成器调用model.predict_generator(...)函数,该生成器从包含测试集的目录中读取数据。它返回预测,您可以使用这些预测来计算混淆矩阵。那是您要找的东西吗?参见此处的文档:keras.io/models/sequential
stmax

1
是的,我确实看到了。预言生成器将返回一个预测列表,该列表是一个介于0和1之间的浮点值。我该如何解释呢?它不能直接与混淆矩阵一起使用。
Raghuram

2
我还没有尝试过预报(新的),但是它似乎返回了类概率。尝试将值<= 0.5转换为0,将> = 0.5转换为1。一旦有了由0和1组成的列表,就可以将其输入到用于计算混淆矩阵的函数中。
stmax

2
顺便说一句,这对于两个类别的问题都可以正常工作,但是如果存在两个以上类别,该怎么办?
Raghuram

1
如果有两个以上的类,则您的网络需要多个输出。对于n个类,您具有n个输出,并且可以预测具有最高输出的类。看一下softmax函数(en.wikipedia.org/wiki/Softmax_function)。
stmax

Answers:


15

要从测试数据中获取混淆矩阵,您应该执行两个步骤:

  1. 对测试数据进行预测

例如,用于model.predict_generator预测测试生成器的前2000个概率。

generator = datagen.flow_from_directory(
        'data/test',
        target_size=(150, 150),
        batch_size=16,
        class_mode=None,  # only data, no labels
        shuffle=False)  # keep data in same order as labels

probabilities = model.predict_generator(generator, 2000)
  1. 根据标签预测计算混淆矩阵

例如,将概率与分别有1000只猫和1000只狗的情况进行比较。

from sklearn.metrics import confusion_matrix

y_true = np.array([0] * 1000 + [1] * 1000)
y_pred = probabilities > 0.5

confusion_matrix(y_true, y_pred)

有关测试和验证数据的附加说明

Keras文档使用三组不同的数据:培训数据,验证数据和测试数据。训练数据用于优化模型参数。验证数据用于选择元参数,例如历元数。使用最佳的元参数优化模型后,可使用测试数据对模型性能进行合理的估计。


2
感谢您的代码段。您可以将两者链接吗?在您的示例中,y_true似乎填充有伪数据。您将使用generator.classes填充数组吗?
Gegenwind

我不确定,但我认为不是np.array([0] * 1000 + [1] * 1000)您可以通过这样做来获得相同的阵列generator.classes
Mehdi Nellen

2

这是我尝试并为我工作的一些代码:

pred= model.predict_generator(validation_generator, nb_validation_samples // batch_size)
predicted_class_indices=np.argmax(pred,axis=1)
labels = (validation_generator.class_indices)
labels2 = dict((v,k) for k,v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]
print(predicted_class_indices)
print (labels)
print (predictions)

然后,您可以使用:

print (confusion matrix(predicted_class_indices,labels)

在进行预测之前,请确保shuffle=False在测试生成器(在我的情况下为验证生成器)中使用并重置它validation_generator.reset()


0

对于混淆矩阵,您必须使用sklearn软件包。我认为Keras不能提供混乱的矩阵。要预测测试集上的值,只需调用model.predict()方法即可生成测试集的预测。输出值的类型取决于您的模型类型,即离散或概率。


感谢您的回答。我确实知道Keras没有自己的混淆矩阵包。我的问题是model.predict_generator返回一个不能用于计算混淆矩阵的浮点值列表。
Raghuram'9

您正在尝试哪种数据?
2002年

我正在处理图像。
Raghuram'9
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.