它并不适合(取决于定义)。测试集的目标信息被保留。半监督允许生成额外的综合数据集以训练模型。在所描述的方法中,原始训练数据未加权与合成以4:3的比例混合。因此,如果综合数据的质量较差,则该方法将带来灾难性的后果。我猜想对于预测不确定的任何问题,综合数据集的准确性都会很差。我想,如果底层结构非常复杂且系统具有较低的噪声,则可能有助于生成合成数据。我认为半监督学习在深度学习(不是我的专业知识)中非常重要,在该学习中也要学习特征表示。
我试图通过对rf和xgboost的多个数据集进行半监督训练来提高准确性,但没有任何积极结果。[随意编辑我的代码。]我注意到在kaggle报告中使用半监督的准确性的实际提高是相当适度的,也许是随机的?
rm(list=ls())
#define a data structure
fy2 = function(nobs=2000,nclass=9) sample(1:nclass-1,nobs,replace=T)
fX2 = function(y,noise=.05,twist=8,min.width=.7) {
x1 = runif(length(y)) * twist
helixStart = seq(0,2*pi,le=length(unique(y))+1)[-1]
x2 = sin(helixStart[y+1]+x1)*(abs(x1)+min.width) + rnorm(length(y))*noise
x3 = cos(helixStart[y+1]+x1)*(abs(x1)+min.width) + rnorm(length(y))*noise
cbind(x1,x2,x3)
}
#define a wrapper to predict n-1 folds of test set and retrain and predict last fold
smartTrainPred = function(model,trainX,trainy,testX,nfold=4,...) {
obj = model(trainX,trainy,...)
folds = split(sample(1:dim(trainX)[1]),1:nfold)
predDF = do.call(rbind,lapply(folds, function(fold) {
bigX = rbind(trainX ,testX[-fold,])
bigy = c(trainy,predict(obj,testX[-fold,]))
if(is.factor(trainy)) bigy=factor(bigy-1)
bigModel = model(bigX,bigy,...)
predFold = predict(bigModel,testX[fold,])
data.frame(sampleID=fold, pred=predFold)
}))
smartPreds = predDF[sort(predDF$sampleID,ind=T)$ix,2]
}
library(xgboost)
library(randomForest)
#complex but perfect separatable
trainy = fy2(); trainX = fX2(trainy)
testy = fy2(); testX = fX2(testy )
pairs(trainX,col=trainy+1)
#try with randomForest
rf = randomForest(trainX,factor(trainy))
normPred = predict(rf,testX)
cat("\n supervised rf", mean(testy!=normPred))
smartPred = smartTrainPred(randomForest,trainX,factor(trainy),testX,nfold=4)
cat("\n semi-supervised rf",mean(testy!=smartPred))
#try with xgboost
xgb = xgboost(trainX,trainy,
nrounds=35,verbose=F,objective="multi:softmax",num_class=9)
normPred = predict(xgb,testX)
cat("\n supervised xgboost",mean(testy!=normPred))
smartPred = smartTrainPred(xgboost,trainX,trainy,testX,nfold=4,
nrounds=35,verbose=F,objective="multi:softmax",num_class=9)
cat("\n semi-supervised xgboost",mean(testy!=smartPred))
printing prediction error:
supervised rf 0.007
semi-supervised rf 0.0085
supervised xgboost 0.046
semi-supervised xgboost 0.049