Vision Transformers (ViT)
Reading Time: ~50 min | Interview Relevance: Very High | Target Roles: MLE, CV Engineer, Applied Scientist, Research Engineer
The Production Scenario
It is 2021. Your team has been running a ResNet-50 backbone in production for 18 months - chest X-ray classification, 92% sensitivity, reliable batch throughput. Then a paper drops on arXiv: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." Vision Transformers (ViT) achieve state-of-the-art on ImageNet - but only when pretrained on JFT-300M, a proprietary 300-million-image Google dataset. The community's reaction is split: impressive, but impractical.
Six months later, DeiT arrives. Training ViT on standard ImageNet-1K with distillation, it matches ResNet performance without any extra data. The transformer's lack of CNN inductive biases - a supposed weakness - turns out to be a feature: given enough data, the model learns better representations than anything hardwired. By 2022, Swin Transformer hierarchically combines local attention with global context, becoming the go-to backbone for detection and segmentation. By 2023, every frontier vision model - CLIP, SAM, Flamingo, GPT-4V - has a transformer-based vision encoder.
Your team's question is no longer whether ViT works. It is when and why to migrate from CNN to ViT, and what the engineering trade-offs look like in practice. This lesson answers both.
The Core Idea: Images as Sequences of Patches
A standard transformer processes sequences of tokens. Language models tokenize text into words or subwords. ViT tokenizes images into patches.
Given an input image of shape H x W x C, ViT divides it into N non-overlapping patches of size P x P:
N = (H * W) / P^2
Each patch is a P x P x C block, flattened to a vector of size P^2 * C:
x_p has shape N x (P^2 * C)
For a 224×224 RGB image with patch size 16×16: N = (224/16)^2 = 196 patches, each of size 16^2 * 3 = 768 dimensions before projection.
Each flattened patch is linearly projected to the transformer's model dimension D via a learned projection matrix E of shape (P^2 * C) x D:
z_i = x_p[i] @ E, for i = 1 ... N
This is the patch embedding - conceptually identical to a word embedding in NLP, but for image patches instead of tokens.
Input Image (224x224x3):
+--+--+--+--+--+--+--+--+--+--+--+--+--+--
| | | | | | | | | | | | | | ...
+--+--+--+--+--+--+--+--+--+--+--+--+--+--
| | | | | | | | | | | | | | ... 14x14 grid = 196 patches
each patch: 16x16x3 = 768 dims
Patch 1 Patch 2 Patch 3 ... Patch 196
[768] -> Linear(768, D) -> [D]
:::note Patch Size Trade-off
Smaller patches (P=8) produce longer sequences (N=784 for 224×224) - more tokens, finer granularity, quadratically higher attention cost. Larger patches (P=32) produce shorter sequences (N=49) - faster but coarser. P=16 is the standard compromise used in ViT-B/16 and ViT-L/16.
:::
The [CLS] Token and Positional Embeddings
Class token: Before the patch embeddings are fed into the transformer, a learnable classification token x_class of shape D is prepended:
z_0 = [x_class; z_1; z_2; ...; z_N] - shape (N+1) x D
After the transformer encoder processes this sequence, only the output corresponding to position 0 (the [CLS] token) is used for classification. All other outputs are discarded. The [CLS] token acts as an aggregator - it attends to all patches and accumulates global image information.
Positional embeddings: Unlike CNNs, the patch embedding operation is permutation-equivariant - the transformer does not inherently know which patch came from where. Positional embeddings add location information:
z_0 = z_0 + E_pos where E_pos has shape (N+1) x D
Surprisingly, learned 1D positional embeddings work nearly as well as 2D sinusoidal embeddings - the model learns spatial structure from the positional embedding matrix rather than having it hand-crafted. The 1D positions treat patches as a flat sequence (0 through N), and the model learns 2D structure implicitly.
E_pos = [e_cls, e_1, e_2, ..., e_N] - each e_i is a learnable vector of size D
The ViT Architecture Step by Step
Each transformer block consists of:
- LayerNorm → Multi-Head Self-Attention → Residual add (pre-LN variant, more stable than post-LN)
- LayerNorm → MLP (two linear layers with GELU) → Residual add
The MLP hidden dimension is typically 4 * D (e.g., D=768 → MLP hidden size 3072 for ViT-B).
Standard ViT model sizes:
| Model | Layers L | Hidden dim D | Heads | Params | Top-1 (ImageNet21K → 1K) |
|---|---|---|---|---|---|
| ViT-Ti/16 | 12 | 192 | 3 | 5.7M | 72.2% |
| ViT-S/16 | 12 | 384 | 6 | 22M | 81.4% |
| ViT-B/16 | 12 | 768 | 12 | 86M | 84.6% |
| ViT-L/16 | 24 | 1024 | 16 | 307M | 86.5% |
| ViT-H/14 | 32 | 1280 | 16 | 632M | 88.5% |
Multi-Head Self-Attention for Images
The self-attention mechanism in ViT is identical to standard transformer self-attention:
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
where Q = X @ W_Q, K = X @ W_K, V = X @ W_V are linear projections of the input sequence X of shape (N+1) x D, and d_k = D / H for H attention heads.
What this means for images: every patch can attend directly to every other patch in a single layer. There is no concept of locality - a patch in the top-left corner can attend to a patch in the bottom-right with no additional layers. This is the key architectural difference from CNNs, where large receptive fields require deep stacking of local convolutions.
Computational cost: The attention matrix is (N+1) x (N+1). For N=196 patches (224×224, P=16), this is a 197×197 matrix - trivially small. But for high-resolution inputs or small patch sizes, attention cost scales as O(N^2):
N = (H * W) / P^2 => attention cost ~ (H * W / P^2)^2
For 512×512 with P=16: N=1024, attention matrix is 1024 x 1024 - 4× more expensive than 224×224.
Attention maps are interpretable: After training, the attention weights from the [CLS] token to each patch show which image regions the model attends to. These "attention rollout" maps are surprisingly semantic - the model learns to focus on relevant objects without any spatial supervision.
Why ViT Needs Large-Scale Pretraining
CNNs have three strong inductive biases built in by their architecture:
- Local connectivity: each neuron sees only a local neighborhood - appropriate because nearby pixels are correlated
- Weight sharing / translation equivariance: the same filter is applied everywhere - a cat in the top-left looks the same as a cat in the bottom-right
- Hierarchical feature extraction: pooling builds up increasingly abstract features over spatial scales
ViT has none of these biases. Every patch attends to every other patch globally from layer 1. The model must learn local structure, translation invariance, and hierarchical features entirely from data.
The consequence: ViT requires substantially more data than CNNs to reach equivalent accuracy. On ImageNet-1K alone (~1.2M images), ViT-B/16 underperforms ResNet-50. It requires pretraining on ImageNet-21K (14M images) or JFT-300M (300M images) to fully realize its potential.
This is not a fundamental flaw - it is a trade-off. With sufficient data, ViT outperforms CNNs because the lack of inductive bias means the model can learn whatever structure is actually present in the data, rather than being constrained to convolve locally.
:::warning Small Dataset Guidance For datasets under 100K images, CNN-based architectures (ResNet, EfficientNet) pretrained on ImageNet will almost always outperform ViT fine-tuned from scratch. Use a pretrained ViT from timm when fine-tuning, but start with a CNN if pretraining is not an option. :::
DeiT: Data-Efficient Image Transformers
DeiT (Touvron et al., 2021) was the first work to train ViT on ImageNet-1K alone (no extra data) and match ResNet performance. The key innovation: distillation from a CNN teacher.
DeiT introduces a second special token - the distillation token - alongside the [CLS] token. The distillation token learns to match the output of a CNN teacher (RegNet-Y in the original paper):
L_DeiT = (1 - alpha) * L_CE(y_cls, y) + alpha * L_KD(y_dist, y_teacher)
where y_cls is the prediction from the [CLS] token, y_dist is from the distillation token, and y_teacher is the teacher's hard or soft prediction.
This is inductive bias injection via distillation: the CNN teacher implicitly transfers its spatial locality and translation equivariance to the ViT student. DeiT-B reaches 81.8% top-1 on ImageNet-1K - matching ResNet-50 family - with a 3-day training run on 8 GPUs.
DeiT results (ImageNet-1K only, no extra data):
| Model | Params | Top-1 | Throughput (images/s) |
|---|---|---|---|
| ResNet-50 | 25M | 79.8% | ~1200 |
| DeiT-S | 22M | 79.8% | ~1000 |
| DeiT-B | 86M | 81.8% | ~290 |
| DeiT-B (distilled) | 86M | 83.4% | ~290 |
Swin Transformer: Bridging CNN and ViT
The Swin Transformer (Liu et al., 2021) addresses two limitations of standard ViT:
- Quadratic attention cost at high resolution
- Lack of hierarchical feature maps - ViT produces a single-scale feature map, unusable as a drop-in backbone for detection/segmentation that require multi-scale features
Shifted Window (SW) Attention: Instead of global attention across all N patches, Swin computes attention within local non-overlapping windows of size M x M (typically 7 x 7). This reduces attention cost from O(N^2) to O(N * M^2) - linear in image size.
To enable cross-window communication, alternate layers shift the window partition by (floor(M/2), floor(M/2)) pixels - hence "shifted windows." This allows information to flow across window boundaries over multiple layers.
Hierarchical feature maps: Swin uses patch merging (similar to CNN pooling) to progressively downsample and increase channel count, producing a 4-stage pyramid:
Stage 1: (H/4, W/4, 96) - C1, like CNN stage 1
Stage 2: (H/8, W/8, 192) - C2
Stage 3: (H/16, W/16, 384) - C3
Stage 4: (H/32, W/32, 768) - C4
These match exactly the feature pyramid expected by FPN-based detectors (Mask R-CNN, Cascade R-CNN). Swin-T achieves 58.7 mAP on COCO, surpassing all CNN backbones at similar parameter count.
CNN vs ViT: Detailed Comparison
| Property | ResNet / CNN | ViT (standard) | Swin Transformer |
|---|---|---|---|
| Inductive biases | Local connectivity, weight sharing, translation equivariance | None - learned from data | Shifted window locality |
| Data requirement | Low - works well on 10K+ images | High - needs 1M+ or pretrained | Medium - pretrained recommended |
| Attention complexity | N/A (local convolution) | O(N^2) - global | O(N * M^2) - local windows |
| Hierarchical features | Yes - natural from pooling | No - single scale | Yes - via patch merging |
| Throughput (inference) | High - very hardware-optimized | Lower at small batches | Moderate |
| Transfer learning | Strong - widely pretrained | Very strong - large-scale pretraining | Very strong |
| Adversarial robustness | Weak to texture attacks | More robust - shape-biased | More robust |
| Detection / segmentation | Excellent backbone | Needs adaptors (e.g., ViTDet) | Excellent drop-in backbone |
| Edge / mobile deployment | Excellent (MobileNet, EfficientNet-Lite) | Poor - large, slow | Poor - large |
CNN vs ViT Decision Guide
ViT from Scratch in PyTorch
import torch
import torch.nn as nn
import math
class PatchEmbedding(nn.Module):
"""
Splits image into patches and linearly projects each patch to model dim D.
Equivalent to a Conv2d with kernel_size=patch_size, stride=patch_size.
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
self.num_patches = (img_size // patch_size) ** 2
# Conv2d trick: kernel and stride equal to patch_size extracts non-overlapping patches
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, H, W) -> (B, embed_dim, H/P, W/P) -> (B, N, embed_dim)
x = self.proj(x) # (B, D, H/P, W/P)
x = x.flatten(2) # (B, D, N)
x = x.transpose(1, 2) # (B, N, D)
return x
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, attn_drop=0.0, proj_drop=0.0):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
self.proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, D = x.shape
# Project to Q, K, V and split into heads
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv.unbind(0) # each: (B, heads, N, head_dim)
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, D) # (B, N, D)
x = self.proj_drop(self.proj(x))
return x
class TransformerBlock(nn.Module):
"""Pre-LN transformer block (LayerNorm before attention and MLP)."""
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop, proj_drop)
self.norm2 = nn.LayerNorm(embed_dim)
hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(proj_drop),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(proj_drop),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x)) # Attention sublayer with residual
x = x + self.mlp(self.norm2(x)) # MLP sublayer with residual
return x
class VisionTransformer(nn.Module):
"""
ViT-B/16 by default: 12 layers, embed_dim=768, 12 heads, patch_size=16.
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches
# Learnable [CLS] token and positional embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(proj_drop)
# Transformer encoder
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, attn_drop, proj_drop)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B = x.size(0)
x = self.patch_embed(x) # (B, N, D)
cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat([cls, x], dim=1) # (B, N+1, D)
x = self.pos_drop(x + self.pos_embed) # add positional embeddings
x = self.blocks(x) # (B, N+1, D)
x = self.norm(x)
cls_out = x[:, 0] # extract [CLS] output
return self.head(cls_out) # (B, num_classes)
# Quick test
if __name__ == '__main__':
model = VisionTransformer()
dummy = torch.randn(4, 3, 224, 224)
logits = model(dummy)
print(f"Output shape: {logits.shape}") # (4, 1000)
total_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Parameters: {total_params:.1f}M") # ~86M
Using timm for ViT
In production, always use timm (PyTorch Image Models) - it provides pretrained ViT weights, optimized implementations, and utilities for fine-tuning:
import timm
import torch
import torch.nn as nn
# List available ViT models
print(timm.list_models('vit_*', pretrained=True)[:10])
# ['vit_base_patch16_224', 'vit_base_patch16_384', 'vit_large_patch16_224', ...]
# Load pretrained ViT-B/16
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=10, # replace head for your task
drop_path_rate=0.1, # stochastic depth regularization
)
# Get model-specific preprocessing config
data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False)
# Freeze all layers except the head (for quick fine-tuning)
for name, param in model.named_parameters():
if 'head' not in name:
param.requires_grad = False
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-4, weight_decay=0.05
)
Fine-tuning ViT: Layer-wise Learning Rate Decay
When fine-tuning ViT end-to-end (not just the head), the standard recipe applies layer-wise learning rate decay (LLRD): outer layers (closer to the output) receive higher learning rates than inner layers (closer to the input). The intuition is that earlier layers capture universal low-level features that should change little, while later layers encode task-specific features that need more adaptation.
lr_l = lr_base * gamma^(L - l), where gamma is typically in [0.65, 0.9]
L is the total number of layers, l is the layer index (0 = earliest).
def build_llrd_optimizer(model, base_lr=1e-4, weight_decay=0.05, decay_rate=0.75):
"""Layer-wise learning rate decay for ViT fine-tuning."""
param_groups = []
num_layers = len(model.blocks)
# Embedding layers get the smallest LR
param_groups.append({
'params': [*model.patch_embed.parameters(),
model.cls_token, model.pos_embed],
'lr': base_lr * (decay_rate ** (num_layers + 1)),
'weight_decay': 0.0,
})
# Each transformer block gets progressively higher LR
for i, block in enumerate(model.blocks):
layer_lr = base_lr * (decay_rate ** (num_layers - i))
param_groups.append({
'params': list(block.parameters()),
'lr': layer_lr,
'weight_decay': weight_decay,
})
# Head gets full base LR
param_groups.append({
'params': list(model.head.parameters()),
'lr': base_lr,
'weight_decay': weight_decay,
})
return torch.optim.AdamW(param_groups)
Positional embedding interpolation: When fine-tuning on a different resolution (e.g., pretrained at 224×224, fine-tune at 384×384), the number of patches changes. timm handles this automatically - it bilinearly interpolates the learned positional embeddings from a (14, 14) to a (24, 24) grid. This works surprisingly well and is the standard approach.
Production Considerations
Throughput: ViT-B/16 processes roughly 300 images/sec on a V100 at batch size 64, compared to ~1200 images/sec for ResNet-50. For latency-sensitive applications, this difference is significant. FlashAttention (Dao et al., 2022) reduces this gap by 2–4× for attention-heavy workloads.
Memory: The self-attention matrix is O(N^2) in memory. For N=196 (224×224, P=16), this is small. But for 512×512 with P=8: N=4096, and the attention matrix is 4096 x 4096 x 12 heads x 4 bytes ~ 800MB per batch item. FlashAttention computes attention in tiles without materializing the full matrix, reducing memory to O(N).
Quantization: ViT is generally harder to quantize to INT8 than CNNs due to the softmax in attention (numerically sensitive) and the wide dynamic range of LayerNorm. Use FP16/BF16 as the first optimization step. INT8 requires careful calibration.
Batch size sensitivity: ViT throughput degrades more than CNNs at small batch sizes (batch size 1–4). If serving single-image requests, CNN latency may be lower. Batch inference with dynamic batching (TorchServe, Triton) is important for ViT production deployment.
:::tip FlashAttention in Practice
Install flash-attn and pass attn_implementation='flash_attention_2' when loading HuggingFace ViT models, or use timm.create_model(..., attn_impl='flash'). This reduces memory and improves throughput with no accuracy change - it is a pure implementation optimization.
:::
Why Attention Changed Everything - The Intuition
To understand why ViT is architecturally different at a fundamental level, consider what happens when a convolutional layer processes an image.
A standard 3×3 convolution at any spatial position sees only a 3×3 neighborhood - 9 pixels out of potentially millions. To build up a receptive field that spans the whole image, a ResNet must stack dozens of convolutional layers, each one only slightly widening the view. A pixel in the top-left corner of the image cannot directly influence a pixel in the bottom-right corner until information has propagated through many layers. This is the locality bottleneck: global context is expensive to build and is built gradually.
In a ViT, the very first self-attention layer computes an (N+1) x (N+1) attention matrix - every patch against every other patch simultaneously. A patch of "grass" at the bottom of the image can directly attend to "sky" patches at the top in the first layer with zero intermediate steps. A patch of "cat fur" can immediately attend to the "cat face" patch on the other side of the image.
Think of it this way:
CNN receptive field growth - needs 5+ layers to see across the image:
Layer 1: [x] sees 3x3 region
Layer 2: [ x ] sees 5x5 region
Layer 3: [ x ] sees 9x9 region
Layer 5: [ x ] still local-ish
...
Layer 50: finally sees global context
ViT - sees everything in layer 1:
Patch[grass] ──────────────────────────────► Patch[sky]
Patch[car] ──────────────────────────────► Patch[road]
Patch[eye] ──────────────────────────────► Patch[face]
All attending to all. Layer 1.
This has three concrete consequences:
1. Long-range dependencies are free. In CNNs, capturing a dependency between two distant pixels requires depth proportional to their distance. In ViT, any two patches can interact in layer 1. This matters for tasks where far-apart image regions are semantically related - for example, recognizing that a "person" patch and a "bicycle" patch together mean "cyclist," even when they appear at opposite sides of the image.
2. Shape bias vs. texture bias. CNNs are strongly biased toward local texture because convolutions operate on small neighborhoods where texture dominates. ViT, operating globally, learns to use shape and structural relationships between distant patches. This is why ViT representations are more robust to texture perturbations and adversarial attacks - the model never over-relies on local texture.
3. Attention is interpretable by design. The attention weight from the [CLS] token to each patch directly tells you which patches the model considers most relevant for its prediction. CNNs produce Grad-CAM saliency maps only by backpropagating gradients - an indirect proxy. ViT's attention weights are a direct signal. In the first few layers, attention is fairly diffuse (all patches get some weight). In later layers, attention concentrates sharply on semantically meaningful regions - object boundaries, faces, distinctive textures.
:::note The Inductive Bias Trade-off Restated CNNs win when data is scarce precisely because their inductive biases (locality, translation equivariance) are good priors for natural images. ViT wins when data is abundant precisely because it has no such priors - it learns whatever structure is actually in the data, including non-local structure that CNNs can only approximate through depth. This is not a flaw in either architecture; it is a data-regime trade-off. :::
Attention Maps Visualization
Visualizing what a ViT attends to is one of its most compelling properties. The raw attention weights from a single layer and head are noisy and hard to interpret directly. The standard technique is attention rollout (Abnar & Zuidema, 2020), which propagates attention through all layers to produce a single attribution map from input patches to the [CLS] token.
How rollout works: At each layer, the attention weight matrix A_l has shape (N+1) x (N+1). Rollout accounts for residual connections by averaging the attention matrix with the identity: A_l_hat = 0.5 * A_l + 0.5 * I. The final attribution is the product of all these matrices from layer 1 to layer L:
Rollout = A_L_hat @ A_(L-1)_hat @ ... @ A_1_hat
The row corresponding to the [CLS] token in the final rollout matrix gives the attribution of each patch to the [CLS] prediction - a heatmap over the image.
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import timm
from PIL import Image
import torchvision.transforms as T
def get_attention_maps(model, x: torch.Tensor):
"""
Extract per-layer, per-head attention weights from a timm ViT model.
Returns list of tensors, each shape (B, num_heads, N+1, N+1).
"""
attn_weights = []
def hook_fn(module, input, output):
# timm's Attention module returns (x, attn_weights) when return_attn=True
# We re-compute attention weights from Q and K stored during forward
attn_weights.append(module._attn_weights)
hooks = []
for block in model.blocks:
# Monkey-patch to store attention weights
original_forward = block.attn.forward
def make_patched_forward(orig):
def patched_forward(x):
B, N, D = x.shape
qkv = block.attn.qkv(x).reshape(
B, N, 3, block.attn.num_heads, D // block.attn.num_heads
)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
scale = (D // block.attn.num_heads) ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
attn_weights.append(attn.detach().cpu())
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
return block.attn.proj(out)
return patched_forward
block.attn.forward = make_patched_forward(block.attn.forward)
with torch.no_grad():
_ = model(x)
return attn_weights # list of (B, heads, N+1, N+1)
def attention_rollout(attn_weights: list, head_fusion: str = "mean") -> np.ndarray:
"""
Compute attention rollout from a list of per-layer attention tensors.
Args:
attn_weights: list of (1, num_heads, N+1, N+1) tensors
head_fusion: 'mean', 'max', or 'min' across heads
Returns:
rollout: (N+1,) array - CLS-to-patch attribution
"""
result = torch.eye(attn_weights[0].shape[-1]) # identity: (N+1, N+1)
for attn in attn_weights:
# Fuse across heads
if head_fusion == "mean":
attn_fused = attn[0].mean(dim=0) # (N+1, N+1)
elif head_fusion == "max":
attn_fused = attn[0].max(dim=0).values
else:
attn_fused = attn[0].min(dim=0).values
# Add residual connection and normalize
attn_fused = 0.5 * attn_fused + 0.5 * torch.eye(attn_fused.shape[0])
attn_fused = attn_fused / attn_fused.sum(dim=-1, keepdim=True)
result = attn_fused @ result
# Return CLS token row - shape (N+1,), drop the CLS-to-CLS position
cls_attn = result[0, 1:] # (N,)
return cls_attn.numpy()
def visualize_attention(image_path: str, model_name: str = 'vit_base_patch16_224'):
"""Full pipeline: load image, extract rollout attention, overlay on image."""
model = timm.create_model(model_name, pretrained=True)
model.eval()
# Standard ViT preprocessing
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open(image_path).convert('RGB')
x = transform(img).unsqueeze(0) # (1, 3, 224, 224)
# Extract attention weights
attn_weights = get_attention_maps(model, x)
# Compute rollout
cls_attn = attention_rollout(attn_weights) # (196,) for 14x14 grid
# Reshape to 2D spatial map
grid_size = int(cls_attn.shape[0] ** 0.5) # 14 for ViT-B/16
attn_map = cls_attn.reshape(grid_size, grid_size)
# Upsample to original image size
attn_map_tensor = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) # (1, 1, 14, 14)
attn_upsampled = F.interpolate(
attn_map_tensor, size=(224, 224), mode='bilinear', align_corners=False
)[0, 0].numpy()
# Normalize for display
attn_upsampled = (attn_upsampled - attn_upsampled.min()) / (
attn_upsampled.max() - attn_upsampled.min() + 1e-8
)
# Plot
img_display = np.array(img.resize((224, 224)))
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_display)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(attn_upsampled, cmap='hot')
axes[1].set_title("Attention Rollout Map")
axes[1].axis('off')
axes[2].imshow(img_display)
axes[2].imshow(attn_upsampled, cmap='hot', alpha=0.5)
axes[2].set_title("Overlay")
axes[2].axis('off')
plt.tight_layout()
plt.savefig("attention_rollout.png", dpi=150, bbox_inches='tight')
plt.show()
return attn_upsampled
# Usage
# attn = visualize_attention("dog.jpg")
What you typically see in attention maps: In early layers (layers 1–4), attention is broadly distributed - the [CLS] token attends to patches everywhere, building a global summary. In middle layers (5–8), attention sharpens: the model begins distinguishing foreground from background, attending heavily to the main subject. In the final layers (9–12), attention is highly concentrated on the most discriminative regions - the dog's face, the car's logo, the text in a sign. Background patches receive near-zero attention weight.
This semantic sharpening across depth is emergent - ViT learns it purely from classification labels, with no pixel-level supervision. It is why ViT attention maps are used as weak supervision signals for segmentation tasks (DINO, self-supervised ViTs produce even cleaner attention maps than supervised ones).
:::tip Single-Head vs Multi-Head Attention Maps Different attention heads specialize in different aspects of the image. One head may attend to edges, another to textures, another to the object center. When computing rollout, fusing heads by "mean" gives a general attribution map. For interpretability research, visualize individual heads - head specialization is one of the most striking emergent behaviors in ViTs. :::
MAE - Masked Autoencoders
Masked Autoencoders (MAE, He et al., 2021) are how modern large-scale ViTs are pretrained. The core idea is disarmingly simple: mask 75% of image patches at random, pass only the visible 25% through the encoder, then use a lightweight decoder to reconstruct the original pixel values at the masked positions.
Why 75% masking? Natural images are highly spatially redundant. Neighboring patches look similar - predicting a masked patch from its visible neighbors is trivially easy if only 10–20% of patches are masked (the model just copies nearby pixels). Masking 75% forces the model to learn genuine semantic understanding: it must reason about the global context of what is likely to appear in a masked region, which requires building meaningful representations of objects, textures, and spatial relationships.
This contrasts with BERT-style masked language modeling, which uses only 15% masking - because language has far less local redundancy than images. Words cannot be predicted from surrounding characters the way patches can be predicted from surrounding pixels.
Architecture: MAE uses an asymmetric encoder-decoder design:
Visible patches (25%)
|
[ViT Encoder - full depth]
e.g., ViT-L: 24 layers
|
[Encoded visible patch tokens]
|
+ [Learnable mask tokens for 75% positions]
|
[MAE Decoder - shallow, 8 blocks]
|
[Reconstructed pixel values at masked positions]
|
MSE loss on masked patches only
The encoder only sees 25% of patches - a 3–4× speedup over encoding all patches. The decoder is lightweight (8 blocks vs. 24 for ViT-L). This asymmetry makes MAE pretraining very efficient: ViT-L can be pretrained in ~31 hours on 64 A100s.
Why MAE beats supervised ImageNet pretraining at scale: When scaling ViT to ViT-L and ViT-H, supervised ImageNet-1K pretraining saturates - 1.2M labeled images are not enough to fill a 300M or 600M parameter model. MAE needs no labels and can use any unlabeled image collection. The pretraining signal - predicting pixel values - is dense (loss computed on 75% of patches per image) and self-supervised (no human annotation required). MAE-pretrained ViT-L reaches 85.9% top-1 on ImageNet, outperforming supervised pretraining by a significant margin.
import torch
import torch.nn as nn
def random_masking(x: torch.Tensor, mask_ratio: float = 0.75):
"""
Randomly mask a fraction of patches.
Args:
x: patch embeddings, shape (B, N, D)
mask_ratio: fraction of patches to mask (0.75 = 75%)
Returns:
x_visible: unmasked patches, shape (B, N_visible, D)
mask: binary mask, shape (B, N), 1 = masked
ids_restore: indices to restore original order, shape (B, N)
"""
B, N, D = x.shape
n_keep = int(N * (1 - mask_ratio)) # number of patches to keep
# Random noise for shuffle - sort by noise to get random permutation
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1) # ascending: small values kept
ids_restore = torch.argsort(ids_shuffle, dim=1) # for restoring original order
# Keep only the first n_keep patches
ids_keep = ids_shuffle[:, :n_keep]
x_visible = torch.gather(x, dim=1,
index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Build binary mask: 0 = kept, 1 = masked
mask = torch.ones(B, N, device=x.device)
mask[:, :n_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_visible, mask, ids_restore
class MAEEncoder(nn.Module):
"""Encoder: standard ViT, but only processes visible patches."""
def __init__(self, patch_embed, blocks, norm, cls_token, pos_embed):
super().__init__()
self.patch_embed = patch_embed
self.blocks = blocks
self.norm = norm
self.cls_token = cls_token
self.pos_embed = pos_embed # shape (1, N+1, D)
def forward(self, x: torch.Tensor, mask_ratio: float = 0.75):
B = x.size(0)
x = self.patch_embed(x) # (B, N, D)
N = x.shape[1]
# Add positional embeddings (skip CLS position 0)
x = x + self.pos_embed[:, 1:, :]
# Mask: keep only visible patches
x_visible, mask, ids_restore = random_masking(x, mask_ratio)
# Prepend CLS token
cls = self.cls_token.expand(B, -1, -1)
cls = cls + self.pos_embed[:, :1, :]
x_visible = torch.cat([cls, x_visible], dim=1) # (B, 1 + N_visible, D)
# Encode visible patches
for block in self.blocks:
x_visible = block(x_visible)
x_visible = self.norm(x_visible)
return x_visible, mask, ids_restore
class MAEDecoder(nn.Module):
"""
Lightweight decoder: projects encoder output to decoder dim,
adds mask tokens for masked positions, reconstructs patches.
"""
def __init__(self, num_patches, encoder_dim=1024, decoder_dim=512,
decoder_depth=8, decoder_heads=16, patch_size=16, in_channels=3):
super().__init__()
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim)
)
self.decoder_blocks = nn.Sequential(*[
# Simplified: reuse TransformerBlock from earlier
nn.TransformerEncoderLayer(
d_model=decoder_dim, nhead=decoder_heads,
dim_feedforward=decoder_dim * 4, batch_first=True
)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
# Project to patch pixels: patch_size^2 * channels
self.pred_head = nn.Linear(decoder_dim, patch_size ** 2 * in_channels)
def forward(self, x_encoded: torch.Tensor, ids_restore: torch.Tensor):
B = x_encoded.size(0)
N_visible_plus_cls = x_encoded.size(1)
N_visible = N_visible_plus_cls - 1 # subtract CLS
N = ids_restore.size(1) # total patches
# Project encoder output to decoder dim
x = self.decoder_embed(x_encoded) # (B, 1+N_visible, decoder_dim)
# Expand mask tokens for all masked positions
n_masked = N - N_visible
mask_tokens = self.mask_token.expand(B, n_masked, -1) # (B, N_masked, D_dec)
# Concatenate visible tokens (without CLS) and mask tokens, then restore order
x_no_cls = x[:, 1:, :] # (B, N_visible, D_dec)
x_full = torch.cat([x_no_cls, mask_tokens], dim=1) # (B, N, D_dec)
x_full = torch.gather(
x_full, dim=1,
index=ids_restore.unsqueeze(-1).expand(-1, -1, x_full.size(-1))
) # restore original patch order: (B, N, D_dec)
# Re-prepend CLS and add positional embeddings
cls = x[:, :1, :]
x_full = torch.cat([cls, x_full], dim=1) # (B, N+1, D_dec)
x_full = x_full + self.decoder_pos_embed
# Decode
x_full = self.decoder_blocks(x_full)
x_full = self.decoder_norm(x_full)
# Predict pixels for all patches (remove CLS)
pred = self.pred_head(x_full[:, 1:, :]) # (B, N, patch_size^2 * C)
return pred
# MAE loss - MSE on masked patches only
def mae_loss(pred: torch.Tensor, target_patches: torch.Tensor, mask: torch.Tensor):
"""
pred: (B, N, patch_size^2 * C) - reconstructed patches
target_patches: (B, N, patch_size^2 * C) - original patch pixels (normalized)
mask: (B, N) - 1 = masked, 0 = visible
"""
loss = (pred - target_patches) ** 2 # (B, N, patch_size^2 * C)
loss = loss.mean(dim=-1) # (B, N) - per-patch MSE
loss = (loss * mask).sum() / mask.sum() # mean over masked patches only
return loss
:::tip MAE vs Supervised Pretraining - When to Use Which For ViT-B fine-tuning on a domain-specific dataset (medical imaging, satellite imagery, retail products), supervised ImageNet-21K pretraining is usually the better starting point - the labeled features transfer well. For ViT-L and ViT-H, or when you have access to large unlabeled domain data, MAE pretraining on your domain data will outperform ImageNet-supervised weights. MAE is also the standard approach for video understanding (VideoMAE) and point cloud learning. :::
Practical Fine-tuning Guide
Fine-tuning a pretrained ViT correctly requires several techniques working together. Below is a step-by-step recipe that reliably achieves strong performance across domains.
Step 1: Load Pretrained ViT from timm
import timm
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
# Choose your model size based on dataset size:
# ViT-S/16 for < 200K images, ViT-B/16 for 200K–2M, ViT-L/16 for > 2M
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=0, # remove head - we'll add our own
drop_path_rate=0.1, # stochastic depth (see Step 6)
)
# Get the encoder output dimension
encoder_dim = model.num_features # 768 for ViT-B
Step 2: Replace the Classification Head
num_classes = 10 # your task
# Option A: simple linear head (works well for large pretrained models)
model.head = nn.Linear(encoder_dim, num_classes)
# Option B: two-layer MLP head (useful when fine-tuning domain shift is large)
model.head = nn.Sequential(
nn.Linear(encoder_dim, encoder_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(encoder_dim // 2, num_classes),
)
nn.init.trunc_normal_(model.head[-1].weight, std=0.02)
nn.init.zeros_(model.head[-1].bias)
Step 3: Freeze Early Layers
When your dataset is small (< 50K images) or very different from ImageNet, freeze early layers to prevent catastrophic forgetting:
def freeze_vit_layers(model, freeze_up_to_block: int):
"""
Freeze embeddings and the first N transformer blocks.
freeze_up_to_block=6 means blocks 0-5 are frozen, 6-11 are trainable.
"""
# Always freeze embedding layers
for param in model.patch_embed.parameters():
param.requires_grad = False
model.cls_token.requires_grad = False
model.pos_embed.requires_grad = False
# Freeze early blocks
for i, block in enumerate(model.blocks):
if i < freeze_up_to_block:
for param in block.parameters():
param.requires_grad = False
# Always keep head trainable
for param in model.head.parameters():
param.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable/1e6:.1f}M / {total/1e6:.1f}M params")
# For small datasets: freeze first 8 of 12 blocks
freeze_vit_layers(model, freeze_up_to_block=8)
# For large datasets: fine-tune everything with LLRD
# freeze_vit_layers(model, freeze_up_to_block=0)
Step 4: Apply Layer-wise LR Decay
def build_vit_optimizer(model, base_lr: float = 5e-5,
weight_decay: float = 0.05,
decay_rate: float = 0.75) -> torch.optim.Optimizer:
"""
AdamW with layer-wise LR decay.
Embedding layers: lr * decay_rate^(L+1)
Block l: lr * decay_rate^(L-l)
Head: lr (full)
"""
param_groups = []
num_layers = len(model.blocks)
# Patch embed + positional embeddings + CLS token
embed_params = (
list(model.patch_embed.parameters()) +
[model.cls_token, model.pos_embed]
)
param_groups.append({
'params': [p for p in embed_params if p.requires_grad],
'lr': base_lr * (decay_rate ** (num_layers + 1)),
'weight_decay': 0.0, # no WD on embeddings
})
# Transformer blocks
for i, block in enumerate(model.blocks):
block_lr = base_lr * (decay_rate ** (num_layers - i))
param_groups.append({
'params': [p for p in block.parameters() if p.requires_grad],
'lr': block_lr,
'weight_decay': weight_decay,
})
# Final norm + head
norm_head_params = (
list(model.norm.parameters()) +
list(model.head.parameters())
)
param_groups.append({
'params': [p for p in norm_head_params if p.requires_grad],
'lr': base_lr,
'weight_decay': weight_decay,
})
return torch.optim.AdamW(param_groups, betas=(0.9, 0.999), eps=1e-8)
Step 5: Cosine LR Schedule with Warmup
def get_cosine_schedule_with_warmup(optimizer, warmup_epochs: int,
total_epochs: int, min_lr_ratio: float = 0.01):
"""
Linear warmup for warmup_epochs, then cosine decay to min_lr_ratio * base_lr.
"""
def lr_lambda(epoch):
if epoch < warmup_epochs:
return float(epoch) / float(max(1, warmup_epochs))
progress = float(epoch - warmup_epochs) / float(
max(1, total_epochs - warmup_epochs)
)
cosine_factor = 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.14159)))
return max(min_lr_ratio, cosine_factor.item())
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
Step 6: ViT-Specific Training Tricks
Label smoothing (0.1): ViTs benefit more from label smoothing than CNNs because they tend to produce sharper logit distributions. Label smoothing regularizes overconfident predictions.
Stochastic depth (drop path): Randomly drops entire transformer blocks during training with probability p proportional to block depth. Deeper blocks (closer to output) have higher drop probability. timm's drop_path_rate=0.1 applies this automatically.
Repeated augmentation: Each image appears multiple times per batch with different augmentations. This effectively increases the number of unique gradient update directions per batch, improving training stability for ViTs. Set num_aug_repeats=3 in timm's data loader.
MixUp + CutMix: These augmentations are especially effective for ViTs. Use mixup_alpha=0.8 and cutmix_alpha=1.0 with a 50/50 switch probability. timm provides Mixup utility for this.
Step 7: Full Training Loop
import torch
import torch.nn.functional as F
from timm.loss import LabelSmoothingCrossEntropy
from timm.data import Mixup
from timm.utils import AverageMeter
def train_one_epoch(model, loader, optimizer, scheduler, device, epoch,
mixup_fn=None, label_smoothing: float = 0.1):
model.train()
criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
loss_meter = AverageMeter()
acc_meter = AverageMeter()
for step, (images, labels) in enumerate(loader):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# Apply MixUp / CutMix if enabled
if mixup_fn is not None:
images, labels = mixup_fn(images, labels)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=True): # BF16/FP16 mixed precision
logits = model(images)
loss = criterion(logits, labels)
# Gradient scaling for mixed precision
loss.backward()
# Gradient clipping - important for ViTs
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Update loss meter (use hard labels for accuracy even with MixUp)
loss_meter.update(loss.item(), images.size(0))
if mixup_fn is None:
acc = (logits.argmax(dim=1) == labels).float().mean().item()
acc_meter.update(acc, images.size(0))
scheduler.step()
return loss_meter.avg, acc_meter.avg
# --- Full setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_base_patch16_224', pretrained=True,
num_classes=10, drop_path_rate=0.1)
model = model.to(device)
optimizer = build_vit_optimizer(model, base_lr=5e-5, decay_rate=0.75)
scheduler = get_cosine_schedule_with_warmup(
optimizer, warmup_epochs=5, total_epochs=50
)
# MixUp + CutMix
mixup_fn = Mixup(
mixup_alpha=0.8,
cutmix_alpha=1.0,
prob=1.0,
switch_prob=0.5,
mode='batch',
label_smoothing=0.1,
num_classes=10,
)
# Training loop
for epoch in range(50):
loss, acc = train_one_epoch(
model, train_loader, optimizer, scheduler,
device, epoch, mixup_fn=mixup_fn
)
print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Acc: {acc:.4f}")
# Validate every 5 epochs
if epoch % 5 == 0:
model.eval()
with torch.no_grad():
val_correct = val_total = 0
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
preds = model(images).argmax(dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
print(f" Val Acc: {val_correct / val_total:.4f}")
:::tip Resolution Fine-tuning
For maximum accuracy, fine-tune at a higher resolution than pretraining (e.g., pretrain at 224, fine-tune at 384). Use timm.create_model('vit_base_patch16_384', pretrained=True) - timm automatically interpolates positional embeddings. Training at 384 adds ~30% latency but typically gains 1–2% accuracy. Do this only in the final fine-tuning stage.
:::
ViT in Production
When to Choose ViT vs CNN in Real Deployments
The architecture choice depends on four factors: dataset size, task type, latency budget, and hardware target.
| Scenario | Recommendation | Reason |
|---|---|---|
| < 50K labeled images, classification | EfficientNet-B3 or ResNet-50 | CNNs generalize better with limited data |
| 50K–500K images, classification | DeiT-S or DeiT-B (pretrained) | Competitive with CNNs, better long-range |
| > 500K images, classification | ViT-B/16 or ViT-L/16 | Superior representation learning at scale |
| Object detection / segmentation | Swin-T or Swin-B with FPN | Hierarchical features required |
| Medical imaging (CT, MRI) | ViT-B with MAE pretraining on medical data | Domain-specific pretraining dominates |
| Satellite / aerial imagery | Swin or ViT with large patch size | High resolution, spatial structure matters |
| Mobile / edge device | MobileViT-S or EfficientNet-Lite | ViT-B too large for on-device inference |
| Real-time video (< 10ms/frame) | MobileNetV3 or YOLOv8n | ViT latency too high even with optimization |
| Multi-modal (vision + language) | ViT encoder (CLIP-pretrained) | ViT is the standard vision encoder in VLMs |
Memory Requirements at Different Batch Sizes
Memory usage includes model weights, activations (proportional to batch size), and optimizer states (during training). At inference, only model weights and activations matter.
Inference memory (FP16, forward pass only):
| Model | Params | B=1 | B=8 | B=32 | B=64 | Throughput (A100) |
|---|---|---|---|---|---|---|
| ResNet-50 | 25M | ~0.5 GB | ~0.7 GB | ~1.2 GB | ~2.0 GB | ~3500 img/s |
| ViT-S/16 | 22M | ~0.6 GB | ~0.9 GB | ~1.8 GB | ~3.2 GB | ~1800 img/s |
| ViT-B/16 | 86M | ~0.9 GB | ~1.5 GB | ~3.5 GB | ~6.5 GB | ~900 img/s |
| ViT-L/16 | 307M | ~2.0 GB | ~4.5 GB | ~11 GB | OOM | ~280 img/s |
| Swin-T | 28M | ~0.7 GB | ~1.1 GB | ~2.5 GB | ~4.5 GB | ~1500 img/s |
| Swin-B | 88M | ~1.0 GB | ~2.0 GB | ~5.0 GB | ~9.5 GB | ~700 img/s |
:::warning High-Resolution Memory Explosion
For input resolution 512×512 with ViT-B/16: N = 1024 patches. The attention matrix is 1024 x 1024 x 12 heads x 2 bytes (FP16) = 25 MB per image. At batch size 32, attention matrices alone consume ~800 MB. Use FlashAttention or switch to Swin for high-resolution inputs.
:::
ONNX Export
ONNX export allows running ViT on non-PyTorch runtimes (ONNX Runtime, TensorRT, OpenVINO, Core ML):
import torch
import timm
import onnx
import onnxruntime as ort
import numpy as np
def export_vit_to_onnx(model_name: str = 'vit_base_patch16_224',
num_classes: int = 10,
output_path: str = 'vit_b16.onnx',
opset_version: int = 17):
"""Export a timm ViT to ONNX with dynamic batch size."""
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
output_path,
opset_version=opset_version,
input_names=['input'],
output_names=['logits'],
dynamic_axes={
'input': {0: 'batch_size'},
'logits': {0: 'batch_size'},
},
do_constant_folding=True,
)
print(f"Exported to {output_path}")
# Validate ONNX model
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX model check passed")
# Benchmark ONNX Runtime vs PyTorch
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(
output_path,
sess_options=sess_options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
dummy_np = dummy_input.numpy()
ort_inputs = {'input': dummy_np}
ort_outputs = session.run(None, ort_inputs)
print(f"ONNX Runtime output shape: {ort_outputs[0].shape}")
return session
session = export_vit_to_onnx(num_classes=10)
TensorRT Optimization
TensorRT (NVIDIA) provides the highest throughput for ViT on NVIDIA GPUs, typically 2–4× faster than ONNX Runtime with FP16 and up to 6× with INT8:
import tensorrt as trt
import torch
import numpy as np
def build_trt_engine(onnx_path: str, engine_path: str,
use_fp16: bool = True, max_batch_size: int = 32):
"""
Build TensorRT engine from ONNX model.
Requires: tensorrt >= 8.6, CUDA toolkit installed.
"""
logger = trt.Logger(trt.Logger.WARNING)
with trt.Builder(logger) as builder, \
builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
) as network, \
trt.OnnxParser(network, logger) as parser:
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * (1 << 30)) # 4GB
if use_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print("Building FP16 TensorRT engine...")
# Parse ONNX
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(f"ONNX parse error: {parser.get_error(error)}")
raise RuntimeError("ONNX parsing failed")
# Dynamic batch size profile
profile = builder.create_optimization_profile()
profile.set_shape(
'input',
min=(1, 3, 224, 224),
opt=(8, 3, 224, 224),
max=(max_batch_size, 3, 224, 224)
)
config.add_optimization_profile(profile)
# Build and serialize engine
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
raise RuntimeError("TensorRT engine build failed")
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
print(f"TensorRT engine saved to {engine_path}")
def run_trt_inference(engine_path: str, input_tensor: torch.Tensor) -> np.ndarray:
"""Run inference with a serialized TensorRT engine."""
logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f:
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
context.set_input_shape('input', input_tensor.shape)
output_shape = (input_tensor.shape[0], engine.get_tensor_shape('logits')[-1])
output = torch.empty(output_shape, dtype=torch.float32, device='cuda')
context.execute_v2(bindings=[
input_tensor.data_ptr(),
output.data_ptr()
])
return output.cpu().numpy()
# Usage
# build_trt_engine('vit_b16.onnx', 'vit_b16_fp16.trt', use_fp16=True)
# logits = run_trt_inference('vit_b16_fp16.trt', images_cuda)
Typical speedup from optimization pipeline on A100 (ViT-B/16, batch=8, 224×224):
| Runtime | Precision | Latency (ms) | Throughput (img/s) |
|---|---|---|---|
| PyTorch eager | FP32 | 28 ms | 285 |
| PyTorch + FlashAttention | FP16 | 12 ms | 667 |
| ONNX Runtime | FP16 | 10 ms | 800 |
| TensorRT | FP16 | 6 ms | 1333 |
| TensorRT | INT8 | 4 ms | 2000 |
:::danger INT8 Quantization Accuracy Drop ViT is more sensitive to INT8 quantization than CNNs. The softmax in attention has a wide dynamic range, and LayerNorm statistics can overflow INT8. Always measure accuracy after INT8 quantization. A drop of > 1% top-1 accuracy is common without careful calibration. Use post-training quantization with at least 500 calibration images representative of your deployment distribution. Consider INT8 only if FP16 throughput is insufficient for your SLA. :::
:::note Serving ViT in Production For serving ViT at scale, use NVIDIA Triton Inference Server with TensorRT backend. Triton handles dynamic batching (accumulates requests up to a target batch size or timeout), model versioning, and multi-GPU load balancing. A typical production setup: Triton + TensorRT FP16 + dynamic batching window of 10ms achieves 90%+ GPU utilization while keeping p99 latency under 50ms for ViT-B/16 at 224×224. :::
Interview Q&A
Q: Explain the ViT patch embedding and why it is equivalent to a Conv2d.
A patch embedding divides the input image into N = H*W/P^2 non-overlapping patches of size P x P, flattens each to a vector of P^2 * C elements, and linearly projects to model dimension D. This is mathematically identical to a Conv2d with kernel_size=patch_size, stride=patch_size, out_channels=D - each filter position sees exactly one non-overlapping patch. Using Conv2d is the efficient implementation because it is highly optimized on modern hardware (CUDA kernels, cuDNN).
Q: Why does ViT need positional encodings, and why do 1D learned positions work despite the image being 2D?
Without positional encodings, the transformer's self-attention is permutation-equivariant - shuffling the patch order would produce the same output (up to permutation). The model must know the spatial position of each patch to reason about spatial relationships. Learned 1D positions work because the model learns to encode 2D spatial structure within the 1D embedding: neighboring positions in the 1D sequence (which correspond to spatially adjacent patches in row-major order) end up with similar embeddings. The model effectively learns a 2D-aware representation through the 1D learnable vectors.
Q: What is the computational complexity of self-attention in ViT, and how does Swin Transformer address it?
Standard ViT self-attention has O(N^2) complexity where N = H*W/P^2 is the number of patches - quadratic in image resolution. For high-resolution inputs this becomes prohibitive. Swin Transformer computes attention within local non-overlapping windows of size M x M rather than globally, reducing complexity to O(N * M^2) - linear in image size for fixed window size M. Shifted windows in alternating layers allow cross-window information exchange over multiple layers.
Q: Why does ViT underperform CNNs on small datasets, and what does DeiT do to fix this?
CNNs have inductive biases - local connectivity, weight sharing, translation equivariance - that match the structure of natural images. These biases allow CNNs to generalize from relatively few examples. ViT has no such biases; it must learn spatial structure from data. On small datasets, there is not enough data for ViT to learn these structures, so it underperforms. DeiT fixes this through distillation from a pretrained CNN teacher via a distillation token. The CNN teacher implicitly transfers its inductive biases (local texture sensitivity, translation invariance) to the ViT student, enabling strong performance with ImageNet-1K alone.
Q: What is layer-wise learning rate decay and why is it used for ViT fine-tuning?
Layer-wise learning rate decay assigns smaller learning rates to earlier transformer layers and larger learning rates to later layers, with the classification head receiving the full base learning rate. The multiplier follows lr_l = lr_base * gamma^(L-l) with gamma ≈ 0.75. The rationale: early layers learn universal low-level features (edges, textures) during pretraining that remain useful across domains - large updates would destroy this. Later layers encode task-specific features that require more adaptation to the new task. LLRD typically adds +0.5 to +1.5% accuracy on fine-tuning benchmarks compared to a uniform learning rate.
Q: What is MAE and why is it preferred over supervised ImageNet pretraining for large ViTs?
MAE (Masked Autoencoders, He et al. 2021) masks 75% of image patches at random and trains the model to reconstruct the masked patches from only the visible 25%. The high masking ratio forces the model to learn genuine semantic understanding rather than copying nearby patches. For large models (ViT-L, ViT-H), MAE outperforms supervised ImageNet pretraining because 1.2M labeled images are insufficient to fill a 300M+ parameter model, while MAE can leverage any unlabeled image data at scale. MAE-pretrained ViT-L reaches 85.9% top-1 on ImageNet vs. ~85.2% with supervised pretraining.
Q: How do you choose between ViT and CNN for a new computer vision project?
The key factors are dataset size, task type, and deployment target. For small datasets (< 50K images), CNNs generalize better due to built-in inductive biases. For large datasets or when using large-scale pretrained weights, ViT produces better representations. For detection and segmentation, Swin Transformer is preferred because it produces hierarchical feature maps compatible with FPN-based detectors. For edge or mobile deployment, CNNs (MobileNet, EfficientNet-Lite) are far superior - ViT-B is too large and slow for on-device inference. When building multimodal systems (vision + language), use a CLIP-pretrained ViT encoder as it aligns well with language model embedding spaces.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Transformer Attention demo on the EngineersOfAI Playground - no code required.
:::
