Explicação da perda de entropia cruzada


35

Suponha que eu crie um NN para classificação. A última camada é uma camada densa com ativação softmax. Eu tenho cinco classes diferentes para classificar. Suponha que, para um único exemplo de treinamento, true labelseja [1 0 0 0 0]enquanto estiverem as previsões [0.1 0.5 0.1 0.1 0.2]. Como eu calcularia a perda de entropia cruzada para este exemplo?

Respostas:


50

A fórmula da entropia cruzada assume duas distribuições, , a distribuição verdadeira e , a distribuição estimada, definida sobre a variável discreta e é dada porp(x)q(x)x

H(p,q)=xp(x)log(q(x))

Para uma rede neural, o cálculo é independente do seguinte:

  • Que tipo de camada foi usada.

  • Que tipo de ativação foi usada - embora muitas ativações não sejam compatíveis com o cálculo porque suas saídas não são interpretáveis ​​como probabilidades (ou seja, suas saídas são negativas, maiores que 1 ou não somam 1). O Softmax é frequentemente usado para a classificação em várias classes porque garante uma função de distribuição de probabilidade bem comportada.

Para uma rede neural, você geralmente verá a equação escrita em uma forma em que é o vetor da verdade fundamental e (ou algum outro valor obtido diretamente da saída da última camada) é a estimativa. Para um único exemplo, seria assim:yy^

L=ylog(y^)

onde é o produto de ponto vetorial.

Seu exemplo de base da verdade fornece toda a probabilidade ao primeiro valor, e os outros valores são zero, para que possamos ignorá-los e apenas usar o termo correspondente de suas estimativasyy^

L=(1×log(0.1)+0×log(0.5)+...)

L=log(0.1)2.303

Um ponto importante dos comentários

Isso significa que a perda seria a mesma, independentemente de as previsões serem ou ?[0.1,0.5,0.1,0.1,0.2][0.1,0.6,0.1,0.1,0.1]

Sim, esse é um recurso importante do logloss multiclass, que recompensa / penaliza apenas probabilidades de classes corretas. O valor é independente de como a probabilidade restante é dividida entre classes incorretas.

Você verá frequentemente essa equação com a média de todos os exemplos como uma função de custo . Nem sempre é rigorosamente respeitado nas descrições, mas geralmente uma função de perda é de nível inferior e descreve como uma única instância ou componente determina um valor de erro, enquanto uma função de custo é de nível superior e descreve como um sistema completo é avaliado para otimização. Uma função de custo baseada na perda de log de várias classes para um conjunto de dados de tamanho pode se parecer com isso:N

J=1N(i=1Nyilog(y^i))

Muitas implementações exigirão que seus valores de verdade básicos sejam codificados com um hot hot (com uma única classe verdadeira), pois isso permite uma otimização extra. No entanto, em princípio, a perda de entropia cruzada pode ser calculada - e otimizada - quando este não for o caso.


1
OK. Isso significa que a perda seria a mesma, independentemente de as previsões serem [0,1 0,5 0,1 0,1 0,2] ou [0,1 0,6 0,1 0,1 0,1]?
Nain

@ Naain: Isso está correto para o seu exemplo. A perda de entropia cruzada não depende de quais são os valores das probabilidades de classe incorretas.
Neil Slater

8

A resposta de Neil está correta. No entanto, acho importante ressaltar que, embora a perda não dependa da distribuição entre as classes incorretas (apenas a distribuição entre a classe correta e as demais), o gradiente dessa função de perda afeta as classes incorretas de maneira diferente, dependendo de como eles estão errados. Portanto, ao usar o cross-ent no aprendizado de máquina, você alterará os pesos de maneira diferente para [0,1 0,5 0,1 0,1 0,2] e [0,1 0,6 0,1 0,1 0,1]. Isso ocorre porque a pontuação da classe correta é normalizada pela pontuação de todas as outras classes para transformá-la em probabilidade.


3
Você pode elaborar com um exemplo adequado?
Naim

@ Lucas Adams, você pode dar um exemplo, por favor?
28418 koryakinp

A derivada de EACH y_i (saída softmax) escrita EACH logit z (ou o próprio parâmetro w) depende de CADA y_i. medium.com/@aerinykim/…
Aaron

2

Vamos ver como o gradiente da perda se comporta ... Temos a entropia cruzada como uma função de perda, que é dada por

