TensorFlow | using tf.data.Dataset.batch() method

  • This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable eager execution to run the code.
  • batch() method of tf.data.Dataset class used for combining consecutive elements of dataset into batches.In below example we look into the use of batch first without using repeat() method and than with using repeat() method.
  • 
      import tensorflow as tf
    
      print(tf.__version__)
      
      # Create Tensor
      tensor1 = tf.range(5)
      
      #print(dir(tf.data.Dataset))
      #Create dataset, this will return object of TensorSliceDataset
      dataset = tf.data.Dataset.from_tensor_slices(tensor1)
      print(dataset)
      print("Original dataset")
      for i in dataset:    
          print(i)
    	  
    ======= Output ======
      2.0.0
      <TensorSliceDataset shapes: (), types: tf.int32>
      Original dataset
      tf.Tensor(0, shape=(), dtype=int32)
      tf.Tensor(1, shape=(), dtype=int32)
      tf.Tensor(2, shape=(), dtype=int32)
      tf.Tensor(3, shape=(), dtype=int32)
      tf.Tensor(4, shape=(), dtype=int32)
    
     

    Applying batch() on dataset, notice change in shape of tensor after applying batch() method

    
      dataset = dataset.batch(batch_size=2)
      print("dataset after applying batch method")
      for i in dataset:
          print(i)
      
      ====== Output ======
    
      dataset after applying batch method
      tf.Tensor([0 1], shape=(2,), dtype=int32)
      tf.Tensor([2 3], shape=(2,), dtype=int32)
      tf.Tensor([4], shape=(1,), dtype=int32)
      
     
    

  • Below code snippet displays batch() with repeat() method on dataset.
  • 
    import tensorflow as tf
    print(tf.__version__)
    
    # Create Tensor
    tensor1 = tf.range(5)
    
    #print(dir(tf.data.Dataset))
    #Create dataset, this will return object of TensorSliceDataset
    dataset = tf.data.Dataset.from_tensor_slices(tensor1)
    print(dataset)
    print("Original dataset")
    for i in dataset:    
        print(i)
    
    
    #Using batch method with repeat
    dataset = dataset.repeat(3).batch(batch_size=2)
    print("dataset after applying batch method with repeat()")
    for i in dataset:
        print(i)
    
    ====== Output ======
    2.0.0
    <TensorSliceDataset shapes: (), types: tf.int32>
    Original dataset
    tf.Tensor(0, shape=(), dtype=int32)
    tf.Tensor(1, shape=(), dtype=int32)
    tf.Tensor(2, shape=(), dtype=int32)
    tf.Tensor(3, shape=(), dtype=int32)
    tf.Tensor(4, shape=(), dtype=int32)
    dataset after applying batch method with repeat()
    tf.Tensor([0 1], shape=(2,), dtype=int32)
    tf.Tensor([2 3], shape=(2,), dtype=int32)
    tf.Tensor([4 0], shape=(2,), dtype=int32)
    tf.Tensor([1 2], shape=(2,), dtype=int32)
    tf.Tensor([3 4], shape=(2,), dtype=int32)
    tf.Tensor([0 1], shape=(2,), dtype=int32)
    tf.Tensor([2 3], shape=(2,), dtype=int32)
    tf.Tensor([4], shape=(1,), dtype=int32)