使用scikit Learn选择功能后识别过滤的功能


10

这是我的Python 功能选择方法代码:

from sklearn.svm import LinearSVC
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
X.shape
(150, 4)
X_new = LinearSVC(C=0.01, penalty="l1", dual=False).fit_transform(X, y)
X_new.shape
(150, 3)

但是在获得新的X(因变量-X_new)之后,我如何知道在此新的更新变量中删除了哪些变量以及考虑了哪些变量?(已删除的一个或数据中存在的三个。)

获得此标识的原因是对新的测试数据应用相同的过滤。

Answers:


6

您可以做两件事:

  • 检查coef_参数并检测哪个列被忽略
  • 使用方法将相同模型用于输入数据转换 transform

您的示例的小修改

>>> from sklearn.svm import LinearSVC
>>> from sklearn.datasets import load_iris
>>> from sklearn.cross_validation import train_test_split
>>>
>>> iris = load_iris()
>>> x_train, x_test, y_train, y_test = train_test_split(
...     iris.data, iris.target, train_size=0.7
... )
>>>
>>> svc = LinearSVC(C=0.01, penalty="l1", dual=False)
>>>
>>> X_train_new = svc.fit_transform(x_train, y_train)
>>> print(X_train_new.shape)
(105, 3)
>>>
>>> X_test_new = svc.transform(x_test)
>>> print(X_test_new.shape)
(45, 3)
>>>
>>> print(svc.coef_)
[[ 0.          0.10895557 -0.20603044  0.        ]
 [-0.00514987 -0.05676593  0.          0.        ]
 [ 0.         -0.09839843  0.02111212  0.        ]]

如您所见,方法transform可以为您完成所有工作。而且从coef_矩阵中您可以看到最后一列只是零向量,因此您可以对数据中的最后一列进行建模


嗨,我如何识别X_train_new的列名。有功能吗?
Vignesh Prajapati

1
它们的顺序与输入数据集中的顺序相同。iris.feature_names
itdxer

是。它的。我在这里很困惑。它的顺序相同。但是我如何获得它们的名称,因为某些列已被忽略。因此,我无法获得在此过程中选择的那些特定列。您能帮我吗!
Vignesh Prajapati 2015年

您是否feature_namesiris变量中检查了方法?这对我来说可以。
itdxer 2015年

12

另外,如果在安装完SVC之后使用SelectFromModel进行特征选择,则可以使用instance方法get_support。这将返回一个布尔数组,该数组映射每个功能的选择。接下来,将其与原始要素名称数组关联,然后过滤布尔状态以生成一组相关的所选要素名称。

希望这对未来的读者有所帮助,他们也将为在选择功能后寻找最佳方法来获取相关功能名称的努力。

例:

lsvc = LinearSVC(C=0.01, penalty="l1", dual=False,max_iter=2000).fit(X, y) 
model = sk.SelectFromModel(lsvc, prefit=True) 
X_new = model.transform(X) 
print(X_new.shape) 
print(model.get_support()) 

5
这应该被接受
user0

5

基于@chinnychinchin解决方案,我通常这样做:

lsvc = LinearSVC(C=0.01, penalty="l1", dual=False,max_iter=2000).fit(X, y) 
model = sk.SelectFromModel(lsvc, prefit=True) 
X_new = model.transform(X) 
print(X.columns[model.get_support()]) 

返回如下内容:

Index([u'feature1', u'feature2', u'feature',
  u'feature4'],
  dtype='object')
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.