O que é entropia cruzada?


92

Eu sei que existem muitas explicações sobre o que é entropia cruzada, mas ainda estou confuso.

É apenas um método para descrever a função de perda? Podemos usar o algoritmo de descida gradiente para encontrar o mínimo usando a função de perda?


10
Não é um bom ajuste para SO. Aqui está uma pergunta semelhante no site irmão da datascience: datascience.stackexchange.com/questions/9302/…
Metropolis

Respostas:


230

A entropia cruzada é comumente usada para quantificar a diferença entre duas distribuições de probabilidade. Normalmente, a distribuição "verdadeira" (aquela que o seu algoritmo de aprendizado de máquina está tentando corresponder) é expressa em termos de uma distribuição one-hot.

Por exemplo, suponha que para uma instância de treinamento específica, o rótulo seja B (dentre os rótulos possíveis A, B e C). A distribuição one-hot para esta instância de treinamento é, portanto:

Pr(Class A)  Pr(Class B)  Pr(Class C)
        0.0          1.0          0.0

Você pode interpretar a distribuição "verdadeira" acima para significar que a instância de treinamento tem 0% de probabilidade de ser da classe A, 100% de probabilidade de ser da classe B e 0% de probabilidade de ser da classe C.

Agora, suponha que seu algoritmo de aprendizado de máquina preveja a seguinte distribuição de probabilidade:

Pr(Class A)  Pr(Class B)  Pr(Class C)
      0.228        0.619        0.153

Quão próxima está a distribuição prevista da distribuição verdadeira? Isso é o que determina a perda de entropia cruzada. Use esta fórmula:

Fórmula de perda de entropia cruzada

Onde p(x)está a probabilidade desejada e q(x)a probabilidade real. A soma é sobre as três classes A, B e C. Neste caso, a perda é de 0,479 :

H = - (0.0*ln(0.228) + 1.0*ln(0.619) + 0.0*ln(0.153)) = 0.479

Portanto, é assim que sua previsão está "errada" ou "distante" da distribuição verdadeira.

A entropia cruzada é uma das muitas funções de perda possíveis (outra função popular é a perda de dobradiça SVM). Essas funções de perda são normalmente escritas como J (theta) e podem ser usadas em gradiente descendente, que é um algoritmo iterativo para mover os parâmetros (ou coeficientes) em direção aos valores ideais. Na equação abaixo, você substituiria J(theta)por H(p, q). Mas observe que você precisa primeiro calcular a derivada de H(p, q)em relação aos parâmetros.

Gradiente descendente

Então, para responder diretamente às suas perguntas originais:

É apenas um método para descrever a função de perda?

A entropia cruzada correta descreve a perda entre duas distribuições de probabilidade. É uma das muitas funções de perda possíveis.

Então, podemos usar, por exemplo, o algoritmo de descida gradiente para encontrar o mínimo.

Sim, a função de perda de entropia cruzada pode ser usada como parte da descida do gradiente.

Leitura adicional: uma das minhas outras respostas relacionadas ao TensorFlow.


então, entropia cruzada descreve a perda por soma das probabilidades para cada exemplo X.
theateist

então, podemos em vez de descrever o erro como entropia cruzada, descrever o erro como um ângulo entre dois vetores (similaridade de cosseno / distância angular) e tentar minimizar o ângulo?
theateist 01 de

1
aparentemente não é a melhor solução, mas eu só queria saber, em teoria, se poderíamos usar cosine (dis)similaritypara descrever o erro pelo ângulo e depois tentar minimizar o ângulo.
theateist

2
@Stephen: Se você olhar o exemplo que dei, p(x)seria a lista de probabilidades de verdade para cada uma das classes, o que seria [0.0, 1.0, 0.0. Da mesma forma, q(x)é a lista de probabilidade prevista para cada uma das classes [0.228, 0.619, 0.153],. H(p, q)é então - (0 * log(2.28) + 1.0 * log(0.619) + 0 * log(0.153)), que resulta em 0,479. Observe que é comum usar a np.log()função do Python , que na verdade é o log natural; Não importa.
stackoverflowuser2010

1
@HAr: Para a codificação one-hot do rótulo verdadeiro, há apenas uma classe diferente de zero com a qual nos importamos. No entanto, a entropia cruzada pode comparar quaisquer duas distribuições de probabilidade; não é necessário que um deles tenha probabilidades um-hot.
stackoverflowuser2010

2

Em suma, entropia cruzada (CE) é a medida de quão longe está seu valor previsto do rótulo verdadeiro.

A cruz aqui se refere ao cálculo da entropia entre dois ou mais recursos / rótulos verdadeiros (como 0, 1).

E o próprio termo entropia se refere à aleatoriedade, então um grande valor dele significa que sua previsão está longe dos rótulos reais.

Assim, os pesos são alterados para reduzir CE e, assim, finalmente leva a uma diferença reduzida entre a previsão e os rótulos verdadeiros e, portanto, melhor precisão.


1

Somando-se as postagens acima, a forma mais simples de perda de entropia cruzada é conhecida como entropia cruzada binária (usada como função de perda para classificação binária, por exemplo, com regressão logística), enquanto a versão generalizada é entropia cruzada categórica (usada como função de perda para problemas de classificação multiclasse, por exemplo, com redes neurais).

A ideia continua a mesma:

  1. quando a probabilidade de classe calculada pelo modelo (softmax) torna-se próxima de 1 para o rótulo de destino para uma instância de treinamento (representado com codificação one-hot, por exemplo), a perda de CCE correspondente diminui para zero

  2. caso contrário, aumenta à medida que a probabilidade prevista correspondente à classe-alvo se torna menor.

A figura a seguir demonstra o conceito (observe a partir da figura que o BCE se torna baixo quando y e p são altos ou ambos são baixos simultaneamente, ou seja, há um acordo):

insira a descrição da imagem aqui

A entropia cruzada está intimamente relacionada à entropia relativa ou divergência KL que calcula a distância entre duas distribuições de probabilidade. Por exemplo, entre dois pmfs discretos, a relação entre eles é mostrada na figura a seguir:

insira a descrição da imagem aqui

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.