TensorFlow guardando / cargando un gráfico desde un archivo
De lo que he recopilado hasta ahora, hay varias formas diferentes de descargar un gráfico TensorFlow en un archivo y luego cargarlo en otro programa, pero no he podido encontrar ejemplos/información claros sobre cómo funcionan. Lo que ya sé es esto:
- Guarde las variables del modelo en un archivo de punto de control (.ckpt) usando un
tf.train.Saver()
y restaurarlos más tarde ( fuente ) - Guardar un modelo en a .pb archivo y cargarlo de nuevo en el uso de
tf.train.write_graph()
ytf.import_graph_def()
(fuente) - Cargar un modelo desde a .archivo pb, reentrenarlo, y volcarlo en un nuevo .archivo pb usando Bazel (source)
- Congele el gráfico para guardar el gráfico y los pesos juntos (fuente)
- Utilizar
as_graph_def()
para guardar el modelo, y para pesos/variables, mapa en constantes (fuente)
Sin embargo, no he sido capaz de aclarar varias preguntas con respecto a estos diferentes métodos:
- Con respecto al punto de control archivos, ¿solo guardan los pesos entrenados de un modelo? ¿Podrían los archivos checkpoint ser cargados en un nuevo programa, y ser usados para ejecutar el modelo, o simplemente sirven como formas de guardar los pesos en un modelo en un cierto tiempo/etapa?
- Con respecto a
tf.train.write_graph()
, ¿también se guardan las ponderaciones/variables? - Con respecto a Bazel, solo puede guardar en/cargar desde.pb archivos para el reciclaje? ¿Hay un simple comando Bazel solo para volcar un gráfico en una .pb?
- Con respecto a la congelación, puede un congelado gráfico ser cargado en el uso de
tf.import_graph_def()
? - La demostración de Android para TensorFlow se carga en el modelo de inicio de Google desde a .archivo pb. Si quisiera sustituir a la mía .archivo pb, ¿cómo lo haría? ¿Tendría que cambiar algún código/método nativo?
- En general, ¿cuál es exactamente la diferencia entre todos estos métodos? O más ampliamente, cuál es la diferencia entre
as_graph_def()
/.ckpt/.pb?
En resumen, lo que estoy buscando es un método para guardar tanto un gráfico (como en, el varias operaciones y tales) y sus pesos / variables en un archivo, que luego se puede utilizar para cargar el gráfico y pesos en otro programa, para su uso (no necesariamente continuar/reentrenamiento).
La documentación sobre este tema no es muy sencilla, por lo que cualquier respuesta/información sería muy apreciada.
2 answers
Hay muchas maneras de abordar el problema de guardar un modelo en TensorFlow, lo que puede hacerlo un poco confuso. Tomando cada una de sus sub-preguntas por turno:
Los archivos checkpoint (producidos, por ejemplo, llamando
saver.save()
on atf.train.Saver
objeto) contienen solo los pesos, y cualquier otra variable definida en el mismo programa. Para usarlos en otro programa, debe volver a crear la estructura del gráfico asociado (por ejemplo, ejecutando código para compilarlo de nuevo, o llamandotf.import_graph_def()
), lo que le dice a TensorFlow qué hacer con esos pesos. Tenga en cuenta que llamar asaver.save()
también produce un archivo que contieneMetaGraphDef
, que contiene un gráfico y detalles de cómo asociar los pesos de un punto de control con ese gráfico. Vea el tutorial para más detalles.tf.train.write_graph()
solo escribe la estructura del gráfico; no los pesos.Bazel no está relacionado con la lectura o escritura de gráficos TensorFlow. (Tal vez yo malinterprete su pregunta: siéntase libre de aclararla en un comentario.)
Un gráfico congelado se puede cargar usando
tf.import_graph_def()
. En este caso, los pesos están (normalmente) incrustados en el gráfico, por lo que no es necesario cargar un punto de control separado.El cambio principal sería actualizar los nombres de los tensores que se introducen en el modelo, y los nombres de los tensores que se obtienen del modelo. En la demo de TensorFlow Android, esto correspondería a las cadenas
inputName
youtputName
que se pasan aTensorFlowClassifier.initializeTensorFlow()
.El
GraphDef
es la estructura del programa, que normalmente no cambia a través del proceso de capacitación. El punto de control es una instantánea del estado de un proceso de entrenamiento, que normalmente cambia en cada paso del proceso de entrenamiento. Como resultado, TensorFlow utiliza diferentes formatos de almacenamiento para estos tipos de datos, y la API de bajo nivel proporciona diferentes formas de guardarlos y cargarlos. Bibliotecas de nivel superior, como lasMetaGraphDef
las bibliotecas Keras y skflow se basan en estos mecanismos para proporcionar formas más convenientes de guardar y restaurar un modelo completo.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-08-20 17:26:05
Puedes probar el siguiente código:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-10-01 13:18:33