TesnorFlow | How to load mnist data with TensorFlow Datasets

This tutorial is with TF2.0, eager execution is by default enabled inTensorFlow2.0, if you are using earlier version of TensorFlow enable eager execution to follow this post.

TensorFlow Datasets is a collection of ready to use datasets for Text, Audio, image and many other ML applications. All datasets are exposed as tf.data. Datasets, enabling easy-to-use and high-performance input pipelines. In this post we will load famous "mnist" image dataset and will configure easy to use input pipeline. Run below code in either Jupyter notebook or in google Colab.

  • Intsall TensorFlow dataset
  • 
    pip install tensorflow-datasets
    
  • Import modules and construct tf.data.Dataset object
  • 
    import tensorflow as tf
    import tensorflow_datasets as tfds
    
    ds = tfds.load('mnist', split='train', shuffle_files=True)
    
    
  • Build input pipeline
  • 
    # Build your input pipeline
    ds = ds.shuffle(1024).repeat().batch(32)
    for example in ds.take(1):
        image, label = example['image'], example['label']
        print(image.shape)
        print(label)
    
    
  • Visualize first batch of input data with matplotlib
  • 
    for i in image:
        plt.imshow(tf.squeeze(i))
        plt.show()  
    
  • Complete Code
  • 
    import tensorflow as tf
    import tensorflow_datasets as tfds
    import matplotlib.pyplot as plt
    
    # Construct a tf.data.Dataset
    ds = tfds.load('mnist', split='train', shuffle_files=True)
    
    # Build your input pipeline
    ds = ds.shuffle(1024).repeat().batch(32)
    for example in ds.take(1):
        image, label = example['image'], example['label']
        print(image.shape)
        print(label)
    
    for i in image:
        plt.imshow(tf.squeeze(i))
        plt.show() 
    
    

    Category: TensorFlow