How to use EarlyStopping callback in TensorFlow with Keras

This tutorial explains how to use EarlyStopping callback with TensorFlow Keras API.

What is Callback in TensorFlow

A callback is an object that can perform actions at various stages of training.Some of the sample use cases for callbacks are:

  • Implement early stopping
  • Get a view on states and statistics of a model during training
  • Periodically save model to disk
  • Write TensorBoard logs after every batch of training etc..

In the below code snippet we will use EarlyStopping callback and understand its effect on model.fit() method.

EarlyStopping callback

EarlyStopping callback is used to stop training when a monitored metric has stopped improving. Below is the EarlyStopping class signature:
   
tf.keras.callbacks.EarlyStopping(
  monitor="loss",
  min_delta=0,
  patience=0,
  verbose=0,
  mode="auto",
  baseline=None,
  restore_best_weights=False,
)
 

Arguments
  • monitor: Quantity to be monitored.
  • min_delta: Minimum change in the monitored quantity to qualify as an improvement.
  • patience: Number of epochs with no improvement after which training will be stopped.
  • verbose: verbosity mode.
  • mode: One of {"auto", "min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing; in "max" mode it will stop when the quantity monitored has stopped increasing; in "auto" mode, the direction is automatically inferred from the name of the monitored quantity.
  • baseline: Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.
  • restore_best_weights: Whether to restore model weights from the epoch with the best value of the monitored quantity.

Using EarlyStopping callback on Keras model for MNIST dataset

1. Import required modules.

   
# import required modules
import tensorflow as tf
import tensorflow_datasets as tfds
 

2. Load MNIST dataset.

   
# load mnist dataset
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
print(type(ds_train))
 

3. Define image normalization function.

   
# define image normalize function
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label
 

4. Apply normalization and batch on training and test datasets.

   
# apply normalization and batch on training and test datasets
dataset_train = ds_train.map(normalize_img)
dataset_train = ds_train.batch(128)

dataset_test = ds_test.map(normalize_img)
dataset_test = ds_test.batch(128)
 

5. Create Keras sequential model.

   
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
 

6. Compile the model.

   
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.006),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
 

7. Instantiate EarlyStopping callback.

   
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
 

8. Train the model for 100 epoch and check number of epochs execution due to EarlyStopping.

   
history = model.fit(
    dataset_train,
    epochs=100,
    validation_data=dataset_test,
    callbacks=[callback]
)

print(len(history.history['loss']))
 

Output

   
Epoch 1/100
469/469 [==============================] - 9s 19ms/step - loss: 3.2035 - sparse_categorical_accuracy: 0.8107 - val_loss: 0.4739 - val_sparse_categorical_accuracy: 0.8911
Epoch 2/100
469/469 [==============================] - 2s 3ms/step - loss: 0.4325 - sparse_categorical_accuracy: 0.8966 - val_loss: 0.3721 - val_sparse_categorical_accuracy: 0.9147
Epoch 3/100
469/469 [==============================] - 2s 3ms/step - loss: 0.3718 - sparse_categorical_accuracy: 0.9117 - val_loss: 0.4223 - val_sparse_categorical_accuracy: 0.9026
Epoch 4/100
469/469 [==============================] - 2s 3ms/step - loss: 0.3774 - sparse_categorical_accuracy: 0.9109 - val_loss: 0.4949 - val_sparse_categorical_accuracy: 0.9073
Epoch 5/100
469/469 [==============================] - 2s 3ms/step - loss: 0.4109 - sparse_categorical_accuracy: 0.9045 - val_loss: 0.4240 - val_sparse_categorical_accuracy: 0.9076
 

As from the above output we can see only 5 epochs run during training due to patience=2 in callback and loss didn't reduce in Epoch 4 and Epoch 5 consecutively.

Complete code snippet for using EarlyStopping callback

   
# import required modules
import tensorflow as tf
import tensorflow_datasets as tfds

# load mnist dataset
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
print(type(ds_train))

# define image normalize function
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label


# apply normalization and batch on training and test datasets
dataset_train = ds_train.map(normalize_img)
dataset_train = ds_train.batch(128)

dataset_test = ds_test.map(normalize_img)
dataset_test = ds_test.batch(128)


model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.006),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)

history = model.fit(
    dataset_train,
    epochs=100,
    validation_data=dataset_test,
    callbacks=[callback]
)

print(len(history.history['loss']))
 

Categories: TensorFlow

Similar Articles