Tensorflow: como salvar / restaurar um modelo?


552

Depois de treinar um modelo no Tensorflow:

  1. Como você salva o modelo treinado?
  2. Como você mais tarde restaura esse modelo salvo?

Você conseguiu restaurar as variáveis ​​usadas no modelo de criação? Eu também estou tentando exatamente o mesmo problema, mas eu sou incapaz de set de gravação de variáveis que foram utilizadas durante o treinamento do modelo de criação (do qual tenho arquivo ckpt)
exAres

Eu não tentei com o modelo inicial. Você tem a estrutura de rede do modelo com seus nomes? Você precisa replicar a rede e carregar os pesos e preconceitos (o arquivo ckpt), como Ryan explica. Talvez alguma coisa mudou desde Nov'15 e há uma abordagem mais simples agora, eu não tenho certeza
Mathetes

Oh tudo bem. Carreguei outros modelos pré-treinados de tensorflow anteriormente, mas estava procurando especificações variáveis ​​do modelo inicial. Obrigado.
exAres

1
Se você restaurar para continuar treinando, use os pontos de verificação do Saver. Se você salvar o modelo para fazer referência, apenas as APIs SavedModel do tensorflow.
HY G

Além disso, se você estiver usando LSTM, terá um mapa da cadeia de caracteres para uma lista de caracteres, salve e carregue essa lista na mesma ordem! Isso não é coberto ao salvar os pesos do modelo e a rede de gráficos do modelo e parecerá que seu modelo não foi carregado quando você altera as sessões ou os dados são alterados.
Devssh # 26/18

Respostas:


119

Documentos

Dos documentos:

Salve 

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Restaurar

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Tensorflow 2

Ainda é beta, por isso não aconselho por enquanto. Se você ainda quiser seguir esse caminho, aqui está o tf.saved_modelguia de uso

Tensorflow <2

simple_save

Muitas boas respostas, para completar, adicionarei meus 2 centavos: simple_save . Também um exemplo de código independente usando a tf.data.DatasetAPI.

Python 3; Tensorflow 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Restaurando:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Exemplo autônomo

Postagem original do blog

O código a seguir gera dados aleatórios para a demonstração.

  1. Começamos criando os espaços reservados. Eles manterão os dados em tempo de execução. A partir deles, criamos o Datasete então é Iterator. Nós obtemos o tensor gerado pelo iterador, chamado input_tensorque servirá como entrada para o nosso modelo.
  2. O modelo em si é construído a partir de input_tensor: um RNN bidirecional baseado em GRU, seguido por um classificador denso. Porque porque não?
  3. A perda é softmax_cross_entropy_with_logitsotimizada com Adam. Após 2 épocas (de 2 lotes cada), salvamos o modelo "treinado" tf.saved_model.simple_save. Se você executar o código como está, o modelo será salvo em uma pasta chamada simple/no seu diretório de trabalho atual.
  4. Em um novo gráfico, restauramos o modelo salvo com tf.saved_model.loader.load. Pegamos os espaços reservados e logits com graph.get_tensor_by_namee a Iteratoroperação de inicialização com graph.get_operation_by_name.
  5. Por fim, executamos uma inferência para ambos os lotes no conjunto de dados e verificamos se o modelo salvo e restaurado produz os mesmos valores. Eles fazem!

Código:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

Isso imprimirá:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
Sou iniciante e preciso de mais explicações ...: Se eu tiver um modelo CNN, devo armazenar apenas 1. inputs_placeholder 2. labels_placeholder e 3. output_of_cnn? Ou todo o intermediário tf.contrib.layers?
Chovendo

2
O gráfico é totalmente restaurado. Você pode verificar em execução [n.name for n in graph2.as_graph_def().node]. Como a documentação diz, o save simples visa simplificar a interação com a veiculação do tensorflow, esse é o objetivo dos argumentos; outras variáveis ​​ainda são restauradas, caso contrário, a inferência não aconteceria. Basta pegar suas variáveis ​​de interesse, como fiz no exemplo. Verifique a documentação
ted

@ted quando eu usaria tf.saved_model.simple_save vs tf.train.Saver ()? Pela minha intuição, eu usaria tf.train.Saver () durante o treinamento e armazenaria diferentes momentos no tempo. Eu usaria tf.saved_model.simple_save quando o treinamento fosse concluído para uso na produção. (Eu pedi também o mesmo em um comentário aqui )
loco.loop

