Trending

TensorFlow | How to use dataset api take() method in TensorFlow

take() method of tf.data.Dataset used for limiting number of items in dataset. This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable eager execution to run the code.

Lets have a look to below snippet for understanding take() method.

Create dataset with tf.data.Dataset.from_tensor_slices


  import tensorflow as tf
  # Create Tensor
  tensor1 = tf.range(5)

  #Create dataset, this will return object of TensorSliceDataset
  dataset = tf.data.Dataset.from_tensor_slices(tensor1)


Apply batch and repeat on dataset


print("dataset after applying batch and repeat")
dataset = dataset.repeat(6).batch(batch_size=2)
for i in dataset:
  print(i)

Example Output:


dataset after applying batch and repeat

tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)


 

Apply take() on dataset select few examples from dataset


print("dataset after applying take() method")
for i in dataset.take(3):
  print(i)


Example Output:


dataset after applying take() method

tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)