How to use EarlyStopping callback in TensorFlow with Keras

This tutorial explains how to use Keras EarlyStopping callback API.

Callback in Keras

A callback is an object that can perform actions at various stages of training. Some of the sample use cases for callbacks are

  • Implement early stopping
  • Get a view on states and statistics of a model during training
  • Periodically save model to disk
  • Write TensorBoard logs after every batch of training etc..

In the below code snippet we will use EarlyStopping callback and understand its effect on model.fit() method.

Keras EarlyStopping callback

EarlyStopping callback is used to stop training when a monitored metric has stopped improving. Below is the EarlyStopping class signature:
 
    tf.keras.callbacks.EarlyStopping(
      monitor="loss",
      min_delta=0,
      patience=0,
      verbose=0,
      mode="auto",
      baseline=None,
      restore_best_weights=False,
    )
     

Arguments
  • monitor: Quantity to be monitored.
  • min_delta: Minimum change in the monitored quantity to qualify as an improvement.
  • patience: Number of epochs with no improvement after which training will be stopped.
  • verbose: verbosity mode.
  • mode: One of {"auto", "min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing; in "max" mode it will stop when the quantity monitored has stopped increasing; in "auto" mode, the direction is automatically inferred from the name of the monitored quantity.
  • baseline: Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.
  • restore_best_weights: Whether to restore model weights from the epoch with the best value of the monitored quantity.

Using EarlyStopping callback with MNIST dataset

1. Import required modules.

   
  # import required modules
  import tensorflow as tf
  import tensorflow_datasets as tfds
   

2. Load MNIST dataset.

   
  # load mnist dataset
  (ds_train, ds_test), ds_info = tfds.load(
      'mnist',
      split=['train', 'test'],
      shuffle_files=True,
      as_supervised=True,
      with_info=True,
  )
  print(type(ds_train))
   

3. Define image normalization function.

   
  # define image normalize function
  def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label
   

4. Apply normalization and batch on training and test datasets.

   
  # apply normalization and batch on training and test datasets
  dataset_train = ds_train.map(normalize_img)
  dataset_train = ds_train.batch(128)
  
  dataset_test = ds_test.map(normalize_img)
  dataset_test = ds_test.batch(128)
   

5. Create Keras sequential model.

   
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
  ])
   

6. Compile the model.

   
  model.compile(
      optimizer=tf.keras.optimizers.Adam(0.006),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )
   

7. Instantiate EarlyStopping callback.

   
  callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
   

8. Train the model for 100 epoch and check number of epochs execution due to EarlyStopping.

   
  history = model.fit(
      dataset_train,
      epochs=100,
      validation_data=dataset_test,
      callbacks=[callback]
  )
  
  print(len(history.history['loss']))
   

Output

   
  Epoch 1/100
  469/469 [==============================] - 9s 19ms/step - loss: 3.2035 - sparse_categorical_accuracy: 0.8107 - val_loss: 0.4739 - val_sparse_categorical_accuracy: 0.8911
  Epoch 2/100
  469/469 [==============================] - 2s 3ms/step - loss: 0.4325 - sparse_categorical_accuracy: 0.8966 - val_loss: 0.3721 - val_sparse_categorical_accuracy: 0.9147
  Epoch 3/100
  469/469 [==============================] - 2s 3ms/step - loss: 0.3718 - sparse_categorical_accuracy: 0.9117 - val_loss: 0.4223 - val_sparse_categorical_accuracy: 0.9026
  Epoch 4/100
  469/469 [==============================] - 2s 3ms/step - loss: 0.3774 - sparse_categorical_accuracy: 0.9109 - val_loss: 0.4949 - val_sparse_categorical_accuracy: 0.9073
  Epoch 5/100
  469/469 [==============================] - 2s 3ms/step - loss: 0.4109 - sparse_categorical_accuracy: 0.9045 - val_loss: 0.4240 - val_sparse_categorical_accuracy: 0.9076
   

As from the above output we can see only 5 epochs run during training due to patience=2 in callback and loss didn't reduce in Epoch 4 and Epoch 5 consecutively.

Complete code snippet for using EarlyStopping callback

   
  # import required modules
  import tensorflow as tf
  import tensorflow_datasets as tfds
  
  # load mnist dataset
  (ds_train, ds_test), ds_info = tfds.load(
      'mnist',
      split=['train', 'test'],
      shuffle_files=True,
      as_supervised=True,
      with_info=True,
  )
  print(type(ds_train))
  
  # define image normalize function
  def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label
  
  
  # apply normalization and batch on training and test datasets
  dataset_train = ds_train.map(normalize_img)
  dataset_train = ds_train.batch(128)
  
  dataset_test = ds_test.map(normalize_img)
  dataset_test = ds_test.batch(128)
  
  
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
  ])
  
  model.compile(
      optimizer=tf.keras.optimizers.Adam(0.006),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )
  
  callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
  
  history = model.fit(
      dataset_train,
      epochs=100,
      validation_data=dataset_test,
      callbacks=[callback]
  )
  
  print(len(history.history['loss']))
   

Category: TensorFlow