Atenção é um método para agregar um conjunto de vetores vi em apenas um vetor, geralmente por meio do vetor de pesquisa u . Geralmente, vi são as entradas para o modelo ou os estados ocultos das etapas de tempo anteriores ou os estados ocultos um nível abaixo (no caso de LSTMs empilhados).
O resultado costuma ser chamado de vetor de contexto c , pois contém o contexto relevante para o atual intervalo de tempo.
Esse vetor de contexto adicional c é alimentado no RNN / LSTM (pode ser simplesmente concatenado com a entrada original). Portanto, o contexto pode ser usado para ajudar na previsão.
A maneira mais simples de fazer isso é a computação probabilidade vector p=softmax(VTu) e c=∑ipivi , onde V é a concatenação de todos os anteriores vi . Um vetor de pesquisa comumu é o estado oculto atualht .
Existem muitas variações nisso, e você pode tornar as coisas tão complicadas quanto desejar. Por exemplo, em vez de usar vTiu como os logits, pode-se escolher f(vi,u) em vez disso, ondef é uma rede neural arbitrária.
Um mecanismo de atenção comum para modelos de sequência a sequência usa p=softmax(qTtanh(W1vi+W2ht)) , onde v são os estados ocultos do codificador e ht é o estado oculto atual do decodificador. q e ambos os W s são parâmetros.
Alguns trabalhos que mostram diferentes variações na idéia de atenção:
As redes de ponteiros prestam atenção às entradas de referência para resolver problemas de otimização combinatória.
As redes de entidades recorrentes mantêm estados de memória separados para diferentes entidades (pessoas / objetos) durante a leitura de texto e atualizam o estado correto da memória usando atenção.
Os modelos de transformadores também fazem uso extensivo de atenção. A sua formulação de atenção é ligeiramente mais geral e também envolve vectores chave ki : os pesos atenção p são, na verdade, calculado entre as chaves e a pesquisa, e o contexto é então construído com a vi .
Aqui está uma rápida implementação de uma forma de atenção, embora eu não possa garantir a correção além do fato de ter passado em alguns testes simples.
RNN básico:
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
Com atenção, adicionamos apenas algumas linhas antes que o novo estado oculto seja calculado:
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
o código completo