¿Cómo obtener las dimensiones del tensor Tensorflow (shape)como valores int?


Supongamos que tengo un tensor de flujo tensor. ¿Cómo obtengo las dimensiones (forma) del tensor como valores enteros? Sé que hay dos métodos, tensor.get_shape() y tf.shape(tensor), pero no puedo obtener los valores de la forma como valores enteros int32.

Por ejemplo, a continuación he creado un tensor 2-D, y necesito obtener el número de filas y columnas como int32 para poder llamar a reshape() para crear un tensor de forma (num_rows * num_cols, 1). Sin embargo, el método tensor.get_shape() devuelve valores como Dimension no int32.

import tensorflow as tf
import numpy as np

sess = tf.Session()    
tensor = tf.convert_to_tensor(np.array([[1001,1002,1003],[3,4,5]]), dtype=tf.float32)

sess.run(tensor)    
# array([[ 1001.,  1002.,  1003.],
#        [    3.,     4.,     5.]], dtype=float32)

tensor_shape = tensor.get_shape()    
tensor_shape
# TensorShape([Dimension(2), Dimension(3)])    
print tensor_shape    
# (2, 3)

num_rows = tensor_shape[0] # ???
num_cols = tensor_shape[1] # ???

tensor2 = tf.reshape(tensor, (num_rows*num_cols, 1))    
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1750, in reshape
#     name=name)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 454, in apply_op
#     as_ref=input_arg.is_ref)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 621, in convert_to_tensor
#     ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
#     return constant(v, dtype=dtype, name=name)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
#     tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto
#     _AssertCompatible(values, dtype)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 290, in _AssertCompatible
#     (dtype.name, repr(mismatch), type(mismatch).__name__))
# TypeError: Expected int32, got Dimension(6) of type 'Dimension' instead.
Author: stackoverflowuser2010, 2016-11-18

3 answers

Para obtener la forma como una lista de ints, haga tensor.get_shape().as_list().

Para completar tu llamada tf.shape(), prueba tensor2 = tf.reshape(tensor, tf.TensorShape([num_rows*num_cols, 1])). O puedes hacer directamente tensor2 = tf.reshape(tensor, tf.TensorShape([-1, 1])) donde se puede inferir su primera dimensión.

 70
Author: yuefengz,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-05-06 19:14:56

Otra forma de resolver esto es así:

tensor_shape[0].value

Esto devolverá el valor int del objeto Dimension.

 18
Author: tijmen Verhulsdonck,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-05-04 08:48:39

Para un tensor 2-D, puede obtener el número de filas y columnas como int32 utilizando el siguiente código:

rows, columns = map(lambda i: i.value, tensor.get_shape())
 4
Author: Anna,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-10-25 14:00:56