如何实际从randomForest :: getTree()绘制示例树?[关闭]


62

任何人都有关于如何从以下位置实际绘制几个示例树的库或代码建议:

getTree(rfobj, k, labelVar=TRUE)

(是的,我知道您不应该在操作上进行此操作,RF是一个黑匣子,依此类推。我想在视觉上检查树,以查看是否有任何违反直觉的变量,需要进行调整/组合/离散化/转换,检查我的编码因子的效果如何,等等)


先前的问题,没有合适的答案:

我实际上想绘制一个样本树。所以,现在就不要与我争论。我不是在问varImpPlot(变量重要性图)或partialPlotMDSPlot,或这些其他图,我已经知道了,但是它们不能代替查看示例树。是的,我可以目视检查的输出getTree(...,labelVar=TRUE)

(我想plot.rf.tree()贡献将是非常受欢迎的。)


6
我认为没有必要先发制人,尤其是在您要有人自愿帮助您的情况下;它遇到的不是很好。简历有礼节政策-您可能需要阅读我们的常见问题解答
gung-恢复莫妮卡

9
@gung:以前关于这个主题的每一个问题都已经沦为人们的观点,他们坚持认为没有必要,而且确实是异端,绘制一棵示例树。阅读我给出的引用。我在这里寻找有关如何编码绘制RF树的草图。
smci 2012年

3
在用户试图提供帮助并解决该问题的地方,我会看到一些答案,并附有一些质疑该想法前提的评论(我确实相信,这些想法也有帮助的精神)。当然,有可能承认有些人会不同意不作证。
gung-恢复莫妮卡

4
在一年中,任何人曾经绘制过树的地方,我看到的答案都是零。我正在寻找特定问题的具体答案。
smci 2012年

1
可以绘制使用cforest(在party包中)构建的单个树。否则,您必须将data.framereturn by randomForest::getTree转换为tree-like对象。
chl 2012年

Answers:


44

第一个(也是最简单的)解决方案:如果您不愿意像Andy Liaw的那样坚持使用经典的RF,则randomForest可以尝试party包,它提供了原始RF 算法的不同实现(使用条件树和基于聚合的方案单位重量平均值)。然后,如该R-help文章所述,您可以绘制树列表的单个成员。据我所知,它似乎运行顺利。下面是由生成的一棵树的图cforest(Species ~ ., data=iris, controls=cforest_control(mtry=2, mincriterion=0))

在此处输入图片说明

第二(几乎一样简单)溶液:大多数的基于树的技术在R( ,,tree 等)提供一个样结构进行打印/绘制一棵树。这个想法是将R 的输出转换为这样的R对象,即使从统计的角度来看这是毫无意义的。基本上,很容易从对象访问树结构,如下所示。请注意,根据任务类型(回归与分类)的不同,它会稍有不同。在后一种情况下,它将添加特定于类的概率作为的最后一列(即)。rpartTWIXtreerandomForest::getTreetreeobj$framedata.frame

> library(tree)
> tr <- tree(Species ~ ., data=iris)
> tr
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

 1) root 150 329.600 setosa ( 0.33333 0.33333 0.33333 )  
   2) Petal.Length < 2.45 50   0.000 setosa ( 1.00000 0.00000 0.00000 ) *
   3) Petal.Length > 2.45 100 138.600 versicolor ( 0.00000 0.50000 0.50000 )  
     6) Petal.Width < 1.75 54  33.320 versicolor ( 0.00000 0.90741 0.09259 )  
      12) Petal.Length < 4.95 48   9.721 versicolor ( 0.00000 0.97917 0.02083 )  
        24) Sepal.Length < 5.15 5   5.004 versicolor ( 0.00000 0.80000 0.20000 ) *
        25) Sepal.Length > 5.15 43   0.000 versicolor ( 0.00000 1.00000 0.00000 ) *
      13) Petal.Length > 4.95 6   7.638 virginica ( 0.00000 0.33333 0.66667 ) *
     7) Petal.Width > 1.75 46   9.635 virginica ( 0.00000 0.02174 0.97826 )  
      14) Petal.Length < 4.95 6   5.407 virginica ( 0.00000 0.16667 0.83333 ) *
      15) Petal.Length > 4.95 40   0.000 virginica ( 0.00000 0.00000 1.00000 ) *
