TensorFlow salvando / carregando um gráfico de um arquivo


98

Pelo que recolhi até agora, existem várias maneiras diferentes de despejar um gráfico do TensorFlow em um arquivo e, em seguida, carregá-lo em outro programa, mas não consegui encontrar exemplos / informações claras sobre como eles funcionam. O que eu já sei é o seguinte:

  1. Salve as variáveis ​​do modelo em um arquivo de checkpoint (.ckpt) usando um tf.train.Saver()e restaure-as mais tarde ( fonte )
  2. Salve um modelo em um arquivo .pb e carregue-o de volta usando tf.train.write_graph()e tf.import_graph_def()( fonte )
  3. Carregue um modelo de um arquivo .pb, treine-o novamente e despeje-o em um novo arquivo .pb usando o Bazel ( fonte )
  4. Congele o gráfico para salvar o gráfico e os pesos juntos ( fonte )
  5. Use as_graph_def()para salvar o modelo, e para pesos / variáveis, mapeie-os em constantes ( fonte )

No entanto, não consegui esclarecer várias dúvidas sobre esses métodos diferentes:

  1. Com relação aos arquivos de checkpoint, eles salvam apenas os pesos treinados de um modelo? Os arquivos de checkpoint podem ser carregados em um novo programa e usados ​​para executar o modelo, ou eles simplesmente servem como formas de salvar os pesos em um modelo em um determinado momento / estágio?
  2. Em relação tf.train.write_graph(), os pesos / variáveis ​​também são salvos?
  3. Em relação ao Bazel, ele só pode salvar / carregar arquivos .pb para treinamento? Existe um comando simples do Bazel apenas para despejar um gráfico em um .pb?
  4. Com relação ao congelamento, um gráfico congelado pode ser carregado usando tf.import_graph_def()?
  5. A demonstração do Android para TensorFlow é carregada no modelo Inception do Google a partir de um arquivo .pb. Se eu quisesse substituir meu próprio arquivo .pb, como faria para fazer isso? Eu precisaria alterar algum código / método nativo?
  6. Em geral, qual é exatamente a diferença entre todos esses métodos? Ou mais amplamente, qual é a diferença entre as_graph_def()/.ckpt/.pb?

Em suma, o que estou procurando é um método para salvar um gráfico (como em, as várias operações e outros) e seus pesos / variáveis ​​em um arquivo, que pode então ser usado para carregar o gráfico e os pesos em outro programa , para uso (não necessariamente continuando / retreinando).

A documentação sobre este tópico não é muito direta, portanto, quaisquer respostas / informações serão muito apreciadas.


2
A API mais recente / mais completa é o meta gráfico, que lhe dará uma maneira de salvar todos os três de uma vez - 1) gráfico 2) valores de parâmetro 3) coleções: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Respostas:


80

Existem muitas maneiras de abordar o problema de salvar um modelo no TensorFlow, o que pode torná-lo um pouco confuso. Tomando cada uma de suas subquestões sucessivamente:

  1. Os arquivos de ponto de verificação (produzidos por exemplo, chamando saver.save()um tf.train.Saverobjeto) contêm apenas os pesos e quaisquer outras variáveis ​​definidas no mesmo programa. Para usá-los em outro programa, você deve recriar a estrutura de gráfico associada (por exemplo, executando o código para criá-lo novamente ou chamando tf.import_graph_def()), que informa ao TensorFlow o que fazer com esses pesos. Observe que a chamada saver.save()também produz um arquivo contendo um MetaGraphDef, que contém um gráfico e detalhes de como associar os pesos de um ponto de verificação a esse gráfico. Veja o tutorial para mais detalhes.

  2. tf.train.write_graph()escreve apenas a estrutura do gráfico; não os pesos.

  3. O Bazel não está relacionado à leitura ou gravação de gráficos do TensorFlow. (Talvez eu não tenha entendido sua pergunta: sinta-se à vontade para esclarecê-la em um comentário.)

  4. Um gráfico congelado pode ser carregado usando tf.import_graph_def(). Nesse caso, os pesos são (normalmente) embutidos no gráfico, então você não precisa carregar um checkpoint separado.

  5. A principal mudança seria atualizar os nomes dos tensores que são alimentados no modelo e os nomes dos tensores que são buscados no modelo. Na demonstração do TensorFlow Android, isso corresponderia às strings inputNamee outputNameque são passadas para TensorFlowClassifier.initializeTensorFlow().

  6. A GraphDefé a estrutura do programa, que normalmente não se altera durante o processo de formação. O ponto de verificação é um instantâneo do estado de um processo de treinamento, que normalmente muda a cada etapa do processo de treinamento. Como resultado, o TensorFlow usa diferentes formatos de armazenamento para esses tipos de dados, e a API de baixo nível oferece diferentes maneiras de salvá-los e carregá-los. Bibliotecas de nível superior, como as MetaGraphDefbibliotecas, Keras e skflow, se baseiam nesses mecanismos para fornecer maneiras mais convenientes de salvar e restaurar um modelo inteiro.


Isso significa que a documentação da API C ++ mente, quando diz que você pode carregar o gráfico salvo tf.train.write_graph()e executá-lo?
mnicky

2
A documentação da API C ++ não mente, mas faltam alguns detalhes. O detalhe mais importante é que, além do GraphDefsalvo por tf.train.write_graph(), você também precisa lembrar os nomes dos tensores que deseja alimentar e buscar ao executar o gráfico (item 5 acima).
março

@mrry: Tentei usar o exemplo tensorflows DeepDream. mas parece que precisa de modelos pré-treinados em formato pb! Executei o exemplo Cifar10, mas ele só cria pontos de verificação! Eu não consegui encontrar nenhum arquivo pb ou qualquer coisa! como posso converter meus pontos de verificação para o formato pb que o exemplo deepdream usa?
Rika

2
@ Coderx7 Eu realmente acho que você não pode converter um .ckpt para um .pb, pois o checkpoint contém apenas os pesos e variáveis ​​e não sabe nada sobre a estrutura do gráfico
davidivad

1
existe um código simples para carregar um arquivo .pb e depois executá-lo?
Kong

1

Você pode tentar o seguinte código:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.