1
Bom eu acho, mas também funciona com os modelos do modo Eager e tfe.Saver?
Geoffrey Anderson

1
sem global_stepcomo argumento, se você parar e tentar retomar o treinamento, ele pensará que você é um passo. Vai estragar suas visualizações tensorboard no mínimo
Monica Heddneck

252

Estou melhorando minha resposta para adicionar mais detalhes para salvar e restaurar modelos.

Na (e depois) versão 0.11 do Tensorflow :

Salve o modelo:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Restaure o modelo:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

Este e alguns casos de uso mais avançados foram explicados muito bem aqui.

Um tutorial completo rápido para salvar e restaurar modelos do Tensorflow


3
+1 para este # O acesso às variáveis ​​salvas é impresso diretamente (sess.run ('viés: 0')) # Isso imprimirá 2, que é o valor do viés que salvamos. Isso ajuda muito para fins de depuração para ver se o modelo está carregado corretamente. as variáveis ​​podem ser obtidas com "All_varaibles = tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES". Além disso, "sess.run (tf.global_variables_initializer ())" deve ser anterior à restauração.
LGG

1
Tem certeza de que precisamos executar global_variables_initializer novamente? Eu restaurei meu gráfico com global_variable_initialization, e isso me dá uma saída diferente toda vez nos mesmos dados. Então, comentei a inicialização e apenas restaurei o gráfico, a variável de entrada e as operações, e agora funciona bem.
Aditya Shinde

@AdityaShinde Não entendo por que sempre obtenho valores diferentes todas as vezes. E não incluí a etapa de inicialização variável para restauração. Estou usando meu próprio código btw.
Chaine

@ AdityaShinde: você não precisa do init op, pois os valores já foram inicializados pela função de restauração, então a removeu. No entanto, não sei por que você obteve uma saída diferente usando o init op.
sankit

5
@sankit Quando você restaura os tensores, por que você adiciona :0os nomes?
Sahar Rabinoviz

177

Na versão 0.11.0RC1 (e depois) do TensorFlow, você pode salvar e restaurar seu modelo diretamente ligando tf.train.export_meta_graphe de tf.train.import_meta_graphacordo com https://www.tensorflow.org/programmers_guide/meta_graph .

Salve o modelo

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Restaurar o modelo

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

4
como carregar variáveis ​​do modelo salvo? Como copiar valores em alguma outra variável?
Neel

9
Não consigo obter esse código funcionando. O modelo é salvo, mas não consigo restaurá-lo. Está me dando esse erro. <built-in function TF_Run> returned a result with an error set
Saad Qureshi

2
Quando, após a restauração, eu acesso as variáveis ​​como mostrado acima, ele funciona. Mas não consigo obter as variáveis ​​mais diretamente usando tf.get_variable_scope().reuse_variables()seguido por var = tf.get_variable("varname"). Isso me dá o erro: "ValueError: variável varname não existe ou não foi criado com tf.get_variable ()." Por quê? Isso não deveria ser possível?
Johann Petrak

4
Isso funciona bem apenas para variáveis, mas como você pode acessar um espaço reservado e fornecer valores a ele após restaurar o gráfico?
Kbrose

11
Isso mostra apenas como restaurar as variáveis. Como você pode restaurar o modelo inteiro e testá-lo em novos dados sem redefinir a rede?
Chaine

127

Para a versão TensorFlow <0.11.0RC1:

Os pontos de verificação salvos contêm valores para os Variables no seu modelo, não o modelo / gráfico em si, o que significa que o gráfico deve ser o mesmo quando você restaurar o ponto de verificação.

Aqui está um exemplo de regressão linear em que há um ciclo de treinamento que salva pontos de verificação variáveis ​​e uma seção de avaliação que restaura as variáveis ​​salvas em uma execução anterior e prediz a computação. Obviamente, você também pode restaurar variáveis ​​e continuar o treinamento, se desejar.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Aqui estão os documentos para Variables, que abrangem salvar e restaurar. E aqui estão os documentos para o Saver.


1
FLAGS são definidos pelo usuário. Aqui está um exemplo de defini-los: github.com/tensorflow/tensorflow/blob/master/tensorflow/...
Ryan Sepassi

em que formato batch_xprecisa estar? Binário? Matriz numpy?
pepe

@pepe Array numpy deve estar bem. E o tipo do elemento deve corresponder ao tipo do espaço reservado. [link] tensorflow.org/versions/r0.9/api_docs/python/…
Donny

