¿Cómo agregar la condición if en un gráfico TensorFlow?


Digamos que tengo el siguiente código:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  

¿Funcionaría la declaración if en el cálculo (no lo creo)? Si no, ¿cómo puedo agregar una instrucción if en el gráfico de cálculo de TensorFlow?

Author: P-Gn, 2016-03-07

1 answers

Tiene razón en que la instrucción if no funciona aquí, porque la condición se evalúa en el tiempo de construcción del gráfico, mientras que presumiblemente desea que la condición dependa del valor introducido en el marcador de posición en tiempo de ejecución. (De hecho, siempre tomará la primera rama, porque condition > 0 evalúa a un Tensor, que es "verdadero" en Python.)

Para soportar el flujo de control condicional, TensorFlow proporciona el tf.cond() operador, que evalúa una de las dos ramas, dependiendo de una condición booleana. Para mostrarte cómo usarlo, reescribiré tu programa para que condition sea un valor escalar tf.int32 por simplicidad:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
 65
Author: mrry,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-12 23:18:21