A convergência do GradienTape é muito mais lenta que o Keras.model.fit


8

Atualmente, estou tentando obter a API TF2.0 , mas como eu comparei o GradientTape com um keras.Model.fit normal , notei:

  1. A execução foi mais lenta (provavelmente devido à Execução Ansiosa)

  2. Ele convergiu muito mais devagar (e não sei por que).

+--------+--------------+--------------+------------------+
|  Epoch | GradientTape | GradientTape | keras.Model.fit  |
|        |              |  shuffling   |                  |
+--------+--------------+--------------+------------------+
|    1   |     0.905    |     0.918    |      0.8793      |
+--------+--------------+--------------+------------------+
|    2   |     0.352    |     0.634    |      0.2226      |
+--------+--------------+--------------+------------------+
|    3   |     0.285    |     0.518    |      0.1192      |
+--------+--------------+--------------+------------------+
|    4   |     0.282    |     0.458    |      0.1029      |
+--------+--------------+--------------+------------------+
|    5   |     0.275    |     0.421    |      0.0940      |
+--------+--------------+--------------+------------------+

Aqui está o ciclo de treinamento que usei com o GradientTape :


optimizer = keras.optimizers.Adam()
glove_model = GloveModel(vocab_size=len(labels))
train_loss = keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(examples, labels):
    with tf.GradientTape() as tape:
        predictions = glove_model(examples)
        loss = glove_model.glove_loss(labels, predictions)

    gradients = tape.gradient(loss, glove_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, glove_model.trainable_variables))

    train_loss(loss)



total_step = 0
for epoch in range(epochs_number):

    pbar = tqdm(train_ds.enumerate(), total=int(len(index_data) / batch_size) + 1)

    for ix, (examples, labels) in pbar:

        train_step(examples, labels)


    print(f"Epoch {epoch + 1}, Loss {train_loss.result()}")

    # Reset the metrics for the next epoch
    train_loss.reset_states()

E aqui está o Keras.Model.fit treinamento :

glove_model.compile(optimizer, glove_model.glove_loss)
glove_model.fit(train_ds, epochs=epochs_number)

Aqui está o tf.data.Dataset fonte

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
).shuffle(100000).batch(batch_size, drop_remainder=True)

E aqui está o modelo.

class GloveModel(keras.Model):

    def __init__(self, vocab_size, dim=100, a=3/4, x_max=100):
        super(GloveModel, self).__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.a = a
        self.x_max = x_max

        self.target_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="target_embedding"
        )
        self.target_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="target_bias"
        )

        self.context_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="context_embedding"
        )
        self.context_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="context_bias"
        )

        self.dot_product = layers.Dot(axes=-1, name="dot")

        self.prediction = layers.Add(name="add")
        self.step = 0

    def call(self, inputs):

        target_ix = inputs[:, 0]
        context_ix = inputs[:, 1]

        target_embedding = self.target_embedding(target_ix)
        target_bias = self.target_bias(target_ix)

        context_embedding = self.context_embedding(context_ix)
        context_bias = self.context_bias(context_ix)

        dot_product = self.dot_product([target_embedding, context_embedding])
        prediction = self.prediction([dot_product, target_bias, context_bias])

        return prediction

    def glove_loss(self, y_true, y_pred):

        weight = tf.math.minimum(
            tf.math.pow(y_true/self.x_max, self.a), 1.0
        )
        loss_value = tf.math.reduce_mean(weight * tf.math.pow(y_pred - tf.math.log(y_true), 2.0))

        return loss_value


Tentei várias configurações e otimizadores, mas nada parece alterar a taxa de convergência.


1
Uma coisa a considerar é a baralhamento dos dados antes de cada época.
THN

Eu tenho exatamente o mesmo embaralhamento entre o método de ajuste e o GradientTape porque eu uso a API tf.Data.
Benjamin Breton

1
Eu acho que eles não são exatamente iguais. Você pode mostrar o código do seu tfds? Observe que o .fitpadrão de keras é o embaralhamento antes de cada época. Você pode testar desativando a reprodução aleatória em keras e comparar sua taxa de convergência.
THN

@THN vou enviá-lo para você, mas eu já faço um shuffle com a API tf.Dataset para que não mude nada, certo?
Benjamin Breton

@THN Eu adicionei o tf.data.Dataset
Benjamin Breton

Respostas:


2

Dataset.shuffle()embaralhe cada minibatch apenas para que cada época tenha a mesma ordem. Keras .fit()usa algumas mágicas para embaralhar todo o conjunto de dados antes de cada época. Para fazer isso no TF, você precisa usar o Dataset .repeat(epochs_number)e .shuffle(..., reshuffle_each_iteration=True):

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
    ).shuffle(100000, reshuffle_each_iteration=True
    ).batch(batch_size, drop_remainder=True
    ).repeat(epochs_number)

for ix, (examples, labels) in train_ds.enumerate():
    train_step(examples, labels)
    current_epoch = ix // (len(index_data) // batch_size)

Esta solução alternativa não é bonita nem natural. No momento, você pode usá-la para embaralhar cada época. É um problema conhecido e será corrigido, no futuro, você pode usar for epoch in range(epochs_number)em vez de .repeat().


Desculpe, eu adicionei seu código, mas a convergência é ainda mais lenta. Adicionei os resultados na coluna GradientTape shuffle. Não faz sentido para mim ...
Benjamin Breton

@BenjaminBreton Neste momento, duvido que haja outros erros ocultos em seu código. Talvez seja melhor vincular seu repositório para mostrar o código completo. Se você tiver certeza de que suas experiências foram conduzidas corretamente, abra um problema no repositório de tensorflow.
THN

Muito obrigado pela sua ajuda @THN Publiquei o problema no repositório TF2.0 TF2.0 github.com/tensorflow/tensorflow/issues/33898 . Vou tentar reproduzir o erro com um modelo diferente.
Benjamin Breton

1
Acontece que você estava certo @THN Eu embaralhei usando numpy e resolveu o problema. Vou postar uma resposta abrangente
Benjamin Breton

0

O problema veio do embaralhamento usando o método tf.Dataset . Ele passou pelo conjunto de dados apenas um depósito por vez. O uso do Keras.Model.fit produziu melhores resultados porque provavelmente adiciona outro embaralhamento.

Eu adicionei um embaralhamento numpy.random.shufflee melhorou o desempenho com os dois métodos de treinamento:

A geração do conjunto de dados é agora:

numpy_data = np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1), index_data.reshape(-1, 1)])

np.random.shuffle(numpy_data)

indexes = np.array(numpy_data[:, :2], dtype=np.uint32)
labels = np.array(numpy_data[:, 2].reshape(-1, 1), dtype=np.float32)

train_ds = data.Dataset.from_tensor_slices(
    (indexes, labels)
).shuffle(100000).batch(batch_size, drop_remainder=True)

E os resultados são:

+--------+--------------+------------------+
|  Epoch | GradientTape |  keras.Model.fit |
+--------+--------------+------------------+
|    1   |     0.294    |      0.294       |
+--------+--------------+------------------+
|    2   |     0.111    |      0.110       |
+--------+--------------+------------------+
|    3   |     0.089    |      0.089       |
+--------+--------------+------------------+
|    4   |     0.074    |      0.075       |
+--------+--------------+------------------+
|    5   |     0.063    |      0.063       |
+--------+--------------+------------------+

O tipo de treinamento por época é aproximadamente o mesmo em minutos por época .

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.