Cómo usar tf.while loop() en tensorflow


Esta es una pregunta genérica. Encontré que en el tensorflow, después de construir el gráfico, obtener datos en el gráfico, la salida del gráfico es un tensor. pero en muchos casos, necesitamos hacer algún cálculo basado en esta salida (que es un tensor), que no está permitido en tensorflow.

Por ejemplo, estoy tratando de implementar un RNN, que hace bucles en tiempos basados en la propiedad data self. Es decir, necesito usar un tensor para juzgar si debo parar (no estoy usando dynamic_rnn ya que en mi diseño, el rnn es altamente personalizado). Creo que tf.while_loop(cond,body.....) podría ser un candidato para mi implementación. Pero el tutorial oficial es demasiado simple. No se como agregar mas funcionalidades al 'cuerpo'. ¿Puede alguien darme un ejemplo más complejo?

También, en tal caso que si el cálculo futuro se basa en la salida del tensor (por ejemplo: la parada RNN basada en el criterio de salida), que es un caso muy común. ¿Hay una forma elegante o mejor en lugar de gráfico dinámico?

Author: Hanyu Guo, 2016-05-25

1 answers

¿Qué te impide agregar más funcionalidad al cuerpo? Puedes construir cualquier grafo computacional complejo que te guste en el cuerpo y tomar cualquier entrada que te guste del grafo envolvente. Además, fuera del bucle, puede hacer lo que quiera con las salidas que devuelva. Como se puede ver por la cantidad de 'whatevers', las primitivas de flujo de control de TensorFlow se construyeron con mucha generalidad en mente. A continuación se muestra otro ejemplo 'simple', en caso de que ayude.

import tensorflow as tf
import numpy as np

def body(x):
    a = tf.random_uniform(shape=[2, 2], dtype=tf.int32, maxval=100)
    b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
    c = a + b
    return tf.nn.relu(x + c)

def condition(x):
    return tf.reduce_sum(x) < 100

x = tf.Variable(tf.constant(0, shape=[2, 2]))

with tf.Session():
    tf.initialize_all_variables().run()
    result = tf.while_loop(condition, body, [x])
    print(result.eval())
 41
Author: Peter Goldsborough,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-05-25 18:17:18