TensorFlow: Max de un tensor a lo largo de un eje
Mi pregunta está en dos partes conectadas: {[11]]}
-
¿Cómo puedo calcular el máximo a lo largo de un determinado eje de un tensor? Por ejemplo, si tengo
x = tf.constant([[1,220,55],[4,3,-1]])
Quiero algo como
x_max = tf.max(x, axis=1) print sess.run(x_max) output: [220,4]
Sé que hay un
tf.argmax
y untf.maximum
, pero ninguno da el valor máximo a lo largo de un eje de un solo tensor. Por ahora tengo una solución:x_max = tf.slice(x, begin=[0,0], size=[-1,1]) for a in range(1,2): x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))
Pero parece menos que óptimo. ¿Hay una mejor manera de hacer esto?
-
Dados los índices de una
argmax
de un tensor, ¿cómo indexo en otro tensor usando esos índices? Usando el ejemplo dex
anterior, cómo hago algo como lo siguiente:ind_max = tf.argmax(x, dimension=1) #output is [1,0] y = tf.constant([[1,2,3], [6,5,4]) y_ = y[:, ind_max] #y_ should be [2,6]
Sé que el corte, como la última línea, aún no existe en TensorFlow(#206).
Mi pregunta es: cuál es la mejor solución para mi caso específico (tal vez utilizando otros métodos como reunir, seleccionar, etc.)?
Información adicional: Sé que
x
yy
van a ser bidimensionales ¡tensores solamente!
1 answers
El tf.reduce_max()
operator proporciona exactamente esta funcionalidad. Por defecto calcula el máximo global del tensor dado, pero puede especificar una lista de reduction_indices
, que tiene el mismo significado que axis
en NumPy. Para completar su ejemplo:
x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max) # ==> "array([220, 4], dtype=int32)"
Si calcula el argmax usando tf.argmax()
, puede obtener los valores de un tensor diferente y
aplanando y
usando tf.reshape()
, convertir los índices argmax en índices vectoriales de la siguiente manera, y tf.gather()
para extraer los valores apropiados:
ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])
flat_y = tf.reshape(y, [-1]) # Reshape to a vector.
# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)
y_ = tf.gather(flat_y, flat_ind_max)
print sess.run(y_) # ==> "array([2, 6], dtype=int32)"
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-10-15 16:27:57