Que desvio o glmnet está usando para comparar valores de


8

Um critério para selecionar o valor ideal de λ com uma rede elástica ou regressão penalizada semelhante é examinar um gráfico do desvio em relação à faixa de λ e selecione λ quando o desvio é minimizado (ou λ dentro de um erro padrão do mínimo).

No entanto, estou tendo dificuldade em entender o que, precisamente, glmnetexibe com plot.cv.glmnet, porque o gráfico exibido não se parece com os resultados de plotar o desvio contraλ.

set.seed(4567)
N       <- 500
P       <- 100
coefs   <- NULL
for(p in 1:P){
    coefs[p]    <- (-1)^p*100*2^(-p)
}
inv.logit <- function(x) exp(x)/(1+exp(x))
X   <- matrix(rnorm(N*P), ncol=P, nrow=N)
Y   <- rbinom(N, size=1, p=inv.logit(cbind(1, X)%*%c(-4, coefs)))
plot(test   <- cv.glmnet(x=X, y=Y, family="binomial", nfolds=10, alpha=0.8))
plot(log(test$lambda), deviance(test$glmnet.fit))

insira a descrição da imagem aqui insira a descrição da imagem aqui

Parece que o segundo gráfico não incorpora a penalidade líquida elástica e também é dimensionado incorretamente na vertical. Baseei a afirmação na base de que o formato da curva para valores maiores deλse assemelha ao da glmnetsaída. No entanto, quando tentei calcular a penalidade sozinho, minha tentativa também parece ser imprecisa.

penalized.dev.fn    <- function(lambda, alpha=0.2, data, cv.model.obj){
    dev <- deviance(cv.model.obj$glmnet.fit)[seq_along(cv.model.obj$lambda)[cv.model.obj$lambda==lambda]]
    beta <- coef(cv.model.obj, s=lambda)[rownames(coef(cv.model.obj))!="(Intercept)"]
    penalty <- lambda * ( (1-alpha)/2*(beta%*%beta) + alpha*sum(abs(beta)) )
    penalized.dev <- penalty+dev
    return(penalized.dev)
}

out <- sapply(test$lambda, alpha=0.2, cv.model.obj=test, FUN=penalized.dev.fn)
    plot(log(test$lambda), out)

Minha pergunta é: como alguém calcula manualmente o desvio relatado no plot.cv.glmnetdiagrama padrão ? Qual é a fórmula e o que fiz de errado na minha tentativa de calculá-la?


Você está ciente de que cv.glmnetestá executando uma validação cruzada de 10 vezes, certo? Então, ele está plotando o erro padrão médio de +/- 1 do desvio nos dados de espera de 10%?
Andrew M

Estou ciente disso, sim.
Sycorax diz Restabelecer Monica

Respostas:


6

Eu só queria adicionar à entrada, mas no momento não tenho uma resposta concisa e é muito longo para um comentário. Espero que isso dê mais informações.

Parece que a função de interesse está na biblioteca glmnet descompactada e é chamada cv.lognet.R É difícil rastrear tudo explicitamente, assim como o código S3 / S4, mas a função acima está listada como uma 'função glmnet interna , 'usado pelos autores e parece coincidir com a forma como o cv.glmnet está calculando o desvio binomial.

Embora eu não tenha visto isso em nenhum lugar do artigo, desde o rastreamento do código glmnet até o cv.lognet, o que eu entendo é que ele está usando algo chamado desvio binomial limitado descrito aqui .

[Ylog10(E)+(1Y)log10(1E)]

predmat é uma matriz dos valores de probabilidade máxima (E, 1-E) de saída para cada lambda, que são comparados aos valores de complemento de y e y resultando em lp. Em seguida, eles são colocados no formato de desvio 2 * (ly-lp) e calculados a média das dobras validadas cruzadas para obter cvm - O erro médio de validação cruzada - e os intervalos de cv que você plotou na primeira imagem.

Penso que a função de desvio manual (2ª parcela) não é calculada da mesma forma que esta interna (1ª parcela).

    # from cv.lognet.R

    cvraw=switch(type.measure,
    "mse"=(y[,1]-(1-predmat))^2 +(y[,2]-predmat)^2,
    "mae"=abs(y[,1]-(1-predmat)) +abs(y[,2]-predmat),
    "deviance"= {
      predmat=pmin(pmax(predmat,prob_min),prob_max)
      lp=y[,1]*log(1-predmat)+y[,2]*log(predmat)
      ly=log(y)
      ly[y==0]=0
      ly=drop((y*ly)%*%c(1,1))
      2*(ly-lp)

   # cvm output
   cvm=apply(cvraw,2,weighted.mean,w=weights,na.rm=TRUE)

Obrigado pela resposta, pat. Isso aborda todas as perguntas que eu tinha sobre como o procedimento funciona e os conceitos estatísticos subjacentes, não apenas o software.
Sycorax diz Restabelecer Monica

2

Então visitei o site da CRAN e baixei o que acho que é a fonte do pacote glmnet . Em ./glmnet/R/plot.cv.glmnet.R, parece que você encontrará o código fonte que procura. É bastante breve, então colarei aqui, mas provavelmente é melhor se você mesmo verificar para ter certeza de que é realmente o código que está sendo executado.

plot.cv.glmnet=function(x,sign.lambda=1,...){
  cvobj=x
  xlab="log(Lambda)"
  if(sign.lambda<0)xlab=paste("-",xlab,sep="")
  plot.args=list(x=sign.lambda*log(cvobj$lambda),y=cvobj$cvm,ylim=range(cvobj$cvup,cvobj$cvlo),xlab=xlab,ylab=cvobj$name,type="n")
      new.args=list(...)
      if(length(new.args))plot.args[names(new.args)]=new.args
    do.call("plot",plot.args)
    error.bars(sign.lambda*log(cvobj$lambda),cvobj$cvup,cvobj$cvlo,width=0.01,col="darkgrey")
  points(sign.lambda*log(cvobj$lambda),cvobj$cvm,pch=20,col="red")
axis(side=3,at=sign.lambda*log(cvobj$lambda),labels=paste(cvobj$nz),tick=FALSE,line=0)
abline(v=sign.lambda*log(cvobj$lambda.min),lty=3)
    abline(v=sign.lambda*log(cvobj$lambda.1se),lty=3)
  invisible()
}

1
Os métodos S3 estão um pouco ocultos no R, mas para ver exatamente o que está sendo executado, você pode digitar getS3method('plot', 'cv.glmnet')sem precisar se preocupar em baixar o pacote de origem. (Internamente, glmnetacabou de definir uma função chamada, plot.cv.glmnetmas não a exportou. Você ainda pode vê-la espiando dentro do espaço de nome com o :::operador :) glmnet:::plot.cv.glmnet.
Andrew M

(+1) Obrigado pela resposta, Diego. Isso me aponta na direção certa e indica implicitamente onde errei. No entanto, vou adiar a aceitação por enquanto, porque isso não responde à pergunta estatística (vice-programação) específica, que é declarada na parte inferior do meu post.
Sycorax diz Restabelecer Monica
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.