Qual é a intuição por trás de uma rede neural recorrente de Long Short Term Memory (LSTM)?


11

A ideia por trás da Rede Neural Recorrente (RNN) é clara para mim. Entendo da seguinte maneira:
Temos uma sequência de observações ( ) (ou, em outras palavras, séries temporais multivariadas). Cada observação única é um vetor numérico dimensional. No modelo RNN, assumimos que a próxima observação é uma função da observação anterior , bem como do "estado oculto" anterior , onde estados ocultos também são representados por vetores (as dimensões dos estados observados e ocultos podem ser diferentes). Também se supõe que os próprios estados ocultos dependam da observação anterior e do estado oculto:o1,o2,,onoiNoi+1oihi

oi,hi=F(oi1,hi1)

Finalmente, no modelo RNN, a função é assumida como uma rede neural. Nós treinamos (ajustamos) a rede neural usando os dados disponíveis (uma sequência de observações). Nosso objetivo no treinamento é poder prever a próxima observação com a maior precisão possível, usando as observações anteriores.F

Agora, a rede LSTM é uma modificação da rede RNN. Até onde eu entendi, a motivação por trás do LSTM é resolver o problema de memória curta que é peculiar à RNN (a RNN convencional tem problemas em relacionar eventos que são separados demais no tempo).

Entendo como as redes LSTM funcionam. Aqui está a melhor explicação do LSTM que eu encontrei. A ideia básica é a seguinte:

Além do vetor de estado oculto, introduzimos um vetor chamado "estado da célula" que tem o mesmo tamanho (dimensionalidade) que o vetor de estado oculto ( ). Eu acho que o vetor "estado da célula" é introduzido para modelar a memória de longo prazo. Como no caso da RNN convencional, a rede LSTM obtém o estado observado e oculto como entrada. Usando essa entrada, calculamos um novo "estado da célula" da seguinte maneira:ci

ci+1=ω1(oi,hi)ci+ω2(oi,hi)cint(oi,hi),

onde as funções de , e são modeladas por redes neurais. Para simplificar a expressão, basta remover os argumentos:ω1ω2cint

ci+1=ω1ci+ω2cint

Portanto, podemos ver que o novo "vetor de estado da célula" ( ) é uma soma ponderada do vetor de estado antigo ( ) e um vetor de estado da célula "intermediário" ( ). A multiplicação entre os vetores é em termos de componentes (multiplicamos dois vetores dimensionais N e obtemos, como resultado, outro vetor dimensional N). Em outras palavras, misturamos dois vetores de estados celulares (o antigo e o intermediário) usando pesos específicos do componente.cici1cint

Aqui está a intuição entre as operações descritas. O vetor de estado da célula pode ser interpretado como um vetor de memória. O segundo vetor de pesos (calculado por uma rede neural) é um portão "manter" (ou esquecer). Seus valores decidem se mantemos ou esquecemos (apagamos) um valor correspondente do vetor de estado da célula (ou vetor de memória de longo prazo). O primeiro vetor de pesos ( ), calculado por outra rede neural, é chamado de porta "write" ou "memorize". Ele decide se uma nova memória (o vetor de estado da célula "intermediário") deve ser salva (ou mais precisamente, se um componente específico dela deve ser salvo / gravado). O "intermediário"ω2ω1ω1vetor). Na verdade, seria mais preciso dizer que, com os dois vetores de pesos ( e ), "misturamos" a memória antiga e a nova.ω1ω2

Assim, após a mistura descrita acima (ou esquecimento e memorização), temos um novo vetor de estado celular. Em seguida, calculamos um estado oculto "intermediário" usando outra rede neural (como antes, usamos o estado observado e o estado oculto como entradas). Finalmente, combinamos o novo estado da célula (memória) com o estado oculto "intermediário" ( ) para obter o novo estado oculto (ou "final") que realmente produzimos:oEuhEuhEunt

hEu+1 1=hEuntS(cEu+1 1),

onde é uma função sigmóide aplicada a cada componente do vetor de estado celular.S

Então, minha pergunta é: por que (ou exatamente) essa arquitetura resolve o problema?

Em particular, eu não entendo o seguinte:

  1. Usamos uma rede neural para gerar memória "intermediária" (vetor de estado da célula) que é misturada com a memória "antiga" (ou estado da célula) para obter uma "nova" memória (estado da célula). Os fatores de ponderação para a mistura também são calculados por redes neurais. Mas por que não podemos usar apenas uma rede neural para calcular o "novo" estado da célula (ou memória). Ou, em outras palavras, por que não podemos usar o estado observado, o estado oculto e a memória antiga como entradas para uma rede neural que calcula a "nova" memória?
  2. No final, usamos os estados observados e ocultos para calcular um novo estado oculto e, em seguida, usamos o "novo" estado da célula (ou memória (de longo prazo)) para corrigir o componente do recém-calculado estado oculto. Em outras palavras, os componentes do estado da célula são usados ​​exatamente como pesos que apenas reduzem os componentes correspondentes do estado oculto calculado. Mas por que o vetor de estado celular é usado dessa maneira específica? Por que não podemos calcular o novo estado oculto colocando o vetor de estado da célula (memória de longo prazo) na entrada de uma rede neural (que também leva os estados observados e ocultos como entrada)?

