Save and restore a Tensorflow model using Keras for continuous model training

Sometimes, we want to stop fitting the model and get the current model weights or the best weights we get so far. When do we do it? Usually, when fitting runs for too long, and we don’t see any improvement.

Table of Contents

  1. Saving the Keras model into a file
  2. Restore a Keras model from a file and continue fitting the model
    1. Does it really resume fitting?
  3. Doing it all in one script.
  4. What about optimizer parameters (learning rate, momentum, etc.)?

Occasionally, we want to restore model training after a script failure.

On other occasions, we use cheap Amazon spot instances or their equivalent provided by other services, and we must prepare our code to be interrupted and resumed at any time.

In all of those situations, we can use Tensorflow checkpoints to store the intermediate state of the model and resume training later.

To use the Tensorflow checkpoints, we need to define the model. I am not going to do it in this example, because the model structure is not relevant. When we call the model.compile function, we are ready to go.

So here is the last line of the model definition. We must write the model saving/restoring code after that line.

model.compile(loss='categorical_crossentropy', optimizer='adam')

Saving the Keras model into a file

To save the model, we are going to use Keras checkpoint feature.
In this example, I am going to store only the best version of the model.

To decide which version should be stored, Keras is going to observe the loss function and choose the model version that has minimal loss.

from keras.callbacks import ModelCheckpoint

filepath = "model.h5"

checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min')
model.fit(X, Y, epochs=5, batch_size=2000, verbose = 1, callbacks = [checkpoint])

If instead of loss we want to track the accuracy, we must change both the monitor and mode parameter.

checkpoint = ModelCheckpoint(filepath, monitor = 'acc', verbose = 1, save_best_only = True, mode = 'max')

Restore a Keras model from a file and continue fitting the model

Now, we can restore the model from the file. All we need is the load_model function. After loading the model, we can restore fitting the model.

from keras.models import load_model

new_model = load_model("model.h5")

checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min')
new_model.fit(X, Y, epochs=5, batch_size=2000, callbacks = [checkpoint], verbose = 1)

Does it really resume fitting?

Yes, let’s look at the log output. Here are the messages logged during the first run of the fit function (before saving the model).

Epoch 1/5
74394/74394 [==============================] - 121s 2ms/step - loss: 3.1464

Epoch 00001: loss improved from inf to 3.14638, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 3.0030

Epoch 00002: loss improved from 3.14638 to 3.00302, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9952

Epoch 00003: loss improved from 3.00302 to 2.99524, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.9812

Epoch 00004: loss improved from 2.99524 to 2.98121, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9357

Epoch 00005: loss improved from 2.98121 to 2.93567, saving model to model.h5

Now, let’s look at the output of the second run (after loading the model from a file and calling the fit function again):

Epoch 1/5
74394/74394 [==============================] - 118s 2ms/step - loss: 2.8460

Epoch 00001: loss improved from inf to 2.84600, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7889

Epoch 00002: loss improved from 2.84600 to 2.78892, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 116s 2ms/step - loss: 2.7534

Epoch 00003: loss improved from 2.78892 to 2.75342, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7183

Epoch 00004: loss improved from 2.75342 to 2.71827, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.6811

Epoch 00005: loss improved from 2.71827 to 2.68112, saving model to model.h5

It did not start from scratch. Keras continued fitting the model.

Doing it all in one script.

When we write a single script, we must somehow distinguish between the first run of the fit function and subsequent runs. It is necessary because during the first run we want to define the model structure, but during other runs, all we need is the load_model function.

Don’t overthink it. It is simple. Just check if the model file exists. If not, define the model and run the fit function for the first time. If the file exists, load the model from it and call the fit function again.

What about optimizer parameters (learning rate, momentum, etc.)?

When we use the code from the example above, the whole model is stored. It means that the file contains the model structure (its architecture), model weights and the optimizer parameter.

To store only the model weights, we should set the save_weights_only parameter of the ModelCheckpoint to true.

checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min', save_weights_only = True)

Obviously, in that case, we can no longer use the load_model function.

Now, it is necessary to define the model architecture again, set the optimizer parameters, and compile the model.

After all of that, we finally can call the model.load_weights function.

So, we should probably stick to storing the whole model.

Older post

A comprehensive guide to putting a machine learning model in production using Flask, Docker, and Kubernetes

How to use Docker and Flask to put a Scikit model in production as a microservice.

Newer post

How to display a progress bar in Jupyter Notebook

Display a progress bar with no additional dependencies, just Python + Jupyter Notebook

Are you looking for an experienced AI consultant? Do you need assistance with your RAG or Agentic Workflow?
Schedule a call, send me a message on LinkedIn. Schedule a call or send me a message on LinkedIn

>