FLAGS dá erro undefined. Você pode me dizer qual é def de FLAGS para este código. @RyanSepassi
Muhammad Hannan

Para torná-lo explícito: As versões recentes do Tensorflow que permitem armazenar o modelo / gráfico. [Não ficou claro para mim, quais aspectos da resposta se aplicam à restrição <0,11. . Dado o grande número de upvotes eu estava tentado a acreditar que esta afirmação geral ainda é verdade para as versões recentes]
bluenote10

78

Meu ambiente: Python 3.6, Tensorflow 1.3.0

Embora tenha havido muitas soluções, a maioria delas é baseada tf.train.Saver. Quando carregar um .ckptsalvo por Saver, temos de redefinir tanto a rede tensorflow ou usar algum nome estranho e hard-lembrado, por exemplo 'placehold_0:0', 'dense/Adam/Weight:0'. Aqui eu recomendo usar tf.saved_model, um exemplo mais simples dado abaixo, você pode aprender mais sobre Servindo um Modelo TensorFlow :

Salve o modelo:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Carregue o modelo:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

4
+1 para um ótimo exemplo da API SavedModel. No entanto, gostaria que sua seção Salvar o modelo mostrasse um ciclo de treinamento como a resposta de Ryan Sepassi! Sei que essa é uma pergunta antiga, mas essa resposta é um dos poucos (e valiosos) exemplos de SavedModel que encontrei no Google.
Dylan F

@ Tom esta é uma ótima resposta - apenas uma voltada para o novo SavedModel. Você poderia dar uma olhada nesta pergunta SavedModel? stackoverflow.com/questions/48540744/…
bluesummers

Agora faça com que tudo funcione corretamente com os modelos TF Eager. O Google aconselhou em sua apresentação em 2018 que todos se afastassem do código gráfico TF.
Geoffrey Anderson

55

Existem duas partes no modelo, a definição do modelo, salva Supervisorcomo graph.pbtxtno diretório do modelo e os valores numéricos dos tensores, salvos em arquivos de ponto de verificação como model.ckpt-1003418.

A definição do modelo pode ser restaurada usando tf.import_graph_defe os pesos são restaurados usando Saver.

No entanto, Saverusa uma lista especial de retenção de variáveis ​​anexadas ao modelo Graph, e essa coleção não é inicializada usando import_graph_def; portanto, você não pode usar as duas juntas no momento (está em nosso roteiro para corrigir). Por enquanto, você precisa usar a abordagem de Ryan Sepassi - construa manualmente um gráfico com nomes de nó idênticos e use Saverpara carregar os pesos nele.

(Como alternativa, você pode cortá-lo usando import_graph_def, criando variáveis ​​manualmente e usando tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)para cada variável e depois usando Saver)


No exemplo classify_image.py que usa inceptionv3, apenas o graphdef é carregado. Isso significa que agora o GraphDef também contém a variável?
Jrabary

1
@jrabary O modelo provavelmente foi congelado .
precisa

1
Ei, eu sou novo no tensorflow e estou tendo problemas para salvar meu modelo. Eu realmente aprecio isso se você poderia me ajudar a stackoverflow.com/questions/48083474/...
Ruchir Baronia

39

Você também pode seguir esse caminho mais fácil.

Etapa 1: inicialize todas as suas variáveis

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Etapa 2: salve a sessão dentro do modelo Savere salve-a

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Etapa 3: restaurar o modelo

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Etapa 4: verifique sua variável

W1 = session.run(W1)
print(W1)

Durante a execução em diferentes instâncias python, use

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

Olá, Como posso salvar o modelo depois de supor 3000 iterações, semelhantes ao Caffe. Descobri que o tensorflow salva apenas os últimos modelos, apesar de concatenar o número da iteração com o modelo para diferenciá-lo entre todas as iterações. Quero dizer model_3000.ckpt, model_6000.ckpt, --- model_100000.ckpt. Você pode explicar por que ele não salva tudo, mas salva apenas as últimas 3 iterações.
Khan


3
Existe um método para obter todas as variáveis ​​/ nomes de operações salvos no gráfico?
Moondra

21

Na maioria dos casos, salvar e restaurar do disco usando a tf.train.Saveré a sua melhor opção:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Você também pode salvar / restaurar a própria estrutura do gráfico (consulte a documentação do MetaGraph para obter detalhes). Por padrão, Saversalva a estrutura do gráfico em um .metaarquivo. Você pode ligar import_meta_graph()para restaurá-lo. Restaura a estrutura do gráfico e retorna um Saverque você pode usar para restaurar o estado do modelo:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

