Que função de perda devo usar para pontuar um modelo RNN seq2seq?


10

Estou trabalhando no artigo de Cho 2014 , que introduziu a arquitetura codificador-decodificador para modelagem seq2seq.

No artigo, eles parecem usar a probabilidade da saída fornecida (ou é uma probabilidade negativa de log) como a função de perda para uma entrada de comprimento e saída de comprimento :xMyN

P(y1,,yN|x1,,xM)=P(y1|x1,,xm)P(y2|y1,x1,,xm)P(yN|y1,,yN1,x1,,xm)

No entanto, acho que vejo vários problemas ao usar isso como uma função de perda:

  1. Parece assumir que o professor força o treinamento durante o treinamento (ou seja, em vez de usar o palpite do decodificador para uma posição como entrada para a próxima iteração, ele usa o token conhecido.
  2. Não penalizaria longas sequências. Como a probabilidade é de a da saída, se o decodificador gerasse uma sequência mais longa, tudo após o primeiro não levaria em consideração a perda.1NN
  3. Se o modelo predizer um token de fim de cadeia anterior, a função de perda ainda exigirá etapas - o que significa que estamos gerando saídas com base em um "coletor" não treinado dos modelos. Isso parece desleixado.N

Alguma dessas preocupações é válida? Em caso afirmativo, houve algum progresso em uma função de perda mais avançada?

Respostas:


1

Parece assumir que o professor força o treinamento durante o treinamento (ou seja, em vez de usar o palpite do decodificador para uma posição como entrada para a próxima iteração, ele usa o token conhecido.

O termo "forçar professor" me incomoda um pouco, porque meio que erra a ideia: não há nada de errado ou estranho em alimentar o próximo token conhecido no modelo da RNN - é literalmente a única maneira de calcular . Se você definir uma distribuição sobre seqüências de forma autoregressiva como como é comumente feito, onde cada termo condicional é modelado com uma RNN, "forçar professor" é o verdadeiro procedimento que maximiza corretamente a probabilidade do log. (Eu omito escrever a sequência de condicionamento acima porque ela não muda nada.)logP(y1,,yN)P(y)=iP(yi|y<i)x

Dada a onipresença do MLE e a falta de boas alternativas, não acho que supor que "forçar professores" seja censurável.

Apesar de tudo, há problemas com ele - ou seja, o modelo atribui alta probabilidade a todos os pontos de dados, mas as amostras do modelo não são necessariamente prováveis ​​na verdadeira distribuição de dados (o que resulta em amostras de "baixa qualidade"). Você pode estar interessado em "Professor Forçar" (Lamb et al.), Que mitiga isso através de um procedimento de treinamento antagônico sem abrir mão do MLE.

Não penalizaria longas sequências. Como a probabilidade é de 1 a N da saída, se o decodificador gerasse uma sequência mais longa, tudo após o primeiro N não levaria em consideração a perda.

e

Se o modelo predizer um token de fim de cadeia anterior, a função de perda ainda exigirá N etapas - o que significa que estamos gerando saídas com base em um "coletor" não treinado dos modelos. Isso parece desleixado.

Nenhum desses problemas ocorre durante o treinamento. Em vez de pensar em um modelo de sequência autoregressiva como um procedimento para gerar uma previsão, pense nele como uma maneira de calcular a probabilidade de uma determinada sequência. O modelo nunca prevê nada - você pode fazer uma amostra de uma sequência ou de um token de uma distribuição, ou perguntar qual é o próximo token mais provável - mas eles são crucialmente diferentes de uma previsão (e você não faz uma amostra durante o treinamento ou).

Em caso afirmativo, houve algum progresso em uma função de perda mais avançada?

É possível que haja objetivos projetados especificamente caso a caso para diferentes tarefas de modelagem. No entanto, eu diria que o MLE ainda é dominante - o recente modelo GPT2, que alcançou desempenho de ponta em um amplo espectro de tarefas de modelagem e compreensão de linguagem natural, foi treinado com ele.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.