How to use flat_map method in tf.data.Dataset

flat_map method of tf.data.Dataset flattens the dataset and maps the function given in method argument across the dataset. Function provided in argument must return a dataset object. Lets understand working of flat_map with an example.

Create dummy dataset with tf.data.Dataset.from_tensor_slices


import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[[1,2, 3], [3,4,5]]])
for i in dataset:
  print(i)
  print(i.shape)


Example Output:


tf.Tensor(
[[1 2 3]
 [3 4 5]], shape=(2, 3), dtype=int32)
(2, 3)

Apply flat_map method on dataset

In the below code snippet map_func = lambda x : tf.data.Dataset.from_tensor_slices(x**2) is being passed to dataset._flat_map, that will covert every dataset item to its square. Note that map_func must needs to return dataset object , if we use map_func = lambda x : x**2 than it will throw an error as it is not returning dataset object.


dataset = dataset.flat_map(lambda x : tf.data.Dataset.from_tensor_slices(x**2))
for i in dataset:
  print(i)

Example Output:
Note change in shape of the dataset items


tf.Tensor([1 4 9], shape=(3,), dtype=int32)
tf.Tensor([ 9 16 25], shape=(3,), dtype=int32)