scikit-learn .predict()默认阈值


75

我正在处理不平衡类(5%1)的分类问题。我想预测班级,而不是概率。

在二进制分类问题中,默认情况下是否classifier.predict()使用scikit 0.5?如果没有,默认方法是什么?如果可以,该如何更改?

在scikit中,某些分类器可以class_weight='auto'选择,但并非全部都可以。使用class_weight='auto',是否.predict()将实际人口比例用作阈值?

MultinomialNB不支持的分类器中执行此操作的方式是什么class_weight?除了predict_proba()自己使用然后计算类之外。

Answers:


42

classifier.predict()默认情况下,scikit是否使用0.5?

在概率分类器中,是的。正如其他人所解释的那样,从数学角度来看,这是唯一明智的阈值。

在不支持的分类器(如MultinomialNB)中,如何做到这一点class_weight

您可以设置class_prior,即每个类y的先验概率P(y)。这有效地改变了决策边界。例如

# minimal dataset
>>> X = [[1, 0], [1, 0], [0, 1]]
>>> y = [0, 0, 1]
# use empirical prior, learned from y
>>> MultinomialNB().fit(X,y).predict([1,1])
array([0])
# use custom prior to make 1 more likely
>>> MultinomialNB(class_prior=[.1, .9]).fit(X,y).predict([1,1])
array([1])

看来RandomForestClassifier没有class_prior。怎么办呢?
famargar

2
RandomForestClassifier没有class_prior参数,但具有可以使用的class_weight参数。
lbcommer '17

4
实际上,默认的0.5默认值是任意的,并且不一定是最佳值,例如,已被撤销的权威Frank Harrell在有关CV的此答案中已注意到。
蒂姆(Tim)

“在概率分类器中,是的。正如其他人所解释的那样,从数学角度来看,这是唯一明智的阈值。” -这似乎完全不合常理。例如,如果您想权衡召回率而不是精度呢?
cyniphile

39

阈值可以使用 clf.predict_proba()

例如:

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state = 2)
clf.fit(X_train,y_train)
# y_pred = clf.predict(X_test)  # default threshold is 0.5
y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool) # set threshold as 0.3

7
为了澄清起见,您没有设置阈值,因为这意味着您将永久更改的行为clf.predict(),而您没有这样做。
pcko1

这是正确的答案。我在MLP来源中看不到他们在哪儿做0.5阈值...
eggie5

您如何将其与GridSearchCV绑定在一起,在其中执行的预测是内部的,您无法访问?假设0.3的阈值将给我带来不同的最佳模型选择。
demongolem

2
我认为GridSearchCV将仅使用默认阈值0.5。在训练过程中更改此阈值是不合理的,因为我们希望一切都公平。仅在最后的预测阶段,我们才对概率阈值进行调整,以偏向于正面或负面的结果。例如,要获得更大的捕获率(以更高的误报为代价),我们可以手动降低阈值。
Yuchao Jiang

37

scikit learning的阈值对于二进制分类是0.5,而哪个类别对多分类的可能性最大。在许多问题中,通过调整阈值可以获得更好的结果。但是,必须谨慎进行,而不是对保留测试数据,而应对训练数据进行交叉验证。如果您对测试数据的阈值进行了任何调整,那么您只是在过度拟合测试数据。

调整阈值的大多数方法都是基于接收器工作特性(ROC)Youden的J统计量,但也可以通过其他方法来完成,例如使用遗传算法进行搜索。

这是一篇同行评议期刊文章,描述了在医学领域中的用法:

http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2515362/

据我所知,没有在Python中执行此操作的程序包,但是使用Python进行暴力搜索来查找它相对简单(但效率低下)。

这是一些执行此操作的R代码。

## load data
DD73OP <- read.table("/my_probabilites.txt", header=T, quote="\"")

library("pROC")
# No smoothing
roc_OP <- roc(DD73OP$tc, DD73OP$prob)
auc_OP <- auc(roc_OP)
auc_OP
Area under the curve: 0.8909
plot(roc_OP)

# Best threshold
# Method: Youden
#Youden's J statistic (Youden, 1950) is employed. The optimal cut-off is the threshold that maximizes the distance to the identity (diagonal) line. Can be shortened to "y".
#The optimality criterion is:
#max(sensitivities + specificities)
coords(roc_OP, "best", ret=c("threshold", "specificity", "sensitivity"), best.method="youden")
#threshold specificity sensitivity 
#0.7276835   0.9092466   0.7559022

5
很棒的帖子!最重要的一点是:“如果您对测试数据的阈值进行任何调整,都将过度拟合测试数据。”
Sven R. Kunze

7

您似乎在这里混淆了概念。阈值不是“通用分类器”的概念-最基本的方法基于某些可调阈值,但是大多数现有方法创建了复杂的分类规则,这些规则不能(或至少不应)视为阈值。

所以首先-无法回答scikit的分类器默认阈值的问题,因为没有此类问题。

第二类加权不是阈值,它是分类器处理不平衡类的能力,它取决于特定的分类器。例如-在SVM情况下,这是对优化问题中的松弛变量进行加权的方法,或者,如果您愿意-与特定类相关的lagrange乘数值的上限。将其设置为“自动”意味着要使用一些默认的启发式方法,但是再一次-不能简单地将其转换为某些阈值。

另一方面,朴素贝叶斯(Naive Bayes)直接从训练集中估计班级概率。它称为“类优先”,您可以在构造函数中使用“ class_prior”变量进行设置。

文档中

该类的先验概率。如果指定了先验,则不会根据数据进行调整。


2
让我以不同的方式对此进行解释,然后随意地说我还是很困惑:-)。说我有两节课。大多数分类器将预测概率。我可以使用概率来评估我的模型,例如使用ROC。但是,如果我想预测一个类别,则需要选择一个临界值,例如0.5,然后说:“每个p <0.5的观察结果进入0级,而p> 0.5的观察结果进入1级。这通常是一个好选择。 。选择,如果你的前科是0.5-0.5,但不平衡的问题,我需要一个不同的截止我的问题真的被问如何是截止在scikit使用.predict时(处理)。
ADJ

大多数分类器不是概率分类器。他们可以以某种方式“乘积”该概率(估计)的事实并不意味着他们实际上是“使用它”来进行预测的。这就是为什么我将其称为可能的混乱。Predict调用用于进行预测的原始模型的例程,它可以是概率(NB),几何(SVM),基于回归(NN)或基于规则(Trees)的,因此对predict()中的概率值的疑问似乎像一个概念上的混淆。
lejlot

2
@lejlot,如果是这种情况,那么用predict_proba绘制的roc曲线的整个概念也不会变得无关紧要吗?roc曲线的不同点在不同的阈值处绘制出来的值是否不适用于predict_proba的结果?
尤金·布拉金

2

如果有人访问此线程,希望可以使用现成的功能(python 2.7)。在此示例中,截止值设计为反映原始数据集df中事件与非事件的比率,而y_prob可能是.predict_proba方法(假设分层训练/测试拆分)的结果。

def predict_with_cutoff(colname, y_prob, df):
    n_events = df[colname].values
    event_rate = sum(n_events) / float(df.shape[0]) * 100
    threshold = np.percentile(y_prob[:, 1], 100 - event_rate)
    print "Cutoff/threshold at: " + str(threshold)
    y_pred = [1 if x >= threshold else 0 for x in y_prob[:, 1]]
    return y_pred

随时批评/修改。希望在少数情况下无法解决类平衡并且数据集本身高度不平衡的情况提供帮助。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.