Attention Maps of Vision Transformers
While reading Vision Transformers Need Registers I came across the term "attention maps" quite a lot. There is a really nice paragrpah in the paper that explains what it means.
You see, in a Vision Transformer (ViT), we do not only provide the image as patches, we also provide a learnable [CLS] token. The [CLS] token intuitively summarizes the contents of the image. This means, if you train a ViT and want to use it for image classification, you just grab the [CLS] token and predict on it.
This also creates a really nice phenomenon. The attention map of the [CLS] token with respect to all the image patches tells us "what are the important patches to attend to, which could summarize the contents of the image".
What if we could visualize this? What if we could see what the ViTs attend to?
Sayak Paul and I, did a project (back in the days) called "Probing ViTs" where we not only visualized the attention maps of ViTs, we probed into the models and talked about a lot of other things. A curious mind can always go to our repository to know more.
With Hugging Face transformers making it more than easy to draw the attention maps and visualize them, let's do a quick tutorial on it!
Setup and Imports
These are the packages we would need.
import requests
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
from transformers import AutoImageProcessor, AutoModel
Get the data
Let's get some data.
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image
Recommended by LinkedIn
Get the model
We will use facebook/dinov2-large for our visualization. Please feel free to use other ViT models from the Hub.
model_id = "facebook/dinov2-large"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
Run the model on the inputs. Notice the output_attentions=True parameter here. We want the model to output the attention maps for us to visualize.
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(output_attentions=True, **inputs)
Get the maps
The model returns attention maps for each layer. We would like to visualize the maps from the last layer. After we have the maps, we reshape and resize the maps for better visualization.
num_heads = 16
# Take the scores from the last layer
attention_scores = outputs.attentions[-1]
# Taking the representations from CLS token.
attentions = attention_scores[0:1, :, 0, 1:].reshape(1, num_heads, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(1, num_heads, 16, 16)
# Interpolate the attention maps
attentions = interpolate(
attentions,
size=(224, 224)
).squeeze()
Now let's visualize the maps overlayed on the image.
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(4):
for j in range(4):
if img_count < len(attentions):
axes[i, j].imshow(unnormalized_image)
axes[i, j].imshow(attentions[img_count], alpha=0.8)
axes[i, j].title.set_text(f"Attention head: {img_count}")
axes[i, j].axis("off")
img_count += 1
fig.tight_layout()
fig.savefig("dino_attention_heads", dpi=300, bbox_inches="tight")
We can see that the attention maps attend to the cats, specifically the tail, the head and the body. Visualization of attention maps can give you clarity about the model and also how well it has trained.
Hope you liked this tutorial. If you want more such short tutorials do let me know!