> tr$frame
            var   n        dev       yval splits.cutleft splits.cutright yprob.setosa yprob.versicolor yprob.virginica
1  Petal.Length 150 329.583687     setosa          <2.45           >2.45   0.33333333       0.33333333      0.33333333
2        <leaf>  50   0.000000     setosa                                  1.00000000       0.00000000      0.00000000
3   Petal.Width 100 138.629436 versicolor          <1.75           >1.75   0.00000000       0.50000000      0.50000000
6  Petal.Length  54  33.317509 versicolor          <4.95           >4.95   0.00000000       0.90740741      0.09259259
12 Sepal.Length  48   9.721422 versicolor          <5.15           >5.15   0.00000000       0.97916667      0.02083333
24       <leaf>   5   5.004024 versicolor                                  0.00000000       0.80000000      0.20000000
25       <leaf>  43   0.000000 versicolor                                  0.00000000       1.00000000      0.00000000
13       <leaf>   6   7.638170  virginica                                  0.00000000       0.33333333      0.66666667
7  Petal.Length  46   9.635384  virginica          <4.95           >4.95   0.00000000       0.02173913      0.97826087
14       <leaf>   6   5.406735  virginica                                  0.00000000       0.16666667      0.83333333
15       <leaf>  40   0.000000  virginica                                  0.00000000       0.00000000      1.00000000

然后,有一些方法可以漂亮地打印和绘制这些对象。关键功能是一种通用tree:::plot.tree方法(我放了一个三元组:,使您可以直接在R中查看代码),它依赖于tree:::treepl(图形显示)和tree:::treeco(计算节点坐标)。这些功能需要obj$frame树的表示。其他细微的问题:(1)type = c("proportional", "uniform")默认绘图方法中的参数tree:::plot.tree有助于管理节点之间的垂直距离(proportional意味着它与偏差成正比,uniform意味着它是固定的);(2)您需要plot(tr)通过调用来text(tr)补充内容,以向节点和拆分添加文本标签,在这种情况下,这意味着您还必须查看tree:::text.tree

getTreefrom中的方法randomForest返回不同的结构,该结构在联机帮助中进行了说明。典型输出如下所示,终端节点用status代码(-1)表示。(同样,输出将根据任务的类型而有所不同,但仅取决于statusprediction列。)

> library(randomForest)
> rf <- randomForest(Species ~ ., data=iris)
> getTree(rf, 1, labelVar=TRUE)
   left daughter right daughter    split var split point status prediction
1              2              3 Petal.Length        4.75      1       <NA>
2              4              5 Sepal.Length        5.45      1       <NA>
3              6              7  Sepal.Width        3.15      1       <NA>
4              8              9  Petal.Width        0.80      1       <NA>
5             10             11  Sepal.Width        3.60      1       <NA>
6              0              0         <NA>        0.00     -1  virginica
7             12             13  Petal.Width        1.90      1       <NA>
8              0              0         <NA>        0.00     -1     setosa
9             14             15  Petal.Width        1.55      1       <NA>
10             0              0         <NA>        0.00     -1 versicolor
11             0              0         <NA>        0.00     -1     setosa
12            16             17 Petal.Length        5.40      1       <NA>
13             0              0         <NA>        0.00     -1  virginica
14             0              0         <NA>        0.00     -1 versicolor
15             0              0         <NA>        0.00     -1  virginica
16             0              0         <NA>        0.00     -1 versicolor
17             0              0         <NA>        0.00     -1  virginica

如果您可以设法将上表转换为生成的表,尽管我没有这种方法的示例,但tree您也许可以自定义tree:::treepltree:::treecotree:::text.tree满足您的需求。特别是,您可能希望摆脱使用偏差,类概率等问题,这些问题在RF中没有意义。您只需要设置节点坐标和分割值。您可以使用fixInNamespace()它,但是老实说,我不确定这是正确的方法。

第三种(当然也是聪明的)解决方案:编写一个真正的as.tree帮助程序函数,以减轻上述所有“补丁”。然后,您可以使用R的绘图方法,或者更好的是使用Klimt(直接来自R)来显示单个树。


