Como dividir um conjunto de dados para validação cruzada 10 vezes


14

Agora eu tenho um Rquadro de dados (treinamento), alguém pode me dizer como dividir aleatoriamente esse conjunto de dados para fazer a validação cruzada de 10 vezes?


2
Certifique-se de repetir todo o processo 100 vezes para obter uma precisão satisfatória.
precisa

Certifique-se de amostrar a caixa e controlar a amostra separadamente e depois combiná-las em cada bloco.
Shicheng Guo

Se você usa caret :: train, nem precisa se preocupar com isso. Isso será feito internamente, você pode escolher a quantidade de dobras. Se você insistir em fazer isso "manualmente", use amostragem estratificada da classe conforme implementada em caret :: createFolds.
Marcel

Eu bloqueei esse tópico porque cada uma das muitas respostas o trata apenas como uma questão de codificação e não como de interesse estatístico geral.
whuber

Respostas:


22

caret tem uma função para isso:

require(caret)
flds <- createFolds(y, k = 10, list = TRUE, returnTrain = FALSE)
names(flds)[1] <- "train"

Cada elemento de fldsé uma lista de índices para cada conjunto de dados. Se o seu conjunto de dados for chamado dat, dat[flds$train,]você receberá o conjunto de treinamento, dat[ flds[[2]], ]o segundo conjunto de dobras etc.


12

Aqui está uma maneira simples de executar 10 vezes sem usar pacotes:

#Randomly shuffle the data
yourData<-yourData[sample(nrow(yourData)),]

#Create 10 equally size folds
folds <- cut(seq(1,nrow(yourData)),breaks=10,labels=FALSE)

#Perform 10 fold cross validation
for(i in 1:10){
    #Segement your data by fold using the which() function 
    testIndexes <- which(folds==i,arr.ind=TRUE)
    testData <- yourData[testIndexes, ]
    trainData <- yourData[-testIndexes, ]
    #Use the test and train data partitions however you desire...
}

-1: as funções de cursor fazem amostragem estratificada que você não está fazendo. Qual é o sentido de reinventar o weel se alguém fez as coisas mais simples para você?
Marcel

10
Você está de brincadeira? O objetivo inteiro da resposta é executar 10 vezes sem precisar instalar o pacote de interpolação inteiro. O único ponto positivo que você coloca é que as pessoas devem entender o que seu código realmente faz. Gafanhoto jovem, amostragem estratificada nem sempre é a melhor abordagem. Por exemplo, dá mais importância aos subgrupos com mais dados, o que nem sempre é desejável. (Esp, se você não sabe o que está acontecendo). Trata-se de usar a melhor abordagem para seus dados. Troll com cuidado meu amigo :)
Jake de Drew

@JakeDrew Sei que agora é uma publicação antiga, mas seria possível pedir algumas orientações sobre como usar os dados de teste e treinamento para obter o erro médio médio de um modelo VAR (p) para cada iteração?
youjustreadthis


@JakeDrew IMHO ambas as respostas merecem um plus 1. Um com um pacote, o outro com código ...
natbusa

2

Provavelmente não é o melhor caminho, mas aqui está uma maneira de fazê-lo. Tenho certeza de que, quando escrevi esse código, havia emprestado um truque de outra resposta aqui, mas não consegui encontrá-lo.

# Generate some test data
x <- runif(100)*10 #Random values between 0 and 10
y <- x+rnorm(100)*.1 #y~x+error
dataset <- data.frame(x,y) #Create data frame
plot(dataset$x,dataset$y) #Plot the data

#install.packages("cvTools")
library(cvTools) #run the above line if you don't have this library

k <- 10 #the number of folds

folds <- cvFolds(NROW(dataset), K=k)
dataset$holdoutpred <- rep(0,nrow(dataset))

for(i in 1:k){
  train <- dataset[folds$subsets[folds$which != i], ] #Set the training set
  validation <- dataset[folds$subsets[folds$which == i], ] #Set the validation set

  newlm <- lm(y~x,data=train) #Get your new linear model (just fit on the train data)
  newpred <- predict(newlm,newdata=validation) #Get the predicitons for the validation set (from the model just fit on the train data)

  dataset[folds$subsets[folds$which == i], ]$holdoutpred <- newpred #Put the hold out prediction in the data set for later use
}

dataset$holdoutpred #do whatever you want with these predictions

1

por favor, encontre abaixo algum outro código que eu uso (emprestado e adaptado de outra fonte). Copiei direto de um script que acabei de usar, deixado na rotina rpart. A parte provavelmente mais interessante são as linhas na criação das dobras. Como alternativa - você pode usar a função crossval do pacote de inicialização.

#define error matrix
err <- matrix(NA,nrow=1,ncol=10)
errcv=err

#creation of folds
for(c in 1:10){

n=nrow(df);K=10; sizeblock= n%/%K;alea=runif(n);rang=rank(alea);bloc=(rang-1)%/%sizeblock+1;bloc[bloc==K+1]=K;bloc=factor(bloc); bloc=as.factor(bloc);print(summary(bloc))

for(k in 1:10){

#rpart
fit=rpart(type~., data=df[bloc!=k,],xval=0) ; (predict(fit,df[bloc==k,]))
answers=(predict(fit,df[bloc==k,],type="class")==resp[bloc==k])
err[1,k]=1-(sum(answers)/length(answers))

}

err
errcv[,c]=rowMeans(err, na.rm = FALSE, dims = 1)

}
errcv

1
# Evaluate models uses k-fold cross-validation
install.packages("DAAG")
library("DAAG")

cv.lm(data=dat, form.lm=mod1, m= 10, plotit = F)

Tudo feito para você em uma linha de código!

?cv.lm for information on input and output

0

Como não fiz minha abordagem nesta lista, pensei em compartilhar outra opção para pessoas que não desejam instalar pacotes para uma rápida validação cruzada

# get the data from somewhere and specify number of folds
data <- read.csv('my_data.csv')
nrFolds <- 10

# generate array containing fold-number for each sample (row)
folds <- rep_len(1:nrFolds, nrow(data))

# actual cross validation
for(k in 1:nrFolds) {
    # actual split of the data
    fold <- which(folds == k)
    data.train <- data[-fold,]
    data.test <- data[fold,]

    # train and test your model with data.train and data.test
}

Observe que o código acima pressupõe que os dados já estão embaralhados. Se não for esse o caso, considere adicionar algo como

folds <- sample(folds, nrow(data))
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.