Como o LSTM evita o problema de gradiente de fuga?


35

O LSTM foi inventado especificamente para evitar o problema do gradiente de fuga. Supõe-se que isso seja feito com o Constant Error Carousel (CEC), que no diagrama abaixo (de Greff et al. ) Corresponde ao loop em torno da célula .

LSTM
(fonte: deeplearning4j.org )

E eu entendo que essa parte pode ser vista como uma espécie de função de identidade, então a derivada é uma e o gradiente permanece constante.

O que eu não entendo é como ele não desaparece devido a outras funções de ativação? Os portões de entrada, saída e esquecimento usam um sigmóide, cuja derivada é no máximo 0,25, e g e h eram tradicionalmente tanh . Como a retropropagação daqueles que não fazem o gradiente desaparecer?


2
O LSTM é um modelo de rede neural recorrente que é muito eficiente para lembrar dependências de longo prazo e que não é vulnerável ao problema do gradiente de fuga. Não tenho a certeza que tipo de explicação que você está procurando
TheWalkingCube

LSTM: memória de curto prazo. (Ref: Hochreiter, S. e Schmidhuber, J. (1997). Memória de curto prazo. Neural Computation 9 (8): 1735-80 · dezembro de 1997)
horaceT

Os gradientes nos LSTMs desaparecem, apenas mais lentamente que nos RNNs de baunilha, permitindo que eles capturem dependências mais distantes. Evitar o problema de desaparecer gradientes ainda é uma área de pesquisa ativa.
Artem Sobolev

11
Importa-se em apoiar o desaparecimento mais lento com uma referência?
bayerj

Respostas:


22

O gradiente de fuga é melhor explicado no caso unidimensional. A multidimensional é mais complicada, mas essencialmente análoga. Você pode revisá-lo neste excelente artigo [1].

Suponha que temos um estado oculto no momento t . Se simplificarmos as coisas e removermos vieses e entradas, teremos h t = σ ( w h t - 1 ) . Então você pode mostrar quehtt

ht=σ(wht1).

O fatorado marcado com !!! é o crucial. Se o peso não for igual a 1, ele decairá para zero exponencialmente rápido emt-tou crescerá exponencialmente rápido.

htht=k=1ttwσ(whtk)=wtt!!!k=1ttσ(whtk)
tt

Nos LSTMs, você tem o estado da célula . O derivado não é da forma s t 'st Aquivté a entrada para o gate de esquecer. Como você pode ver, não há fator de decomposição exponencialmente rápido envolvido. Conseqüentemente, há pelo menos um caminho em que o gradiente não desaparece. Para a derivação completa, consulte [2].

stst=k=1ttσ(vt+k).
vt

[1] Pascanu, Razvan, Tomas Mikolov e Yoshua Bengio. "Sobre a dificuldade de treinar redes neurais recorrentes." ICML (3) 28 (2013): 1310-1318.

[2] Bayer, Justin Simon. Representações da sequência de aprendizado. Diss. München, Technische Universität München, Diss., 2015, 2015.


3
Para lstm, h_t também não depende de h_ {t-1}? O que você quer dizer com seu artigo quando diz que ds_t / d_s {t-1} "é a única parte em que os gradientes fluem ao longo do tempo"?
user3243135

@ user3243135 h_t depende de h_ {t-1}. No entanto, suponha que ds_t / d_s {t-1} seja mantido, mesmo se outros fluxos de gradiente desaparecerem, todo o fluxo de gradiente não desaparecerá. Isso resolve o desaparecimento do gradiente.
soloice

Eu sempre pensei que a questão principal era o termo porque se σ ( z ) é geralmente a derivada de um sigmóide (ou algo com uma derivada menor que 1) que causou o gradiente de fuga com certeza (por exemplo, sigmoides são <1 em magnitude e sua derivada é σ ( x ) = σ ( z ) ( 1 - σ ( z ) )
ttσ(whtk)
σ(z)σ(x)=σ(z)(1σ(z))que é <1 com certeza). Não foi por isso que as ReLUs foram aceitas nas CNNs? Isso é algo que sempre me confundiu com a diferença de como o gradiente de fuga foi abordado nos modelos de feed forward e nos modelos recorrentes. Algum esclarecimento para isso?
Pinóquio

O gradiente do sigmóide também pode se tornar um problema, assumindo uma distribuição de entradas com grande variação e / ou média de 0. No entanto, mesmo se você usar ReLUs, o principal problema persiste: multiplicar repetidamente por uma matriz de pesos (geralmente pequena ) causa gradientes de fuga ou, em alguns casos, onde a regularização não foi adequada, explodindo gradientes.
Ataxias

3

A imagem do bloco LSTM de Greff et al. (2015) descreve uma variante que os autores chamam de LSTM de baunilha . É um pouco diferente da definição original de Hochreiter e Schmidhuber (1997). A definição original não incluía as conexões do gate e do olho mágico.

O termo Carrossel com erro constante foi usado no artigo original para indicar a conexão recorrente do estado da célula. Considere a definição original em que o estado da célula é alterado apenas por adição, quando a porta de entrada é aberta. O gradiente do estado da célula em relação ao estado da célula em uma etapa anterior é zero.

O erro ainda pode entrar no CEC através da porta de saída e da função de ativação. A função de ativação reduz um pouco a magnitude do erro antes de ser adicionado ao CEC. O CEC é o único local em que o erro pode fluir inalterado. Novamente, quando a porta de entrada é aberta, o erro sai através da porta de entrada, função de ativação e transformação afim, reduzindo a magnitude do erro.

Portanto, o erro é reduzido quando é retropropagado por meio de uma camada LSTM, mas somente quando entra e sai do CEC. O importante é que ele não mude no CEC, independentemente da distância percorrida. Isso resolve o problema na RNN básica de que cada etapa do tempo aplica uma transformação afim e não linearidade, significando que quanto maior a distância do tempo entre a entrada e a saída, menor o erro.


2

http://www.felixgers.de/papers/phd.pdf Consulte as seções 2.2 e 3.2.2, onde é explicada a parte do erro truncado. Eles não propagam o erro se vazar na memória da célula (ou seja, se houver uma porta de entrada fechada / ativada), mas eles atualizam os pesos da porta com base no erro apenas naquele instante. Mais tarde, é zerado durante a propagação posterior. Isso é meio que um hack, mas o motivo é que o fluxo de erros ao longo dos portões diminui com o tempo.


7
Você poderia expandir um pouco sobre isso? No momento, a resposta não terá valor se o local do link mudar ou se o papel for colocado offline. No mínimo, ajudaria a fornecer uma citação completa (referência) que permita que o artigo seja encontrado novamente se o link parar de funcionar, mas seria melhor um pequeno resumo que torne essa resposta independente.
precisa

2

Gostaria de acrescentar alguns detalhes à resposta aceita, porque acho que é um pouco mais sutil e a nuance pode não ser óbvia para alguém que está aprendendo sobre RNNs pela primeira vez.

htht=k=1ttwσ(whtk)

stst=k=1ttσ(vt+k)

  • tt
  • a resposta é sim , e é por isso que o LSTM também sofrerá com gradientes de fuga, mas não tanto quanto o RNN de baunilha

wσ()σ()

σ()1
vt+k=wxwxw

x=1w=10 vt+k=10σ()=0.99995 , ou o gradiente morre como:

(0.99995)tt

For the vanilla RNN, there is no set of weights which can be learned such that

wσ(whtk)1

e.g. In the 1D case, suppose htk=1. The function wσ(w1) achieves a maximum of 0.224 at w=1.5434. This means the gradient will decay as,

(0.224)tt

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.