Tensorflow: ¿Cómo reemplazar un nodo en un gráfico de cálculo?


Si tiene dos gráficos disjuntos, y desea vincularlos, gire esto:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

En esto:

x = tf.placeholder('float')
y = f(x)
z = g(y)

¿Hay alguna manera de hacer eso? Parece que podría hacer la construcción más fácil en algunos casos.

Por ejemplo, si tiene un gráfico que tiene la imagen de entrada como tf.placeholder, y desea optimizar la imagen de entrada, al estilo de sueño profundo, ¿hay alguna manera de reemplazar el marcador de posición con un nodo tf.variable? ¿O tienes que pensar en eso antes de construir el gráfico?

Author: mdaoust, 2015-11-17

4 answers

TL;DR: Si puede definir los dos cálculos como funciones de Python, debería hacerlo. Si no puede, TensorFlow tiene una funcionalidad más avanzada para serializar e importar gráficos, lo que le permite componer gráficos de diferentes fuentes.

Una forma de hacer esto en TensorFlow es construir los cálculos disjuntos como objetos tf.Graph separados, luego convertirlos en búferes de protocolo serializados usando Graph.as_graph_def():

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

Entonces podrías componer gdef_1 y gdef_2 en un tercer gráfico, usando tf.import_graph_def():

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]
 24
Author: mrry,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-21 06:02:30

Si desea combinar modelos entrenados (por ejemplo, para reutilizar partes de un modelo preentrenado en un nuevo modelo), puede usar un Saver para guardar un punto de control del primer modelo y luego restaurar ese modelo (total o parcialmente) en otro modelo.

Por ejemplo, supongamos que desea reutilizar los pesos del modelo 1 w en el modelo 2, y también convertir x de un marcador de posición a una variable:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.
 4
Author: MiniQuark,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-10-03 19:10:41

Resulta que tf.train.import_meta_graph pasa todos los argumentos adicionales al import_scoped_meta_graph subyacente que tiene el argumento input_map y lo utiliza cuando llega a su propia invocación (interna) de import_graph_def.

No está documentado, y me tomó mucho tiempo para encontrarlo, ¡pero funciona!

 4
Author: Jonan Georgiev,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-04-22 21:51:52

Ejemplo práctico:

import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
    # set variables/placeholders
    tf.placeholder(tf.int32, [], name='g1_a')
    tf.placeholder(tf.int32, [], name='g1_b')

    # example on exacting tensor by name
    a = g1.get_tensor_by_name('g1_a:0')
    b = g1.get_tensor_by_name('g1_b:0')

    # operation ==>>     c = 2 * 3 = 6
    mul_op = tf.multiply(a, b, name='g1_mul')
    sess = tf.Session()
    g1_mul_results = sess.run(mul_op, feed_dict={'g1_a:0': 2, 'g1_b:0': 3})
    print('graph1 mul = ', g1_mul_results)  # output = 6

    print('\ngraph01 operations/variables:')
    for op in g1.get_operations():
        print(op.name)

g2 = tf.Graph()
with g2.as_default():
    # set variables/placeholders
    tf.import_graph_def(g1.as_graph_def())
    g2_c = tf.placeholder(tf.int32, [], name='g2_c')

    # example on exacting tensor by name
    g1_b = g2.get_tensor_by_name('import/g1_b:0')
    g1_mul = g2.get_tensor_by_name('import/g1_mul:0')

    # operation ==>>
    b = tf.multiply(g1_b, g2_c, name='g2_var_times_g1_a')
    f = tf.multiply(g1_mul, g1_b, name='g1_mul_times_g1_b')

    print('\ngraph01 operations/variables:')
    for op in g2.get_operations():
        print(op.name)
    sess = tf.Session()

    # graph1 variable 'a' times graph2 variable 'c'(graph2)
    ans = sess.run('g2_var_times_g1_a:0', feed_dict={'g2_c:0': 4, 'import/g1_b:0': 5})
    print('\ngraph2 g2_var_times_g1_a = ', ans)  # output = 20

    # graph1 mul_op (a*b) times graph1 variable 'b'
    ans = sess.run('g1_a_times_g1_b:0',
                   feed_dict={'import/g1_a:0': 6, 'import/g1_b:0': 7})
    print('\ngraph2 g1_mul_times_g1_b:0 = ', ans)  # output = (6*7)*7 = 294

''' output
graph1 mul =  6

graph01 operations/variables:
g1_a
g1_b
g1_mul

graph01 operations/variables:
import/g1_a
import/g1_b
import/g1_mul
g2_c
g2_var_times_g1_a
g1_a_times_g1_b

graph2 g2_var_times_g1_a =  20

graph2 g1_a_times_g1_b:0 =  294
'''

Referencia ENLACE

 0
Author: Manuel Cuevas,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-10-02 17:10:55