Estou um pouco surpreso que ninguém tenha mencionado a principal (e única) razão para o aviso dado! Como parece, esse código deve implementar a variante generalizada da função Bump; no entanto, basta dar uma olhada nas funções implementadas novamente:
def f_True(x):
# Compute Bump Function
bump_value = 1-tf.math.pow(x,2)
bump_value = -tf.math.pow(bump_value,-1)
bump_value = tf.math.exp(bump_value)
return(bump_value)
def f_False(x):
# Compute Bump Function
x_out = 0*x
return(x_out)
O erro é evidente: não há uso do peso treinável da camada nessas funções! Portanto, não é de surpreender que você receba a mensagem dizendo que não existe gradiente para isso: você não o está usando, portanto, não há gradiente para atualizá-lo! Pelo contrário, esta é exatamente a função Bump original (ou seja, sem peso treinável).
Mas, você pode dizer que: "pelo menos, usei o peso treinável na condição de tf.cond
, então deve haver alguns gradientes ?!"; no entanto, não é assim e deixe-me esclarecer a confusão:
Primeiro de tudo, como você também notou, estamos interessados no condicionamento por elementos. Então, em vez de tf.cond
você precisar usar tf.where
.
O outro equívoco é afirmar que, uma vez que tf.less
é usado como condição, e como não é diferenciável, ou seja, não possui gradiente em relação a suas entradas (o que é verdade: não há gradiente definido para uma função com saída booleana, que é real. entradas valiosas!), então isso resulta no aviso dado!
- Isso é simplesmente errado! A derivada aqui seria retirada da saída do peso treinável da camada e a condição de seleção NÃO está presente na saída. Pelo contrário, é apenas um tensor booleano que determina o ramo de saída a ser selecionado. É isso aí! A derivada da condição não é aceita e nunca será necessária. Portanto, esse não é o motivo do aviso dado; a razão é única e apenas o que mencionei acima: nenhuma contribuição do peso treinável na saída da camada. (Nota: se o ponto sobre a condição é um pouco surpreendente para você, pense em um exemplo simples: a função ReLU, que é definida como
relu(x) = 0 if x < 0 else x
. Se a derivada da condição, ou seja,x < 0
, é considerado / necessário, o que não existe, então não poderíamos usar ReLU em nossos modelos e treiná-los usando métodos de otimização baseados em gradiente!)
(Nota: a partir daqui, eu me referiria e denotaria o valor do limiar como sigma , como na equação).
Tudo certo! Encontramos a razão por trás do erro na implementação. Podemos consertar isso? Claro! Aqui está a implementação de trabalho atualizada:
import tensorflow as tf
from tensorflow.keras.initializers import RandomUniform
from tensorflow.keras.constraints import NonNeg
class BumpLayer(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(BumpLayer, self).__init__(*args, **kwargs)
def build(self, input_shape):
self.sigma = self.add_weight(
name='sigma',
shape=[1],
initializer=RandomUniform(minval=0.0, maxval=0.1),
trainable=True,
constraint=tf.keras.constraints.NonNeg()
)
super().build(input_shape)
def bump_function(self, x):
return tf.math.exp(-self.sigma / (self.sigma - tf.math.pow(x, 2)))
def call(self, inputs):
greater = tf.math.greater(inputs, -self.sigma)
less = tf.math.less(inputs, self.sigma)
condition = tf.logical_and(greater, less)
output = tf.where(
condition,
self.bump_function(inputs),
0.0
)
return output
Alguns pontos em relação a esta implementação:
Nós substituímos tf.cond
por tf.where
para fazer o condicionamento por elementos.
Além disso, como você pode ver, ao contrário de sua implementação que só verificado para um dos lados da desigualdade, estamos usando tf.math.less
, tf.math.greater
e também tf.logical_and
para descobrir se os valores de entrada têm magnitudes inferior a sigma
(alternativamente, poderíamos fazer isso usando apenas tf.math.abs
e tf.math.less
, sem diferença !). E vamos repetir: usar funções de saída booleana dessa maneira não causa problemas e não tem nada a ver com derivadas / gradientes.
Também estamos usando uma restrição de não negatividade no valor sigma aprendido por camada. Por quê? Como os valores de sigma menores que zero não fazem sentido (ou seja, o intervalo (-sigma, sigma)
é mal definido quando o sigma é negativo).
E considerando o ponto anterior, tomamos o cuidado de inicializar o valor sigma corretamente (ou seja, para um pequeno valor não negativo).
E também, por favor, não faça coisas como 0.0 * inputs
! É redundante (e um pouco estranho) e é equivalente a 0.0
; e ambos têm um gradiente de 0.0
(wrt inputs
). A multiplicação de zero por um tensor não adiciona nada ou resolve qualquer problema existente, pelo menos não neste caso!
Agora, vamos testá-lo para ver como funciona. Escrevemos algumas funções auxiliares para gerar dados de treinamento com base em um valor sigma fixo e também para criar um modelo que contém um único BumpLayer
com formato de entrada de (1,)
. Vamos ver se ele pode aprender o valor sigma usado para gerar dados de treinamento:
import numpy as np
def generate_data(sigma, min_x=-1, max_x=1, shape=(100000,1)):
assert sigma >= 0, 'Sigma should be non-negative!'
x = np.random.uniform(min_x, max_x, size=shape)
xp2 = np.power(x, 2)
condition = np.logical_and(x < sigma, x > -sigma)
y = np.where(condition, np.exp(-sigma / (sigma - xp2)), 0.0)
dy = np.where(condition, xp2 * y / np.power((sigma - xp2), 2), 0)
return x, y, dy
def make_model(input_shape=(1,)):
model = tf.keras.Sequential()
model.add(BumpLayer(input_shape=input_shape))
model.compile(loss='mse', optimizer='adam')
return model
# Generate training data using a fixed sigma value.
sigma = 0.5
x, y, _ = generate_data(sigma=sigma, min_x=-0.1, max_x=0.1)
model = make_model()
# Store initial value of sigma, so that it could be compared after training.
sigma_before = model.layers[0].get_weights()[0][0]
model.fit(x, y, epochs=5)
print('Sigma before training:', sigma_before)
print('Sigma after training:', model.layers[0].get_weights()[0][0])
print('Sigma used for generating data:', sigma)
# Sigma before training: 0.08271004
# Sigma after training: 0.5000002
# Sigma used for generating data: 0.5
Sim, ele pode aprender o valor do sigma usado para gerar dados! Mas, é garantido que ele realmente funciona para todos os diferentes valores de dados de treinamento e inicialização do sigma? A resposta é não! Na verdade, é possível que você execute o código acima e obtenha nan
o valor da sigma após o treinamento ou inf
o valor da perda! Então qual é o problema? Por que isso nan
ou inf
valores podem ser produzidos? Vamos discutir abaixo ...
Lidar com a estabilidade numérica
Uma das coisas importantes a considerar, ao construir um modelo de aprendizado de máquina e usar métodos de otimização baseados em gradiente para treiná-los, é a estabilidade numérica das operações e cálculos em um modelo. Quando valores extremamente grandes ou pequenos são gerados por uma operação ou seu gradiente, quase certamente isso atrapalha o processo de treinamento (por exemplo, essa é uma das razões por trás da normalização dos valores de pixel de imagem nas CNNs para evitar esse problema).
Então, vamos dar uma olhada nessa função de bump generalizada (e vamos descartar o limiar por enquanto). É óbvio que esta função possui singularidades (isto é, pontos em que a função ou seu gradiente não está definido) em x^2 = sigma
(isto é, quando x = sqrt(sigma)
ou x=-sqrt(sigma)
). O diagrama animado abaixo mostra a função bump (a linha vermelha sólida), sua derivada wrt sigma (a linha verde pontilhada) e x=sigma
and x=-sigma
lines (duas linhas verticais tracejadas em azul), quando o sigma começa do zero e é aumentado para 5:
Como você pode ver, em torno da região das singularidades, a função não é bem-comportada para todos os valores de sigma, no sentido de que a função e sua derivada assumem valores extremamente grandes nessas regiões. Assim, dado um valor de entrada nessas regiões para um valor específico de sigma, seriam gerados valores explosivos de saída e gradiente, daí a questão do inf
valor da perda.
Além disso, há um comportamento problemático tf.where
que causa a emissão de nan
valores para a variável sigma na camada: surpreendentemente, se o valor produzido no ramo inativo de tf.where
for extremamente grande ou inf
que, com a função bump, resulta em valores extremamente grandes ou inf
gradientes , então o gradiente de tf.where
seria nan
, apesar do fato de inf
estar no ramo inativo e nem ser selecionado (consulte esta edição do Github que discute exatamente isso) !!
Portanto, existe alguma solução alternativa para esse comportamento de tf.where
? Sim, na verdade, há um truque para resolver de alguma forma esse problema, explicado nesta resposta : basicamente podemos usar um adicional tf.where
para impedir que a função seja aplicada nessas regiões. Em outras palavras, em vez de aplicar self.bump_function
em qualquer valor de entrada, filtramos os valores que NÃO estão no intervalo (-self.sigma, self.sigma)
(ou seja, o intervalo real em que a função deve ser aplicada) e, em vez disso, alimentamos a função com zero (que sempre produz valores seguros, ou seja, é igual a exp(-1)
):
output = tf.where(
condition,
self.bump_function(tf.where(condition, inputs, 0.0)),
0.0
)
A aplicação dessa correção resolveria completamente a questão dos nan
valores para o sigma. Vamos avaliá-lo nos valores dos dados de treinamento gerados com diferentes valores sigma e ver como ele funcionaria:
true_learned_sigma = []
for s in np.arange(0.1, 10.0, 0.1):
model = make_model()
x, y, dy = generate_data(sigma=s, shape=(100000,1))
model.fit(x, y, epochs=3 if s < 1 else (5 if s < 5 else 10), verbose=False)
sigma = model.layers[0].get_weights()[0][0]
true_learned_sigma.append([s, sigma])
print(s, sigma)
# Check if the learned values of sigma
# are actually close to true values of sigma, for all the experiments.
res = np.array(true_learned_sigma)
print(np.allclose(res[:,0], res[:,1], atol=1e-2))
# True
Poderia aprender todos os valores sigma corretamente! Isso é bom. Essa solução funcionou! Embora exista uma ressalva: é garantido que funcione corretamente e aprenda qualquer valor sigma se os valores de entrada para essa camada forem maiores que -1 e menores que 1 (ou seja, este é o caso padrão de nossa generate_data
função); caso contrário, ainda existe a questão do inf
valor da perda que pode ocorrer se os valores de entrada tiverem uma magnitude maior que 1 (consulte os pontos 1 e 2 abaixo).
Aqui estão alguns pensamentos para os curiosos e a mente interessada:
Acabamos de mencionar que, se os valores de entrada para essa camada forem maiores que 1 ou menores que -1, isso poderá causar problemas. Você pode argumentar por que esse é o caso? (Dica: use o diagrama animado acima e considere os casos em que sigma > 1
e o valor de entrada está entre sqrt(sigma)
e sigma
(ou entre -sigma
e -sqrt(sigma)
.)
Você pode fornecer uma correção para o problema no ponto 1, ou seja, para que a camada funcione para todos os valores de entrada? (Dica: como a solução alternativa tf.where
, pense em como você pode filtrar ainda mais os valores não seguros nos quais a função bump pode ser aplicada e produzir saída / gradiente explosivos.)
No entanto, se você não está interessado em corrigir esse problema e gostaria de usar essa camada em um modelo como está agora, como garantir que os valores de entrada nessa camada estejam sempre entre -1 e 1? (Dica: como uma solução, existe uma função de ativação comumente usada que produz valores exatamente nesse intervalo e pode ser potencialmente usada como a função de ativação da camada anterior a essa camada.)
Se você der uma olhada no último trecho de código, verá que usamos epochs=3 if s < 1 else (5 if s < 5 else 10)
. Por que é que? Por que grandes valores de sigma precisam de mais épocas para serem aprendidas? (Dica: novamente, use o diagrama animado e considere a derivada da função para valores de entrada entre -1 e 1 à medida que o valor sigma aumenta. Qual é a magnitude deles?)
Você também precisa verificar os dados de treinamento gerados para qualquer nan
, inf
ou extremamente grandes valores de y
e filtrá-los? (Dica: sim, se sigma > 1
e faixa de valores, ou seja, min_x
e max_x
, cair fora do (-1, 1)
!?!, Caso contrário, não que não é necessário Por que isso é deixado como um exercício)
input
? é um escalar?