How to use flat_map method in

flat_map method of flattens the dataset and maps the function given in method argument across the dataset. Function provided in argument must return a dataset object. Lets understand working of flat_map with an example.

Create dummy dataset with

import tensorflow as tf

dataset =[[[1,2, 3], [3,4,5]]])
for i in dataset:

Example Output:

[[1 2 3]
 [3 4 5]], shape=(2, 3), dtype=int32)
(2, 3)

Apply flat_map method on dataset

In the below code snippet map_func = lambda x :**2) is being passed to dataset._flat_map, that will covert every dataset item to its square. Note that map_func must needs to return dataset object , if we use map_func = lambda x : x**2 than it will throw an error as it is not returning dataset object.

dataset = dataset.flat_map(lambda x :**2))
for i in dataset:

Example Output:
Note change in shape of the dataset items

tf.Tensor([1 4 9], shape=(3,), dtype=int32)
tf.Tensor([ 9 16 25], shape=(3,), dtype=int32)