How to use batch_flatten in tf.keras.backend

tf.keras.backend.batch_flatten method in TensorFlow flattens the each data samples of a batch. If batch_flatten is applied on a Tensor having dimension like 3D,4D,5D or ND it always turn that tensor to 2D. 0th dimension would remain same in both input tensor and output tensor. Lets see with below example.

Example 1

Create a 4D tensor with tf.ones


import tensorflow as tf

t1_batch = tf.ones((2,3,2,2))

print(t1_batch)


tf.Tensor(
[[[[1. 1.]
   [1. 1.]]

  [[1. 1.]
   [1. 1.]]

  [[1. 1.]
   [1. 1.]]]


 [[[1. 1.]
   [1. 1.]]

  [[1. 1.]
   [1. 1.]]

  [[1. 1.]
   [1. 1.]]]], shape=(2, 3, 2, 2), dtype=float32)

Apply batch_flatten to convert t1_batch to 2D tensor


t1_batch_flatten = tf.keras.backend.batch_flatten(t1_batch)

print(t1_batch_flatten)


tf.Tensor: shape=(2, 12), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

Example 2

Create a 5D tensor with tf.ones


import tensorflow as tf

t2_batch = tf.ones((3,4,2,2,4))


Apply batch_flatten to convert t2_batch to 2D tensor


t2_batch_flatten = tf.keras.backend.batch_flatten(t2_batch)

print(t2_batch_flatten)



tf.Tensor: shape=(3, 64), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
      dtype=float32)