40

我迟到了四年,但是如果您真的想坚持使用该randomForest软件包(这样做有很多充分的理由),并且想真正地可视化该树,则可以使用reprtree软件包。

该软件包的文档不是很好(您可以在此处找到文档),但是一切都很简单。要安装该软件包,请在存储库中引用initialize.R,因此只需运行以下命令:

options(repos='http://cran.rstudio.org')
have.packages <- installed.packages()
cran.packages <- c('devtools','plotrix','randomForest','tree')
to.install <- setdiff(cran.packages, have.packages[,1])
if(length(to.install)>0) install.packages(to.install)

library(devtools)
if(!('reprtree' %in% installed.packages())){
  install_github('araastat/reprtree')
}
for(p in c(cran.packages, 'reprtree')) eval(substitute(library(pkg), list(pkg=p)))

然后继续制作模型和树:

library(randomForest)
library(reprtree)

model <- randomForest(Species ~ ., data=iris, importance=TRUE, ntree=500, mtry = 2, do.trace=100)

reprtree:::plot.getTree(model)

然后你去了!美丽而简单。

从plot.getTree(model)生成的树

您可以检查github存储库以了解包中的其他方法。实际上,如果您检查plot.getTree.R,您会注意到作者使用了自己的实现,as.tree()chl♦建议您在自己的答案中构建自己。这意味着您可以执行以下操作:

tree <- getTree(model, k=1, labelVar=TRUE)
realtree <- reprtree:::as.tree(tree, model)

然后可能realtree与其他树状绘图包(例如tree)一起使用


非常感谢,我仍然很高兴地接受答案,这似乎是人们对产品不屑一顾的领域。我想新事物也将得到支持xgboost
smci

6
没问题。我花了几个小时才能找到该库/程序包,所以我认为,如果这对您没有用处,那将是对其他尝试在仍然坚持使用该randomForest程序包的同时绘制树木的人。
jgozal

2
很酷的发现。注意:从某种意义上说,它绘制了有代表性的树,该树是集合中所有其他树平均“最接近”的树
Chris

2
@Chris函数plot.getTree()绘制一棵单独的树。该plot.reprtree()软件包中的函数绘制了代表性的树。
春丽

1
我从插入符中获取模型,并想使用,将其馈入reptree reprtree:::plot.getTree(mod_rf_1$finalModel),但是,“ data.frame(var = fr $ var,splits = as.character(gTree [,“ split point”]])中出现错误:参数暗示不同行数:
2631,0

15

我创建了一些函数来提取树的规则。

#**************************
#return the rules of a tree
#**************************
getConds<-function(tree){
  #store all conditions into a list
  conds<-list()
  #start by the terminal nodes and find previous conditions
  id.leafs<-which(tree$status==-1)
	  j<-0
	  for(i in id.leafs){
		j<-j+1
		prevConds<-prevCond(tree,i)
		conds[[j]]<-prevConds$cond
		while(prevConds$id>1){
		  prevConds<-prevCond(tree,prevConds$id)
		  conds[[j]]<-paste(conds[[j]]," & ",prevConds$cond)
        }
		if(prevConds$id==1){
			conds[[j]]<-paste(conds[[j]]," => ",tree$prediction[i])
    }
    }

  }

  return(conds)
}

#**************************
#find the previous conditions in the tree
#**************************
prevCond<-function(tree,i){
  if(i %in% tree$right_daughter){
		id<-which(tree$right_daughter==i)
		cond<-paste(tree$split_var[id],">",tree$split_point[id])
	  }
	  if(i %in% tree$left_daughter){
    id<-which(tree$left_daughter==i)
		cond<-paste(tree$split_var[id],"<",tree$split_point[id])
  }

  return(list(cond=cond,id=id))
}

#remove spaces in a word
collapse<-function(x){
  x<-sub(" ","_",x)

  return(x)
}


data(iris)
require(randomForest)
mod.rf <- randomForest(Species ~ ., data=iris)
tree<-getTree(mod.rf, k=1, labelVar=TRUE)
#rename the name of the column
colnames(tree)<-sapply(colnames(tree),collapse)
rules<-getConds(tree)
print(rules)
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.