TensorFlow, ¿por qué hay 3 archivos después de guardar el modelo?
Después de leer los documentos , guardé un modelo en TensorFlow
, aquí está mi código de demostración:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
Pero después de eso, encontré que hay 3 archivos
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
Y no puedo restaurar el modelo restaurando el archivo model.ckpt
, ya que no existe dicho archivo. Aquí está mi código
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
Entonces, ¿por qué hay 3 archivos?
4 answers
Prueba esto:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
saver.restore(sess, "/tmp/model.ckpt")
El método TensorFlow save guarda tres tipos de archivos porque almacena la estructura del gráfico por separado de los valores variables . El archivo .meta
describe la estructura del gráfico guardado, por lo que debe importarlo antes de restaurar el punto de control (de lo contrario, no sabe a qué variables corresponden los valores de punto de control guardados).
Alternativamente, puedes hacer esto:
# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
Aunque no hay ningún archivo llamado model.ckpt
, todavía consulte el punto de control guardado con ese nombre al restaurarlo. Desde el saver.py
código fuente :
Los usuarios solo necesitan interactuar con el prefijo especificado por el usuario... en su lugar de cualquier pathname físico.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-20 23:24:01
Meta file : describe la estructura del gráfico guardado, incluye GraphDef, SaverDef, etc.; luego aplica
tf.train.import_meta_graph('/tmp/model.ckpt.meta')
, restauraráSaver
yGraph
.Index file : es una tabla inmutable string-string(tensorflow::table::Table). Cada clave es el nombre de un tensor y su valor es un BundleEntryProto serializado. Cada BundleEntryProto describe los metadatos de un tensor: cuál de los archivos "data" contiene el contenido de un tensor, el desplazamiento en ese archivo, suma de comprobación, algunos datos auxiliares, etc.
Data file : es la colección TensorBundle, guarda los valores de todas las variables.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-11 11:38:45
Estoy restaurando incrustaciones de palabras entrenadas desde Word2Vec tutorial tensorflow.
En caso de que haya creado varios puntos de control:
Por ejemplo, los archivos creados se ven así
Modelo.ckpt-55695.data-00000-of-00001
Modelo.ckpt-55695.índice
Modelo.ckpt-55695.meta
Prueba esto
def restore_session(self, session):
saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
saver.restore(session, './tmp/model.ckpt-55695')
Cuando se llama a restore_session ():
def test_word2vec():
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
with tf.device("/cpu:0"):
model = Word2Vec(opts, session)
model.restore_session(session)
model.get_embedding("assistance")
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-02-22 20:41:44
Si entrenaste a una CNN con deserción, por ejemplo, podrías hacer esto:
def predict(image, model_name):
"""
image -> single image, (width, height, channels)
model_name -> model file that was saved without any extensions
"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./' + model_name + '.meta')
saver.restore(sess, './' + model_name)
# Substitute 'logits' with your model
prediction = tf.argmax(logits, 1)
# 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-01-25 04:41:38