TensorFlow | How to use tf.data.Dataset.repeat() method in TensorFlow

This post explains how to use tf.data.Dataset.repeat() method in TensorFlow.

repeat() method of tf.data.Dataset class is used for repeating the tensors for a given count times in dataset. If repeat(count=None) or repeat(count=-1) is specified than dataset is repeated indefinitely.

Below is the code snippet for using tf.data.Dataset.repeat() in TensorFlow.

1. Create TensorSliceDataset object
   
import tensorflow as tf

print(tf.__version__)

# Create Tensor
tensor1 = tf.range(5)

#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(type(dataset))
print(dataset)
for i in dataset:
    print(i)
 

Output
   
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
<TensorSliceDataset shapes: (), types: tf.int32>
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
 

2. Repeat dataset for 3 times using tf.data.Dataset.repeat() with count=3

There are 5 tensors in dataset object, by using repeat with count=3 on dataset, dataset would be repeated 3 times so each original value would be appearing 3 times in output.


dataset = dataset.repeat(count=3)
for i in dataset:
    print(i)
 
 

Output
   
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
 

3. Repeat dataset indefinitely using tf.data.Dataset.repeat() with count=None or count=-1
   
dataset = dataset.repeat(count=None)
print(dataset)
for i in dataset:
    print(i)

dataset = dataset.repeat(count=-1)
print(dataset)
for i in dataset:
    print(i)
 

Follow US on Twitter: