Diferença entre Variável e get_variable no TensorFlow


125

Até onde eu sei, Variableé a operação padrão para criar uma variável e get_variableé usada principalmente para o compartilhamento de peso.

Por um lado, algumas pessoas sugerem o uso get_variableda Variableoperação primitiva , sempre que você precisar de uma variável. Por outro lado, apenas vejo uso get_variablenos documentos e demonstrações oficiais do TensorFlow.

Portanto, quero conhecer algumas regras práticas sobre como usar corretamente esses dois mecanismos. Existem princípios "padrão"?


6
get_variable é nova maneira, Variable é maneira antiga (que pode ser suportada para sempre) como Lukasz diz (PS: ele escreveu grande parte do escopo nome de variável em TF)
Yaroslav Bulatov

Respostas:


90

Eu recomendo o uso sempre tf.get_variable(...)- facilitará a refatoração do seu código, se você precisar compartilhar variáveis ​​a qualquer momento, por exemplo, em uma configuração de várias gpu (veja o exemplo CIFAR de várias gpu). Não há desvantagem nisso.

Puro tf.Variableé de nível inferior; em algum momento tf.get_variable()não existia, então algum código ainda usa a maneira de baixo nível.


5
Muito obrigado pela sua resposta. Mas ainda tenho uma pergunta sobre como substituir tf.Variablepor tf.get_variabletodos os lugares. É quando eu quero inicializar uma variável com uma matriz numpy, não consigo encontrar uma maneira limpa e eficiente de fazê-lo, como faço com tf.Variable. Como você resolve isso? Obrigado.
Lifu Huang

68

tf.Variable é uma classe e existem várias maneiras de criar tf.Variable, incluindo tf.Variable.__init__e tf.get_variable.

tf.Variable.__init__: Cria uma nova variável com valor_inicial .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Obtém uma variável existente com esses parâmetros ou cria uma nova. Você também pode usar o inicializador.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

É muito útil usar inicializadores como xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Mais informações aqui .


Sim, na Variableverdade quero dizer usando o seu __init__. Como get_variableé tão conveniente, eu me pergunto por que a maioria dos códigos do TensorFlow que eu vi usar em Variablevez de get_variable. Existem convenções ou fatores a serem considerados ao escolher entre eles. Obrigado!
Lifu Huang 8/16

Se você deseja ter um determinado valor, o uso de Variável é simples: x = tf.Variable (3).
Sung Kim

@SungKim normalmente, quando usamos tf.Variable(), podemos inicializá-lo como um valor aleatório a partir de uma distribuição normal truncada. Aqui está o meu exemplo w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). Qual seria o equivalente disso? como posso dizer que quero um normal truncado? Devo apenas fazer w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: Você pode usar tf.truncated_normal_initializer()para obter o resultado desejado.
Beta

46

Eu posso encontrar duas diferenças principais entre uma e outra:

  1. A primeira é que tf.Variablesempre criará uma nova variável, enquanto tf.get_variableobtém uma variável existente com parâmetros especificados no gráfico e, se não existir, cria uma nova.

  2. tf.Variable requer que um valor inicial seja especificado.

É importante esclarecer que a função tf.get_variableprefixa o nome com o escopo da variável atual para executar verificações de reutilização. Por exemplo:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

O último erro de asserção é interessante: Duas variáveis ​​com o mesmo nome no mesmo escopo devem ser a mesma variável. Mas se você testar os nomes das variáveis de eperceberá que o Tensorflow mudou o nome da variável e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Ótimo exemplo! No que diz respeito d.namee e.name, acabo de vir através de um doc este TensorFlow no tensor operação gráfico de nomenclatura que explica:If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

Outra diferença reside em que um está na ('variable_store',)coleção, mas o outro não.

Por favor, veja o código fonte :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Deixe-me ilustrar isso:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

A saída:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.