How to extract features from layers in TensorFlow

Feature extraction in quite common while using transfer learning in ML.In this tutorial you will learn how to extract features from tf.keras.Sequential model.

After building the Sequential model, each layer of model contains an input and output attribute, with these attributes outputs from intermediate layers can be extracted.

Lets understand this with below code snippets.

Instantiate Sequential model with three layers


import tensorflow as tf

model = tf.keras.Sequential([
        tf.keras.Input(4,),
        tf.keras.layers.Dense(3, activation="tanh", name="layer1"),
        tf.keras.layers.Dense(4, activation="relu", name="layer2"),
        tf.keras.layers.Dense(2, activation="sigmoid",name="layer3"),
])

Build the model


input = tf.random.normal((1,4))
final_output = model(input)

print(input)
print(final_output)

Output


tf.Tensor([[-0.93511236 -1.5531812   0.03609114 -1.3348812 ]], shape=(1, 4), dtype=float32)
tf.Tensor([[0.5384151  0.33594143]], shape=(1, 2), dtype=float32)

Extract features of "layer1"


features_layer1 = tf.keras.models.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer1").output,
)

print(features_layer1)

Output


tf.Tensor([[-0.753606    0.9534933  -0.56991524]], shape=(1, 3), dtype=float32)

Extract features of "layer2"


features_layer2 = tf.keras.models.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer2").output,
)

print(features_layer2(input))

Output


tf.Tensor([[0.6845998  0.47238696 0.         0.04091616]], shape=(1, 4), dtype=float32)

Extract features of "layer3"


features_layer3 = tf.keras.models.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer3").output,
)

print(features_layer3(input))

Output


tf.Tensor([[0.5384151  0.33594143]], shape=(1, 2), dtype=float32)

Output for "layer3" and model output stored in "final_output" would be same


print(final_output)
print(features_layer3(input))

Output


tf.Tensor([[0.5384151  0.33594143]], shape=(1, 2), dtype=float32)
tf.Tensor([[0.5384151  0.33594143]], shape=(1, 2), dtype=float32)

Categories: TensorFlow