Adicionado:

Aqui está um vídeo que pode ajudar a esclarecer como diferentes portais ("manter", "escrever" e "ler") são organizados.


11
Parece que você entende os LSTMs melhor do que eu, por isso não postarei uma resposta real, pois pode não ser o que você está procurando: os portões do LSTM (principalmente os portões do esquecimento) permitem manter ativações e gradientes por tanto tempo como necessário. Portanto, as informações no tempo t podem ser mantidas disponíveis até o tempo t + n, para n arbitrariamente grande.
Rcpinto 5/04

@rcpinto, também acho que a principal idéia por trás da "arquitetura" proposta é permitir manter as informações por muito tempo (muitas etapas). Mas não entendo o que exatamente torna isso possível. As redes dos dois portões ("keep" e "write") podem aprender que os pesos de manutenção devem ser grandes e os de gravação devem ser pequenos (então mantemos a memória por muito tempo). Mas isso não pode ser alcançado apenas por uma rede? Uma rede neural (que usa o estado oculto (memória) e o estado observável como entrada) não pode aprender que o estado oculto deve ser mantido sem alterações?
Roman

Na verdade, basta definir a matriz de peso recorrente para a identidade e ela sempre manterá a última ativação. O problema é sempre a parte, o que significa que novas entradas se acumularão e saturarão rapidamente a ativação do neurônio, o que também é uma espécie de esquecimento. Assim, a capacidade de apagar uma memória anterior ou bloquear a formação de novas memórias é crucial.
Rcpinto 06/04

@rcpinto, mas a "capacidade de apagar uma memória anterior ou bloquear a formação de uma nova" não pode ser alcançada em uma única rede neural? A rede neural obtém o vetor de memória (ou um vetor de estado oculto) e o vetor de estado observado como entrada. Essa rede não pode "decidir" manter ou substituir alguns componentes do estado oculto (sua memória) com base nos valores dos componentes no vetor de estado observado?
Roman

Os LSTMs podem fazer isso porque os portões abrem ou fecham de acordo com a entrada e o estado atuais. Não existe esse conceito em RNNs simples. Nesse caso, o próprio estado sempre reage à entrada / estado direta e imediatamente, impedindo-o de "escolher" se deseja ou não armazenar as novas informações. Além disso, não há mecanismo para apagar alguma memória armazenada, ela sempre se acumula de acordo com os pesos de entrada do neurônio, e esses pesos não podem ser alterados após o treinamento. Como os portões LSTM são multiplicativos, eles simulam a alteração de peso durante a inferência, reagindo à entrada / estado atual.
Rcpinto 06/04

Respostas:


1

Pelo que entendi, o que você imagina é basicamente concatenar a entrada, o estado oculto anterior e o estado anterior da célula e passá-los por uma ou várias camadas totalmente conectadas para calcular o estado oculto da saída e o estado da célula, em vez de calcular independentemente "fechado" "atualizações que interagem aritmeticamente com o estado da célula. Isso basicamente criaria uma RNN regular que produzia apenas parte do estado oculto.

A principal razão para não fazer isso é que a estrutura dos cálculos do estado celular do LSTM garante fluxo constante de erros através de longas sequências . Se você usou pesos para calcular diretamente o estado da célula, seria necessário retropropagá-los a cada passo! Evitar essas operações resolve em grande parte gradientes de fuga / explosão que, de outra forma, afetam os RNNs.

Além disso, a capacidade de reter informações facilmente em períodos mais longos é um bônus interessante. Intuitivamente, seria muito mais difícil para a rede aprender do zero para preservar o estado da célula em períodos mais longos.

Vale notar que a alternativa mais comum ao LSTM, o GRU , calcula similarmente as atualizações de estado oculto sem aprender pesos que operam diretamente no próprio estado oculto.


0

Se entendi corretamente, as duas perguntas se resumem a isso. Dois lugares em que usamos tanh e sigmoid para processar as informações. Em vez disso, devemos usar uma única rede neural que capte todas as informações.

Não conheço as desvantagens de usar uma única rede neural. Na minha opinião, podemos usar uma única rede neural com não linearidade sigmóide que aprende corretamente o vetor que será usado adequadamente (adicionado no estado da célula no primeiro caso ou passado como estado oculto no segundo caso).

No entanto, da maneira como estamos fazendo isso agora, estamos dividindo a tarefa em duas partes, uma parte que usa a não linearidade sigmóide para aprender a quantidade de dados a serem mantidos. A outra parte que usa tanh como não linearidade está apenas realizando a tarefa de aprender as informações importantes.

Em termos simples, o sigmoid aprende quanto economizar e tanh aprende o que salvar e quebrá-lo em duas partes facilitará o treinamento.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.