No entanto, há casos em que você precisa de algo muito mais rápido. Por exemplo, se você implementar uma parada antecipada, deseje salvar os pontos de verificação sempre que o modelo melhorar durante o treinamento (conforme medido no conjunto de validação); se não houver progresso por algum tempo, será necessário reverter para o melhor modelo. Se você salvar o modelo em disco toda vez que ele melhorar, ele reduzirá tremendamente o treinamento. O truque é salvar os estados das variáveis ​​na memória e restaurá-los mais tarde:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

Uma explicação rápida: quando você cria uma variável X, o TensorFlow cria automaticamente uma operação de atribuição X/Assignpara definir o valor inicial da variável. Em vez de criar espaços reservados e operações extras de atribuição (o que deixaria o gráfico confuso), apenas usamos essas operações existentes. A primeira entrada de cada atribuição op é uma referência à variável que ela deve inicializar, e a segunda entrada ( assign_op.inputs[1]) é o valor inicial. Portanto, para definir qualquer valor que desejarmos (em vez do valor inicial), precisamos usar feed_dictae substituir o valor inicial. Sim, o TensorFlow permite que você alimente um valor para qualquer operação, não apenas para espaços reservados, portanto, isso funciona bem.


Obrigado pela resposta. Tenho uma pergunta semelhante sobre como converter um único arquivo .ckpt em dois arquivos .index e .data (por exemplo, para modelos de iniciação pré-treinados disponíveis no tf.slim). Minha pergunta está aqui: stackoverflow.com/questions/47762114/…
Amir

Ei, eu sou novo no tensorflow e estou tendo problemas para salvar meu modelo. Eu realmente aprecio isso se você poderia me ajudar a stackoverflow.com/questions/48083474/...
Ruchir Baronia

17

Como Yaroslav disse, você pode hackear a restauração de um graph_def e ponto de verificação importando o gráfico, criando manualmente variáveis ​​e, em seguida, usando um Saver.

Eu implementei isso para meu uso pessoal, então eu gostaria de compartilhar o código aqui.

Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Obviamente, isso é um hack e não há garantia de que os modelos salvos dessa maneira permanecerão legíveis em versões futuras do TensorFlow.)


14

Se for um modelo salvo internamente, basta especificar um restaurador para todas as variáveis ​​como

restorer = tf.train.Saver(tf.all_variables())

e use-o para restaurar variáveis ​​em uma sessão atual:

restorer.restore(self._sess, model_file)

Para o modelo externo, você precisa especificar o mapeamento dos nomes de suas variáveis ​​para seus nomes de variáveis. Você pode visualizar os nomes das variáveis ​​do modelo usando o comando

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

O script inspect_checkpoint.py pode ser encontrado na pasta './tensorflow/python/tools' da fonte do Tensorflow.

Para especificar o mapeamento, você pode usar o meu Tensorflow-Worklab , que contém um conjunto de classes e scripts para treinar e treinar novamente modelos diferentes. Inclui um exemplo de reciclagem de modelos ResNet, localizado aqui


all_variables()agora está obsoleto
MiniQuark 31/05

Ei, eu sou novo no tensorflow e estou tendo problemas para salvar meu modelo. Eu realmente aprecio isso se você poderia me ajudar a stackoverflow.com/questions/48083474/...
Ruchir Baronia

12

Aqui está minha solução simples para os dois casos básicos que diferem se você deseja carregar o gráfico do arquivo ou compilá-lo durante o tempo de execução.

Esta resposta vale para o Tensorflow 0.12+ (incluindo 1.0).

Reconstruindo o Gráfico no Código

Salvando

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Carregando

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Carregando também o gráfico de um arquivo

Ao usar esta técnica, verifique se todas as suas camadas / variáveis ​​definiram explicitamente nomes exclusivos.Caso contrário, o Tensorflow tornará os nomes únicos e eles serão diferentes dos nomes armazenados no arquivo. Não é um problema na técnica anterior, porque os nomes são "mutilados" da mesma maneira no carregamento e no salvamento.

Salvando

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Carregando

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

-1 Iniciar sua resposta descartando "todas as outras respostas aqui" é um pouco duro. Dito isso, votei por outros motivos: você definitivamente deve salvar todas as variáveis ​​globais, não apenas as variáveis ​​treináveis. Por exemplo, a global_stepvariável e as médias móveis da normalização do lote são variáveis ​​não treináveis, mas ambas definitivamente valem a pena ser salvas. Além disso, você deve distinguir mais claramente a construção do gráfico da execução da sessão, por exemplo Saver(...).save(), criará novos nós sempre que você o executar. Provavelmente não é o que você quer. E há mais ...: /
MiniQuark 31/05

