How to plot Model Architecture with tf.keras.utils

For understating a Keras Model, it always good to have visual representation of model layers. In this article we will see how to display Keras Model architecture and save to a file.

tf.keras.utils provides plot_model function for plotting and saving Model architecture to the file.

Create a sample Model with below code snippet.

  import tensorflow as tf

  input = tf.keras.Input(shape=(100,), dtype='int32', name='input')

  x = tf.keras.layers.Embedding(
      output_dim=512, input_dim=1000, input_length=100)(input)

  x = tf.keras.layers.LSTM(32)(x)

  x = tf.keras.layers.Dense(64, activation='relu')(x)

  output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)

  model = tf.keras.Model(inputs=[input], outputs=[output])


Display and save Model architecture to the file

  img_file = './model_arch.png'

  tf.keras.utils.plot_model(model, to_file=img_file, show_shapes=True, show_layer_names=True)

After executing above code snippets you should see image model_arch.png in your current directory and below output on Jupyter Notebook