Problema de alimentación de una lista en dict de alimentación en TensorFlow


Estoy tratando de pasar una lista a feed_dict, sin embargo estoy teniendo problemas para hacerlo. Digamos que tengo:

inputs = 10 * [tf.placeholder(tf.float32, shape=(batch_size, input_size))]

Donde las entradas se introducen en alguna función outputs que quiero calcular. Así que para ejecutar esto en tensorflow, creé una sesión y ejecuté lo siguiente:

sess.run(outputs, feed_dict = {inputs: data}) 
#data is my list of inputs, which is also of length 10

Pero tengo un error, TypeError: unhashable type: 'list'. Sin embargo, soy capaz de pasar el elemento de datos de esta manera:

sess.run(outputs, feed_dict = {inputs[0]: data[0], ..., inputs[9]: data[9]}) 

Así que me pregunto si hay una manera de resolver este problema. También he intentado construir un diccionario (usando un bucle for ), sin embargo, esto resulta en un diccionario con un solo elemento, donde la clave es: tensorflow.python.framework.ops.Tensor at 0x107594a10

Author: Mooncrater, 2015-11-13

2 answers

Hay dos cuestiones que están causando problemas aquí:

La primera cuestión es que la Session.run() call solo acepta un pequeño número de tipos como claves del feed_dict. En particular, las listas de tensores son no apoyado como llaves, así que usted tiene que poner cada tensor como llave separada.* Una forma conveniente de hacer esto es usar una comprensión de diccionario:

inputs = [tf.placeholder(...), ...]
data = [np.array(...), ...]
sess.run(y, feed_dict={i: d for i, d in zip(inputs, data)})

El segundo problema es que la sintaxis 10 * [tf.placeholder(...)] en Python crea una lista con diez elementos, donde cada elemento es el mismo objeto tensor (es decir, tiene la misma propiedad name, la misma propiedad id, y es idéntica a la referencia si compara dos elementos de la lista usando inputs[i] is inputs[j]). Esto explica por qué, cuando intentó crear un diccionario usando los elementos de la lista como claves, terminó con un diccionario con un solo elemento, porque todos los elementos de la lista eran idénticos.

Para crear 10 tensores de marcador de posición diferentes, como pretendía, en su lugar debe hacer lo siguiente:

inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))
          for _ in xrange(10)]

Si imprime los elementos de esta lista, verá que cada elemento es un tensor con un nombre diferente.


EDITAR: * Ahora puede pasar tuplas como las claves de un feed_dict, porque estas pueden usarse como claves de diccionario.

 41
Author: mrry,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-12-19 16:23:35

Aquí hay un ejemplo correcto:

batch_size, input_size, n = 2, 3, 2
# in your case n = 10
x = tf.placeholder(tf.types.float32, shape=(n, batch_size, input_size))
y = tf.add(x, x)

data = np.random.rand(n, batch_size, input_size)

sess = tf.Session()
print sess.run(y, feed_dict={x: data})

Y aquí hay cosas extrañas que veo en su enfoque. Por alguna razón se utiliza 10 * [tf.placeholder(...)], que crea 10 tensores de tamaño (batch_size, input_size). No tengo idea de por qué haces esto, si solo puedes crear en Tensor de rango 3 (donde la primera dimensión es 10).

Debido a que tiene una lista de tensores (y no un tensor), no puede alimentar sus datos a esta lista (pero en mi caso puedo alimentar a mi tensor).

 4
Author: Salvador Dali,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-11-13 02:11:21