Cargando un modelo de Keras entrenado y continuar la capacitación


Me preguntaba si era posible guardar un modelo de Keras parcialmente entrenado y continuar el entrenamiento después de cargar el modelo nuevamente.

La razón de esto es que tendré más datos de entrenamiento en el futuro y no quiero volver a entrenar todo el modelo de nuevo.

Las funciones que estoy usando son:

#Partly train model
model.fit(first_training, first_classes, batch_size=32, nb_epoch=20)

#Save partly trained model
model.save('partly_trained.h5')

#Load partly trained model
from keras.models import load_model
model = load_model('partly_trained.h5')

#Continue training
model.fit(second_training, second_classes, batch_size=32, nb_epoch=20)

Edición 1: se agregó un ejemplo completamente funcional

Con el primer conjunto de datos después de 10 épocas la pérdida de la última época será 0.0748 y el precisión 0.9863.

Después de guardar, eliminar y recargar el modelo, la pérdida y la precisión del modelo entrenado en el segundo conjunto de datos serán 0.1711 y 0.9504 respectivamente.

¿Esto es causado por los nuevos datos de entrenamiento o por un modelo completamente re-entrenado?

"""
Model by: http://machinelearningmastery.com/
"""
# load (downloaded if needed) the MNIST dataset
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.models import load_model
numpy.random.seed(7)

def baseline_model():
    model = Sequential()
    model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu'))
    model.add(Dense(num_classes, init='normal', activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

if __name__ == '__main__':
    # load data
    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    # flatten 28*28 images to a 784 vector for each image
    num_pixels = X_train.shape[1] * X_train.shape[2]
    X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
    X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')
    # normalize inputs from 0-255 to 0-1
    X_train = X_train / 255
    X_test = X_test / 255
    # one hot encode outputs
    y_train = np_utils.to_categorical(y_train)
    y_test = np_utils.to_categorical(y_test)
    num_classes = y_test.shape[1]

    # build the model
    model = baseline_model()

    #Partly train model
    dataset1_x = X_train[:3000]
    dataset1_y = y_train[:3000]
    model.fit(dataset1_x, dataset1_y, nb_epoch=10, batch_size=200, verbose=2)

    # Final evaluation of the model
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))

    #Save partly trained model
    model.save('partly_trained.h5')
    del model

    #Reload model
    model = load_model('partly_trained.h5')

    #Continue training
    dataset2_x = X_train[3000:]
    dataset2_y = y_train[3000:]
    model.fit(dataset2_x, dataset2_y, nb_epoch=10, batch_size=200, verbose=2)
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))
Author: Wilmar van Ommeren, 2017-03-08

4 answers

En realidad - model.save guarda toda la información necesaria para reiniciar el entrenamiento en su caso. Lo único que podría estropearse al recargar el modelo es su estado optimizador. Para comprobarlo-intente save y vuelva a cargar el modelo y entrénelo en los datos de entrenamiento.

 11
Author: Marcin Możejko,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-03-08 11:45:28

Observe que Keras a veces tiene problemas con los modelos cargados, como en aquí. Esto podría explicar los casos en los que no se parte de la misma precisión entrenada.

 2
Author: shahar_m,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-26 08:42:45

El problema puede ser que use un optimizador diferente, o argumentos diferentes a su optimizador. Acabo de tener el mismo problema con un modelo preentrenado personalizado, usando

reduce_lr = ReduceLROnPlateau(monitor='loss', factor=lr_reduction_factor,
                              patience=patience, min_lr=min_lr, verbose=1)

Para el modelo preentrenado, donde la tasa de aprendizaje original comienza en 0.0003 y durante el pre-entrenamiento se reduce a la tasa min_learning, que es 0.000003

Acabo de copiar esa línea a la secuencia de comandos que utiliza el modelo pre-entrenado y consiguió muy malas precisiones. Hasta que me di cuenta de que el último la tasa de aprendizaje del modelo preentrenado fue la tasa mínima de aprendizaje, es decir, 0,000003. Y si empiezo con esa tasa de aprendizaje, obtengo exactamente las mismas precisiones para comenzar como la salida del modelo preentrenado, lo que tiene sentido, ya que comenzar con una tasa de aprendizaje que es 100 veces mayor que la última tasa de aprendizaje utilizada en el modelo preentrenado resultará en un gran exceso de GD y, por lo tanto, en una gran disminución de las precisiones.

 2
Author: Wolfgang,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-02-08 14:37:14

Todo lo anterior ayuda, debe reanudar desde la misma tasa de aprendizaje() que el LR cuando se guardaron el modelo y los pesos. Configúrelo directamente en el optimizador.

Tenga en cuenta que la mejora a partir de ahí no está garantizada, porque el modelo puede haber alcanzado el mínimo local, que puede ser global. No tiene sentido reanudar un modelo con el fin de buscar otro mínimo local, a menos que la intención de aumentar la tasa de aprendizaje de una manera controlada y empujar el modelo en un posiblemente mejor mínimo no muy lejos.

 0
Author: flowgrad,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-29 18:55:32