@MiniQuark ok, obrigado por seu feedback, eu vou editar a resposta de acordo com as suas sugestões;)
Martin Pecka

10

Você também pode conferir exemplos no TensorFlow / skflow , que oferece métodos savee restoremétodos que podem ajudá-lo a gerenciar facilmente seus modelos. Possui parâmetros que você também pode controlar com que frequência deseja fazer backup do seu modelo.


9

Se você usar tf.train.MonitoredTrainingSession como a sessão padrão, não precisará adicionar código extra para salvar / restaurar as coisas. Basta passar um nome de dir de ponto de verificação para o construtor MonitoredTrainingSession, ele usará ganchos de sessão para lidar com eles.


O uso do tf.train.Supervisor cuidará da criação dessa sessão e fornecerá uma solução mais completa.
Mark

1
O supervisor está obsoleto #
Changming Sun

Você tem algum link que suporte a alegação de que o Supervisor está obsoleto? Não vi nada que indique que seja esse o caso.
Mark


Obrigado pela URL - verifiquei com a fonte original da informação e me disseram que provavelmente estará disponível até o final da série TF 1.x, mas não há garantias depois disso.
Mark

8

Todas as respostas aqui são ótimas, mas quero acrescentar duas coisas.

Primeiro, para elaborar a resposta de @ user7505159, o "./" pode ser importante para adicionar ao início do nome do arquivo que você está restaurando.

Por exemplo, você pode salvar um gráfico sem "./" no nome do arquivo, assim:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

Mas, para restaurar o gráfico, você pode precisar acrescentar um "./" ao nome do arquivo:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

Você nem sempre precisará do "./", mas isso pode causar problemas dependendo do ambiente e da versão do TensorFlow.

Também é necessário mencionar que isso sess.run(tf.global_variables_initializer())pode ser importante antes de restaurar a sessão.

Se você estiver recebendo um erro sobre variáveis ​​não inicializadas ao tentar restaurar uma sessão salva, inclua sess.run(tf.global_variables_initializer())antes da saver.restore(sess, save_file)linha. Pode poupar uma dor de cabeça.


7

Conforme descrito na edição 6255 :

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

ao invés de

saver.restore('my_model_final.ckpt')

7

De acordo com a nova versão do Tensorflow, tf.train.Checkpointé a maneira preferível de salvar e restaurar um modelo:

Checkpoint.savee Checkpoint.restoreescreva e leia pontos de verificação baseados em objeto, em contraste com tf.train.Saver, que grava e lê pontos de verificação baseados em variável.nome. O ponto de verificação baseado em objeto salva um gráfico de dependências entre objetos Python (Camadas, Otimizadores, Variáveis, etc.) com arestas nomeadas, e esse gráfico é usado para corresponder variáveis ​​ao restaurar um ponto de verificação. Ele pode ser mais robusto às alterações no programa Python e ajuda a oferecer suporte à restauração na criação de variáveis ​​ao executar com entusiasmo. Prefira tf.train.Checkpointsobre tf.train.Saverpara o novo código .

Aqui está um exemplo:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Mais informações e exemplo aqui.


7

Para o tensorflow 2.0 , é tão simples quanto

# Save the model
model.save('path_to_my_model.h5')

Restaurar:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

E todas as operações e variáveis ​​personalizadas de tf que não fazem parte do objeto de modelo? Eles serão salvos de alguma forma quando você chamar save () no modelo? Eu tenho várias expressões personalizadas de perda e probabilidade de fluxo de tensor que são usadas na rede de inferência e geração, mas elas não fazem parte do meu modelo. O objeto de modelo My Keras contém apenas as camadas densa e conv. No TF 1, chamei o método save e posso ter certeza de que todas as operações e tensores usados ​​no meu gráfico serão salvos. No TF2, não vejo como as operações que não são de alguma forma adicionadas ao modelo keras serão salvas.
Kristof

Há mais informações sobre a restauração de modelos no TF 2.0? Não consigo restaurar pesos de arquivos de ponto de verificação gerados por meio da API do C, consulte: stackoverflow.com/questions/57944786/…
jregalad


5

tf.keras Salvamento do modelo com TF2.0

