Como listar todas as operações usadas no Tensorflow SavedModel?


10

Se eu salvar meu modelo usando a tensorflow.saved_model.savefunção no formato SavedModel, como recuperar quais Tensorflow Ops são usados ​​nesse modelo posteriormente. Como o modelo pode ser restaurado, essas operações são armazenadas no gráfico, meu palpite está no saved_model.pbarquivo. Se eu carregar esse protobuf (não o modelo inteiro), a parte da biblioteca do protobuf os lista, mas isso não está documentado e marcado como um recurso experimental por enquanto. Os modelos criados no Tensorflow 1.x não terão essa parte.

Então, qual é uma maneira rápida e confiável de recuperar uma lista de operações usadas (como MatchingFilesou WriteFile) de um modelo no formato SavedModel?

Agora eu posso congelar a coisa toda, como tensorflowjs-converterfaz. Como eles também verificam as operações suportadas. Atualmente, isso não funciona quando um LSTM está no modelo, veja aqui . Existe uma maneira melhor de fazer isso, já que as Ops estão aí?

Um modelo de exemplo:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Esperado na saída de todas as operações, contendo neste caso pelo menos:

  • ReadFilecomo descrito aqui
  • ...

11
É difícil dizer exatamente o que você deseja, o que é saved_model.pb, é tf.GraphDefuma SavedModelmensagem ou um protobuf? Se você recebeu uma tf.GraphDefchamada gd, pode obter a lista de operações usadas com sorted(set(n.op for n in gd.node)). Se você tem um modelo carregado, pode fazê-lo sorted(set(op.type for op in tf.get_default_graph().get_operations())). Se for um SavedModel, você pode obter o resultado tf.GraphDef(por exemplo saved_model.meta_graphs[0].graph_def).
jdehesa 14/02

Eu quero recuperar as operações de um SavedModel armazenado. Então, de fato, a última opção que você está descrevendo. Qual é a saved_modelvariável no seu último exemplo? O resultado tf.saved_model.load('/path/to/model')ou o carregamento do protobuf do arquivo saved_model.pb.
sampers

Respostas:


1

Se saved_model.pbfor uma SavedModelmensagem protobuf, você obtém as operações diretamente a partir daí. Digamos que criamos um modelo da seguinte maneira:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Agora podemos encontrar as operações usadas por esse modelo assim:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

Eu tentei algo assim, mas infelizmente isso não é o que eu esperava: digamos que eu tenha um modelo que faça isso: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Então o ReadFile Op, conforme listado aqui, está lá, mas não é impresso.
sampers

11
@ampers Editei a resposta com um exemplo como você sugere. Eu recebo a ReadFileoperação na saída. É possível que, no seu caso real, essa operação não esteja entre a entrada e a saída do modelo salvo? Nesse caso, acho que pode ser podado.
jdehesa 14/02

De fato, com o modelo dado, ele funciona. Infelizmente para um módulo feito em tf2, isso não acontece. Se eu criar um tf.Module com 1 função com uma anotação de file_nameargumento @tf.function, contendo as chamadas que listei no meu comentário anterior, ele Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
fornecerá

adicionou um modelo à minha pergunta
sampers 17/02

@sampers Atualizei minha resposta. Eu estava usando o TF 1.x antes, não estava familiarizado com as alterações nos objetos de definição de gráfico no TF 2.x, acho que a resposta agora cobre tudo no modelo salvo. Eu acho que as operações correspondentes à função Python em que você escreveu estão saved_model.meta_graphs[0].graph_def.library.function[0](a node_defcoleção dentro desse objeto de função).
jdehesa 17/02
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.