H(p,q)=i=1np(xi)log(q(xi))=(p(x1)log(q(x1))++p(xn)log(q(xn))

Indo daqui .. gostaríamos de saber a derivada com relação a alguns : Como todos os outros termos são cancelados devido à diferenciação. Podemos levar essa equação um passo adiante para xi

xiH(p,q)=xip(xi)log(q(xi)).
xiH(p,q)=p(xi)1q(xi)q(xi)xi.

A partir disso, podemos ver que ainda estamos penalizando apenas as classes verdadeiras (para as quais há valor para ). Caso contrário, temos apenas um gradiente de zero.p(xi)

Eu me pergunto como os pacotes de software lidam com um valor previsto de 0, enquanto o valor verdadeiro era maior que zero ... Como estamos dividindo por zero nesse caso.


Eu acho que o que você quer é obter derivada wrt do parâmetro, não wrt x_i.
Aaron

1

Vamos começar entendendo a entropia na teoria da informação: suponha que você queira comunicar uma sequência de alfabetos "aaaaaaaa". Você poderia fazer isso facilmente como 8 * "a". Agora pegue outra string "jteikfqa". Existe uma maneira compactada de comunicar essa string? Não existe. Podemos dizer que a entropia da 2ª corda é mais uma vez que, para comunicá-la, precisamos de mais "bits" de informação.

Essa analogia também se aplica às probabilidades. Se você tiver um conjunto de itens, frutos por exemplo, a codificação binária desses frutos seria que n é o número de frutos. Para 8 frutas, você precisa de 3 bits e assim por diante. Outra maneira de analisar isso é que, dada a probabilidade de alguém selecionar uma fruta aleatoriamente, ser 1/8, a redução da incerteza se uma fruta for selecionada é que é 3. Mais especificamente,log2(n)log2(1/8)

i=1818log2(18)=3
Essa entropia nos fala sobre a incerteza envolvida em certas distribuições de probabilidade; quanto mais incerteza / variação em uma distribuição de probabilidade, maior é a entropia (por exemplo, para 1024 frutos, seria 10).

Na entropia "cruzada", como o nome sugere, focamos no número de bits necessários para explicar a diferença em duas distribuições de probabilidade diferentes. O melhor cenário é que ambas as distribuições são idênticas; nesse caso, é necessária a menor quantidade de bits, ou seja, entropia simples. Em termos matemáticos,

H(y,y^)=iyiloge(y^i)

Onde é o vetor de probabilidade previsto (saída Softmax) e é o vetor da verdade do terreno (por exemplo, um quente). A razão pela qual usamos log natural é porque é fácil diferenciar (ref. Cálculo de gradientes) e a razão pela qual não registramos o vetor de verdade do solo é porque ele contém muitos 0s que simplificam a soma.y^y

Conclusão: Em termos leigos, pode-se pensar em entropia cruzada como a distância entre duas distribuições de probabilidade em termos da quantidade de informação (bits) necessária para explicar essa distância. É uma maneira elegante de definir uma perda que diminui à medida que os vetores de probabilidade se aproximam.


0

Eu discordo de Lucas. Os valores acima já são probabilidades. Observe que o post original indicou que os valores tiveram uma ativação softmax.

O erro é propagado apenas de volta na classe "hot" e a probabilidade Q (i) não muda se as probabilidades nas outras classes se alternarem.


2
Lucas está correto. Com a arquitetura descrita pelo OP, o gradiente em todos os logits (em oposição às saídas) não é zero, porque a função softmax conecta todos eles. Portanto, o erro [gradiente do] na classe "quente" se propaga a todos os neurônios de saída.
Neil Slater

+1 para Neil e Lucas
Aaron

-1

O problema é que as probabilidades são provenientes de uma função 'complicada' que incorpora as outras saídas no valor fornecido. Os resultados estão interconectados; portanto, não estamos obtendo resultados reais, mas todas as entradas da última função de ativação (softmax), para cada resultado.

Eu encontrei uma descrição muito boa em deepnotes.io/softmax-crossentropy, onde o autor mostra que a derivada real é .piyi

Outra descrição interessante em gombru.github.io/2018/05/23/cross_entropy_loss .

Penso que o uso de um sigmóide simples como última camada de ativação levaria à resposta aprovada, mas o uso do softmax indica uma resposta diferente.


1
Bem-vindo ao Stack Exchange. No entanto, o que você escreveu não parece ser uma resposta da pergunta do OP sobre o cálculo da perda de entropia cruzada.
User12075
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.