在Keras从头开始培训卷积网络的博客中,代码仅显示了运行在培训和验证数据上的网络。那测试数据呢?验证数据是否与测试数据相同(我认为不是)。如果在与train和validation文件夹相似的行上有一个单独的测试文件夹,我们如何获得测试数据的混淆矩阵。我知道我们必须使用scikit Learn或其他软件包来执行此操作,但是如何从类明智的测试数据概率中获取一些信息呢?我希望将其用于混淆矩阵。
1
您可以使用生成器调用model.predict_generator(...)函数,该生成器从包含测试集的目录中读取数据。它返回预测,您可以使用这些预测来计算混淆矩阵。那是您要找的东西吗?参见此处的文档:keras.io/models/sequential
—
stmax
是的,我确实看到了。预言生成器将返回一个预测列表,该列表是一个介于0和1之间的浮点值。我该如何解释呢?它不能直接与混淆矩阵一起使用。
—
Raghuram
我还没有尝试过预报(新的),但是它似乎返回了类概率。尝试将值<= 0.5转换为0,将> = 0.5转换为1。一旦有了由0和1组成的列表,就可以将其输入到用于计算混淆矩阵的函数中。
—
stmax
顺便说一句,这对于两个类别的问题都可以正常工作,但是如果存在两个以上类别,该怎么办?
—
Raghuram
如果有两个以上的类,则您的网络需要多个输出。对于n个类,您具有n个输出,并且可以预测具有最高输出的类。看一下softmax函数(en.wikipedia.org/wiki/Softmax_function)。
—
stmax