TensorFlow, ¿por qué hay 3 archivos después de guardar el modelo?


Después de leer los documentos , guardé un modelo en TensorFlow, aquí está mi código de demostración:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

Pero después de eso, encontré que hay 3 archivos

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Y no puedo restaurar el modelo restaurando el archivo model.ckpt, ya que no existe dicho archivo. Aquí está mi código

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Entonces, ¿por qué hay 3 archivos?

 69
Author: GoingMyWay, 2016-12-21

4 answers

Prueba esto:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

El método TensorFlow save guarda tres tipos de archivos porque almacena la estructura del gráfico por separado de los valores variables . El archivo .meta describe la estructura del gráfico guardado, por lo que debe importarlo antes de restaurar el punto de control (de lo contrario, no sabe a qué variables corresponden los valores de punto de control guardados).

Alternativamente, puedes hacer esto:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Aunque no hay ningún archivo llamado model.ckpt, todavía consulte el punto de control guardado con ese nombre al restaurarlo. Desde el saver.py código fuente :

Los usuarios solo necesitan interactuar con el prefijo especificado por el usuario... en su lugar de cualquier pathname físico.

 72
Author: T.K. Bartel,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-20 23:24:01
  • Meta file : describe la estructura del gráfico guardado, incluye GraphDef, SaverDef, etc.; luego aplica tf.train.import_meta_graph('/tmp/model.ckpt.meta'), restaurará Saver y Graph.

  • Index file : es una tabla inmutable string-string(tensorflow::table::Table). Cada clave es el nombre de un tensor y su valor es un BundleEntryProto serializado. Cada BundleEntryProto describe los metadatos de un tensor: cuál de los archivos "data" contiene el contenido de un tensor, el desplazamiento en ese archivo, suma de comprobación, algunos datos auxiliares, etc.

  • Data file : es la colección TensorBundle, guarda los valores de todas las variables.

 31
Author: Guangcong Liu,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-11 11:38:45

Estoy restaurando incrustaciones de palabras entrenadas desde Word2Vec tutorial tensorflow.

En caso de que haya creado varios puntos de control:

Por ejemplo, los archivos creados se ven así

Modelo.ckpt-55695.data-00000-of-00001

Modelo.ckpt-55695.índice

Modelo.ckpt-55695.meta

Prueba esto

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

Cuando se llama a restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
 2
Author: Steven Wong,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-02-22 20:41:44

Si entrenaste a una CNN con deserción, por ejemplo, podrías hacer esto:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
 0
Author: Sashank Aryal,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-01-25 04:41:38