Florestas aleatórias dificilmente são uma caixa preta. Eles são baseados em árvores de decisão, que são muito fáceis de interpretar:
#Setup a binary classification problem
require(randomForest)
data(iris)
set.seed(1)
dat <- iris
dat$Species <- factor(ifelse(dat$Species=='virginica','virginica','other'))
trainrows <- runif(nrow(dat)) > 0.3
train <- dat[trainrows,]
test <- dat[!trainrows,]
#Build a decision tree
require(rpart)
model.rpart <- rpart(Species~., train)
Isso resulta em uma árvore de decisão simples:
> model.rpart
n= 111
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 111 35 other (0.68468468 0.31531532)
2) Petal.Length< 4.95 77 3 other (0.96103896 0.03896104) *
3) Petal.Length>=4.95 34 2 virginica (0.05882353 0.94117647) *
Se Petal.Length <4.95, essa árvore classifica a observação como "outra". Se for maior que 4,95, classifica a observação como "virginica". Uma floresta aleatória é simples, uma coleção de muitas dessas árvores, onde cada uma é treinada em um subconjunto aleatório dos dados. Cada árvore "vota" na classificação final de cada observação.
model.rf <- randomForest(Species~., train, ntree=25, proximity=TRUE, importance=TRUE, nodesize=5)
> getTree(model.rf, k=1, labelVar=TRUE)
left daughter right daughter split var split point status prediction
1 2 3 Petal.Width 1.70 1 <NA>
2 4 5 Petal.Length 4.95 1 <NA>
3 6 7 Petal.Length 4.95 1 <NA>
4 0 0 <NA> 0.00 -1 other
5 0 0 <NA> 0.00 -1 virginica
6 0 0 <NA> 0.00 -1 other
7 0 0 <NA> 0.00 -1 virginica
Você pode até extrair árvores individuais da RF e observar sua estrutura. O formato é um pouco diferente do dos rpart
modelos, mas você pode inspecionar cada árvore, se quiser, e ver como está modelando os dados.
Além disso, nenhum modelo é realmente uma caixa preta, porque você pode examinar as respostas previstas versus as respostas reais para cada variável no conjunto de dados. Essa é uma boa ideia, independentemente do tipo de modelo que você está construindo:
library(ggplot2)
pSpecies <- predict(model.rf,test,'vote')[,2]
plotData <- lapply(names(test[,1:4]), function(x){
out <- data.frame(
var = x,
type = c(rep('Actual',nrow(test)),rep('Predicted',nrow(test))),
value = c(test[,x],test[,x]),
species = c(as.numeric(test$Species)-1,pSpecies)
)
out$value <- out$value-min(out$value) #Normalize to [0,1]
out$value <- out$value/max(out$value)
out
})
plotData <- do.call(rbind,plotData)
qplot(value, species, data=plotData, facets = type ~ var, geom='smooth', span = 0.5)
Normalizei as variáveis (comprimento e largura da sépala e pétala) para um intervalo de 0 a 1. A resposta também é 0-1, onde 0 é outro e 1 é virginica. Como você pode ver, a floresta aleatória é um bom modelo, mesmo no conjunto de testes.
Além disso, uma floresta aleatória calculará várias medidas de importância variável, que podem ser muito informativas:
> importance(model.rf, type=1)
MeanDecreaseAccuracy
Sepal.Length 0.28567162
Sepal.Width -0.08584199
Petal.Length 0.64705819
Petal.Width 0.58176828
Esta tabela representa quanto a remoção de cada variável reduz a precisão do modelo. Por fim, existem muitos outros gráficos que você pode fazer a partir de um modelo de floresta aleatório, para visualizar o que está acontecendo na caixa preta:
plot(model.rf)
plot(margin(model.rf))
MDSplot(model.rf, iris$Species, k=5)
plot(outlier(model.rf), type="h", col=c("red", "green", "blue")[as.numeric(dat$Species)])
Você pode visualizar os arquivos de ajuda de cada uma dessas funções para ter uma idéia melhor do que elas exibem.