¿La mejor manera de salvar a un modelo entrenado en PyTorch?


Estaba buscando formas alternativas de guardar un modelo entrenado en PyTorch. Hasta ahora, he encontrado dos alternativas.

  1. antorcha.save () para guardar un modelo y antorcha.load () para cargar un modelo.
  2. modelo.state_dict () para guardar un modelo entrenado y un modelo .load_state_dict() para cargar el modelo guardado.

He llegado a esta discusión donde se recomienda el enfoque 2 sobre el enfoque 1.

Mi pregunta es, ¿por qué la segundo método es preferido? Es solo porque antorcha.nn los módulos tienen esas dos funciones y se nos anima a usarlos?

Author: kmario23, 2017-03-09

2 answers

He encontrado esta página en su repositorio de github, solo pegaré el contenido aquí.


Método recomendado para guardar un modelo

Hay dos enfoques principales para serializar y restaurar un modelo.

El primero (recomendado) guarda y carga solo los parámetros del modelo:

torch.save(the_model.state_dict(), PATH)

Luego más tarde:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

El segundo guarda y carga todo el modelo:

torch.save(the_model, PATH)

Luego más tarde:

the_model = torch.load(PATH)

Sin embargo, en este caso, los datos serializados son vinculado a las clases específicas y la estructura de directorios exacta utilizada, por lo que puede romperse de varias maneras cuando utilizado en otros proyectos, o después de algunos refactores serios.

 90
Author: dontloo,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-19 05:58:35

Depende de lo que quieras hacer.

Caso # 1: Guarde el modelo para usarlo usted mismo para inferencia: Guarde el modelo, lo restaure y luego cambie el modelo al modo de evaluación. Esto se hace porque normalmente tienes BatchNorm y Dropout capas que por defecto están en modo tren en construcción:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso # 2: Guardar modelo para reanudar el entrenamiento más tarde: Si necesita seguir entrenando el modelo que está a punto de guardar, necesita guardar más de sólo el modelo. También es necesario guardar el estado del optimizador, épocas, puntuación, etc. Lo harías así:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Para reanudar el entrenamiento haría cosas como: state = torch.load(filepath), y luego, para restaurar el estado de cada objeto individual, algo como esto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Ya que está reanudando el entrenamiento, NO llame a model.eval() una vez que restaure los estados al cargar.

Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código: En Tensorflow puede crear un archivo .pb que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve. La forma equivalente de hacer esto en Pytorch sería:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Esta manera todavía no es a prueba de balas y ya que pytorch todavía está experimentando muchos cambios, no lo recomendaría.

 43
Author: Jadiel de Armas,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-08-23 10:56:33