¿La mejor manera de salvar a un modelo entrenado en PyTorch?
Estaba buscando formas alternativas de guardar un modelo entrenado en PyTorch. Hasta ahora, he encontrado dos alternativas.
- antorcha.save () para guardar un modelo y antorcha.load () para cargar un modelo.
- modelo.state_dict () para guardar un modelo entrenado y un modelo .load_state_dict() para cargar el modelo guardado.
He llegado a esta discusión donde se recomienda el enfoque 2 sobre el enfoque 1.
Mi pregunta es, ¿por qué la segundo método es preferido? Es solo porque antorcha.nn los módulos tienen esas dos funciones y se nos anima a usarlos?
2 answers
He encontrado esta página en su repositorio de github, solo pegaré el contenido aquí.
Método recomendado para guardar un modelo
Hay dos enfoques principales para serializar y restaurar un modelo.
El primero (recomendado) guarda y carga solo los parámetros del modelo:
torch.save(the_model.state_dict(), PATH)
Luego más tarde:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
El segundo guarda y carga todo el modelo:
torch.save(the_model, PATH)
Luego más tarde:
the_model = torch.load(PATH)
Sin embargo, en este caso, los datos serializados son vinculado a las clases específicas y la estructura de directorios exacta utilizada, por lo que puede romperse de varias maneras cuando utilizado en otros proyectos, o después de algunos refactores serios.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-19 05:58:35
Depende de lo que quieras hacer.
Caso # 1: Guarde el modelo para usarlo usted mismo para inferencia: Guarde el modelo, lo restaure y luego cambie el modelo al modo de evaluación. Esto se hace porque normalmente tienes BatchNorm
y Dropout
capas que por defecto están en modo tren en construcción:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Caso # 2: Guardar modelo para reanudar el entrenamiento más tarde: Si necesita seguir entrenando el modelo que está a punto de guardar, necesita guardar más de sólo el modelo. También es necesario guardar el estado del optimizador, épocas, puntuación, etc. Lo harías así:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
Para reanudar el entrenamiento haría cosas como: state = torch.load(filepath)
, y luego, para restaurar el estado de cada objeto individual, algo como esto:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
Ya que está reanudando el entrenamiento, NO llame a model.eval()
una vez que restaure los estados al cargar.
Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código:
En Tensorflow puede crear un archivo .pb
que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve
. La forma equivalente de hacer esto en Pytorch sería:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
Esta manera todavía no es a prueba de balas y ya que pytorch todavía está experimentando muchos cambios, no lo recomendaría.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-08-23 10:56:33