Vejo ótimas respostas para salvar modelos usando o TF1.x. Quero fornecer mais alguns indicadores para salvartensorflow.keras modelos, o que é um pouco complicado, pois há muitas maneiras de salvar um modelo.

Aqui estou fornecendo um exemplo de salvar um tensorflow.kerasmodelo na model_pathpasta no diretório atual. Isso funciona bem com o fluxo tensor mais recente (TF2.0). Atualizarei esta descrição se houver alguma alteração no futuro próximo.

Salvando e carregando o modelo inteiro

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Salvando e carregando apenas pesos do modelo

Se você estiver interessado em salvar apenas pesos do modelo e depois carregar pesos para restaurar o modelo,

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Salvando e restaurando usando o retorno de chamada do keras checkpoint

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

salvando modelo com métricas personalizadas

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Salvando o modelo keras com operações personalizadas

Quando temos operações personalizadas, como no caso a seguir ( tf.tile), precisamos criar uma função e agrupar com uma camada Lambda. Caso contrário, o modelo não pode ser salvo.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

Acho que abordei algumas das muitas maneiras de salvar o modelo tf.keras. No entanto, existem muitas outras maneiras. Comente abaixo se o seu caso de uso não estiver coberto acima. Obrigado!


3

Use tf.train.Saver para salvar um modelo, remerber, você precisará especificar a var_list, se desejar reduzir o tamanho do modelo. A lista val_ pode ser tf.trainable_variables ou tf.global_variables.


3

Você pode salvar as variáveis ​​na rede usando

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

Para restaurar a rede para reutilização mais tarde ou em outro script, use:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Pontos importantes:

  1. sess deve ser o mesmo entre as execuções iniciais e posteriores (estrutura coerente).
  2. saver.restore precisa do caminho da pasta dos arquivos salvos, não de um caminho de arquivo individual.

2

Onde você quiser salvar o modelo,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Certifique-se de que todos os seus tf.Variablenomes tenham nomes, pois você poderá restaurá-los posteriormente usando os nomes deles. E onde você deseja prever,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Verifique se a proteção é executada dentro da sessão correspondente. Lembre-se de que, se você usar o tf.train.latest_checkpoint('./'), apenas o ponto de verificação mais recente será usado.


2

Estou na versão:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Maneira simples é

Salve :

model.save("model.h5")

Restaurar:

model = tf.keras.models.load_model("model.h5")

2

Para tensorflow-2.0

é muito simples.

import tensorflow as tf

SALVE 

model.save("model_name")

RESTAURAR

model = tf.keras.models.load_model('model_name')

1

Seguindo a resposta de @Vishnuvardhan Janapati, aqui está outra maneira de salvar e recarregar o modelo com camada / métrica / perda personalizada no TensorFlow 2.0.0

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

Dessa forma, depois de executar esses códigos e salvar seu modelo com tf.keras.models.save_modelou model.saveou com ModelCheckpointretorno de chamada, você poderá recarregar seu modelo sem a necessidade de objetos personalizados precisos, tão simples quanto

new_model = tf.keras.models.load_model("./model.h5"})

0

Na nova versão do tensorflow 2.0, o processo de salvar / carregar um modelo é muito mais fácil. Por causa da implementação da API Keras, uma API de alto nível para o TensorFlow.

Para salvar um modelo: Verifique a documentação para referência: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

Para carregar um modelo:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

0

Aqui está um exemplo simples usando o formato SavedModel do Tensorflow 2.0 (que é o formato recomendado, de acordo com a documentação ) para um classificador de conjunto de dados MNIST simples, usando a API funcional Keras sem muita imaginação:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

O que é serving_default?

É o nome da definição de assinatura da tag que você selecionou (nesse caso, a servetag padrão foi selecionada). Além disso, aqui explica como encontrar as tags e assinaturas de um modelo usando saved_model_cli.

Isenções de responsabilidade

Este é apenas um exemplo básico, se você deseja colocá-lo em funcionamento, mas não é de modo algum uma resposta completa - talvez eu possa atualizá-lo no futuro. Eu só queria dar um exemplo simples usando oSavedModel TF 2.0, porque eu não vi um, mesmo assim simples, em qualquer lugar.

A resposta de @ Tom é um exemplo de SavedModel, mas não funcionará no Tensorflow 2.0, porque infelizmente existem algumas mudanças.

@ A resposta de Janishati em Vishnuvardhan diz TF 2.0, mas não é para o formato SavedModel.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.