Quais são as vantagens da métrica de Wasserstein em comparação com a divergência de Kullback-Leibler?


25

Qual é a diferença prática entre a métrica de Wasserstein e a divergência de Kullback-Leibler ? A métrica de Wasserstein também é chamada de distância do motor da Terra .

Da Wikipedia:

A métrica de Wasserstein (ou Vaserstein) é uma função de distância definida entre distribuições de probabilidade em um determinado espaço métrico M.

e

A divergência de Kullback-Leibler é uma medida de como uma distribuição de probabilidade diverge de uma segunda distribuição de probabilidade esperada.

Vi o KL ser usado em implementações de aprendizado de máquina, mas recentemente me deparei com a métrica de Wasserstein. Existe uma boa orientação sobre quando usar um ou outro?

(Não tenho reputação suficiente para criar uma nova tag com Wassersteinou Earth mover's distance.)



1
editar a postagem para adicionar uma tag Wasserstein com base na solicitação de pôster. Também adicionando uma resposta.
Lucas Roberts

Respostas:


28

Ao considerar as vantagens da métrica de Wasserstein em comparação com a divergência de KL, o mais óbvio é que W é uma métrica, enquanto a divergência de KL não é, uma vez que KL não é simétrica (isto é, em geral) e não satisfaz a desigualdade do triângulo (ou seja, D K L ( R | | P ) D K L ( Q | | P ) + D KDKeu(P||Q)DKeu(Q||P)não se aplica em geral).DKeu(R||P)DKeu(Q||P)+DKeu(R||Q)

Quanto à diferença prática, uma das mais importantes é que, diferentemente da KL (e de muitas outras medidas), Wasserstein leva em consideração o espaço métrico e o que isso significa em termos menos abstratos talvez seja melhor explicado por um exemplo (sinta-se à vontade para pular para a figura, código apenas para produzi-lo):

# define samples this way as scipy.stats.wasserstein_distance can't take probability distributions directly
sampP = [1,1,1,1,1,1,2,3,4,5]
sampQ = [1,2,3,4,5,5,5,5,5,5]
# and for scipy.stats.entropy (gives KL divergence here) we want distributions
P = np.unique(sampP, return_counts=True)[1] / len(sampP)
Q = np.unique(sampQ, return_counts=True)[1] / len(sampQ)
# compare to this sample / distribution:
sampQ2 = [1,2,2,2,2,2,2,3,4,5]
Q2 = np.unique(sampQ2, return_counts=True)[1] / len(sampQ2)

fig = plt.figure(figsize=(10,7))
fig.subplots_adjust(wspace=0.5)
plt.subplot(2,2,1)
plt.bar(np.arange(len(P)), P, color='r')
plt.xticks(np.arange(len(P)), np.arange(1,5), fontsize=0)
plt.subplot(2,2,3)
plt.bar(np.arange(len(Q)), Q, color='b')
plt.xticks(np.arange(len(Q)), np.arange(1,5))
plt.title("Wasserstein distance {:.4}\nKL divergence {:.4}".format(
    scipy.stats.wasserstein_distance(sampP, sampQ), scipy.stats.entropy(P, Q)), fontsize=10)
plt.subplot(2,2,2)
plt.bar(np.arange(len(P)), P, color='r')
plt.xticks(np.arange(len(P)), np.arange(1,5), fontsize=0)
plt.subplot(2,2,4)
plt.bar(np.arange(len(Q2)), Q2, color='b')
plt.xticks(np.arange(len(Q2)), np.arange(1,5))
plt.title("Wasserstein distance {:.4}\nKL divergence {:.4}".format(
    scipy.stats.wasserstein_distance(sampP, sampQ2), scipy.stats.entropy(P, Q2)), fontsize=10)
plt.show()

Métrica de Wasserstein e divergências de Kullback-Leibler para dois pares diferentes de distribuições Aqui, as medidas entre as distribuições de vermelho e azul são as mesmas para a divergência de KL, enquanto a distância de Wasserstein mede o trabalho necessário para transportar a massa de probabilidade do estado vermelho para o estado azul usando o eixo x como uma “estrada”. Essa medida é obviamente maior quanto maior a distância da massa probabilística (daí a distância do motor da terra). Então, qual você deseja usar depende da sua área de aplicação e do que você deseja medir. Como nota, em vez da divergência de KL, também existem outras opções, como a distância de Jensen-Shannon, que são métricas adequadas.


6

A métrica de Wasserstein geralmente aparece em problemas ideais de transporte, onde o objetivo é mover as coisas de uma determinada configuração para uma configuração desejada no custo mínimo ou distância mínima. O Kullback-Leibler (KL) é uma divergência (não uma métrica) e aparece frequentemente em estatística, aprendizado de máquina e teoria da informação.

Além disso, a métrica de Wasserstein não exige que ambas as medidas estejam no mesmo espaço de probabilidade, enquanto a divergência KL exige que ambas as medidas sejam definidas no mesmo espaço de probabilidade.

kμEuΣEuEu=1,2

W2(N0 0,N1)2=__μ1-μ2__22+tr(Σ1+Σ2-2(Σ21/2Σ1Σ21/2)1/2)
DKL(N0 0,N1)=12(tr(Σ1-1Σ0 0)+(μ1-μ0 0)TΣ1-1(μ1-μ0 0)-k+em(detΣ1detΣ0 0)).
Para simplificar, vamos considerar Σ1=Σ2=WEuk e μ1μ2. Com essas suposições simplificadoras, o termo traço em Wasserstein é0 0 eo termo do traço na divergência KL será 0 quando combinado com o -k termo e a razão log-determinante também é 0 0, então essas duas quantidades se tornam:
W2(N0 0,N1)2=__μ1-μ2__22
e
DKL(N0 0,N1)=(μ1-μ0 0)TΣ1-1(μ1-μ0 0).
Observe que a distância de Wasserstein não muda se a variação mudar (por exemplo, Wcomo uma grande quantidade nas matrizes de covariância) enquanto a divergência KL sim. Isso ocorre porque a distância de Wasserstein é uma função de distância nos espaços de apoio conjunto das duas medidas de probabilidade. Em contraste, a divergência KL é uma divergência e essa divergência muda com base no espaço de informação (relação sinal / ruído) das distribuições.


1

A métrica de Wasserstein é útil na validação de modelos, pois suas unidades são as da própria resposta. Por exemplo, se você estiver comparando duas representações estocásticas do mesmo sistema (por exemplo, um modelo de ordem reduzida),P e Q, e a resposta são unidades de deslocamento, a métrica de Wasserstein também está em unidades de deslocamento. Se você reduzisse sua representação estocástica a um determinístico, o CDF de cada distribuição é uma função de etapa. A métrica de Wasserstein é a diferença dos valores.

Eu acho essa propriedade uma extensão muito natural para falar sobre a diferença absoluta entre duas variáveis ​​aleatórias

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.