Lista de nombres de tensor en el gráfico en Tensorflow
El objeto graph en Tensorflow tiene un método llamado "get_tensor_by_name(name)". ¿Hay alguna manera de obtener una lista de nombres de tensores válidos?
Si no, ¿alguien sabe los nombres válidos para el modelo preentrenado inception-v3 desde aquí? De su ejemplo, pool_3, es un tensor válido pero una lista de todos ellos sería agradable. Miré el documento al que se refería y algunas de las capas parecen corresponder a los tamaños de la tabla 1, pero no todas.
5 answers
El documento no refleja con precisión el modelo. Si descarga el código fuente de arxiv tiene una descripción precisa del modelo como modelo.txt, y los nombres de allí se correlacionan fuertemente con los nombres en el modelo lanzado.
Para responder a su primera pregunta, sess.graph.get_operations()
le da una lista de operaciones. Para un op, op.name
le da el nombre y op.values()
le da una lista de tensores que produce (en el modelo inception-v3, todos los nombres de tensores son el nombre de op con un ":0" adjunto a él, por lo que pool_3:0
es el tensor producido por el pooling final)
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-02-27 15:41:25
Para ver las operaciones en el gráfico (Verá muchas, por lo que para abreviar he dado aquí solo la primera cadena).
sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]
out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-02-04 06:09:27
Las respuestas anteriores son correctas. Me encontré con un código fácil de entender / simple para la tarea anterior. Así que compartirlo aquí: -
import tensorflow as tf
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
printTensors("path-to-my-pbfile.pb")
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-31 09:17:12
No es necesario crear una sesión para ver los nombres de todos los nombres de operaciones en el gráfico. Para hacer esto, solo necesita tomar un gráfico predeterminadotf.get_default_graph()
y extraer todas las operaciones: .get_operations
. Cada operación tiene muchos campos , el que necesita es name.
Aquí está el código:
import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c
for i in tf.get_default_graph().get_operations():
print i.name
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-04-30 06:33:23
Como una lista anidada comprensión:
tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]
Función para obtener nombres de tensores en un gráfico (por defecto es el gráfico por defecto):
def get_names(graph=tf.get_default_graph()):
return [t.name for op in graph.get_operations() for t in op.values()]
Función para obtener tensores en un gráfico (por defecto es el gráfico por defecto):
def get_tensors(graph=tf.get_default_graph()):
return [t for op in graph.get_operations() for t in op.values()]
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-02-28 23:58:04