Transfer Learning and Fine-Tuning
Reading Time: ~40 min | Interview Relevance: Very High | Target Roles: MLE, Applied Scientist, MLOps Engineer
A startup is building a mobile app to detect plant diseases. Farmers photograph their crops; the app diagnoses the disease and suggests treatment. The dataset: 800 labeled images across 15 disease classes - generously collected by agronomists over six months.
The first ML engineer trains ResNet-50 from scratch. After 100 epochs, validation accuracy is 62%. The model barely learns - 800 images is nowhere near enough for a 25-million-parameter network. The engineer adds more augmentation, tunes the learning rate, tries smaller architectures. Nothing gets past 68%.
A senior ML engineer joins. She loads ResNet-50 with ImageNet pretrained weights, replaces only the final classification layer for 15 classes, and runs training for 20 epochs. Validation accuracy: 91%.
Same architecture. Same 800 images. Same hardware. 91% vs 62%.
The CTO, watching over her shoulder, asks the obvious question: "How is it possible that a model trained to recognize cats, cars, and hot dogs helps us identify plant diseases? What do those things have in common?"
The answer is one of the most important insights in modern deep learning - and understanding it properly will change how you approach every new ML problem.
What Does a Neural Network Actually Learn?
Before we can answer the CTO's question, we need to understand what ImageNet training actually produces in the weights of a CNN.
In 2013, Zeiler and Fergus published "Visualizing and Understanding Convolutional Networks" - the first systematic study of what CNNs learn layer by layer. They used deconvolution networks to project activations back to pixel space, revealing what patterns in the input triggered each neuron. What they found was startling in its structure.
Layer 1: Universal Edge Detectors
The first convolutional layer of an ImageNet-trained CNN learns filters that look like Gabor filters - oriented edge detectors and color blob detectors. These are nearly identical across:
- CNNs trained on ImageNet
- CNNs trained on CIFAR-10 (tiny images of objects)
- CNNs trained on face recognition datasets
- CNNs trained on medical X-rays
They're nearly identical because these are the fundamental patterns of any natural image. Images everywhere are made of edges, gradients, and color transitions at fine scale. A network that can't detect edges can't see anything.
Layer 1 filters (visualized as 11×11 patches):
//// ──── \\\\ ████ ░░░░
Oriented edges at Color patches
multiple angles (R, G, B, etc.)
Layer 2: Textures and Simple Shapes
The second layer combines layer 1 outputs to detect textures - grid patterns, curved edges, simple shapes, repetitive structures. A texture is what you'd see if you zoomed in on fabric, wood grain, skin, leaves, or circuit boards. Again, these patterns exist across all visual domains.
Layer 3: Object Parts
By layer 3, filters respond to recognizable structures: honeycomb patterns, wheels, faces in profile, text characters. These are more domain-specific but still broadly general.
Layer 4–5: High-Level Semantics
The deep layers become increasingly specific to the ImageNet task. They respond to dog breeds, car models, specific types of food. These are the features that don't transfer as cleanly.
Feature Generality by Layer:
Layer 1: ████████████████████ (99% general across all domains)
Layer 2: ████████████████████ (95% general)
Layer 3: ████████████████░░░░ (80% general)
Layer 4: ████████████░░░░░░░░ (60% general)
Layer 5: ████████░░░░░░░░░░░░ (40% general)
FC/Head: ████░░░░░░░░░░░░░░░░ (ImageNet-specific)
◄─ FREEZE ──────────────────► ◄─ FINE-TUNE ──►
The Key Insight: Visual Features Are Hierarchical and General
Here is why the plant disease model works. Plant leaves have edges (layer 1 detects them). Leaves have textures - the fine structure of veins, spots, discoloration (layer 2-3 detects them). The specific semantic meaning of "leaf rust" or "bacterial blight" lives in layers 4-5 and the classifier head, which the model learns from the 800 training images.
The first 80% of the network - the part that learns to see - transfers for free. Only the final interpretation of what those visual features mean needs to be learned from scratch on the target task.
Yosinski et al. (2014) "How transferable are features in deep neural networks?" formalized this experimentally. They transferred feature layers between networks trained on different ImageNet splits and showed that the first layers transfer nearly perfectly, while the last layers transfer poorly. The "transferability cliff" occurs around layers 4-5 for typical networks.
Why Transfer Works: The Feature Hierarchy Argument
Think of it this way. Learning to recognize plant diseases from scratch requires the network to simultaneously solve two problems:
- How to see: build edge detectors, texture analyzers, shape detectors from random pixel noise
- What diseases look like: learn the visual signature of each disease class
With 800 images, there's not enough signal to solve both problems at once. The gradient noise overwhelms the actual learning signal. Training diverges or gets stuck.
Transfer learning solves problem 1 for free. The ImageNet-pretrained backbone is an excellent, general-purpose visual feature extractor. You only need your 800 images to solve problem 2: teach the network's high-level reasoning about what those general visual features mean in the context of plant diseases.
The Source-Target Similarity Spectrum
How well transfer works depends on how similar the source domain (typically ImageNet) is to the target domain:
| Source | Target | Feature Match | Best Strategy |
|---|---|---|---|
| ImageNet | General object photos | Excellent | Feature extraction or light fine-tuning |
| ImageNet | Medical X-rays | Good (edges, textures) | Fine-tune top 2 layers + head |
| ImageNet | Satellite imagery | Moderate (shapes, textures) | Fine-tune top 3 layers + head |
| ImageNet | Fluorescence microscopy | Moderate | Full fine-tuning |
| ImageNet | Sonar images | Poor | Full fine-tuning or from scratch |
| ImageNet | Spectral / thermal images | Poor | Adapt input layer + full fine-tune |
The surprising entry is medical imaging. Doctors often object: "An ImageNet model trained on everyday objects shouldn't help with X-rays." But it does. The reason: radiologists are also looking for edges (bone boundaries), textures (lung tissue patterns), and shapes (tumor outlines) - just in grayscale. The low-level feature detectors transfer, even when the semantic domain is completely different.
Three Transfer Learning Strategies
Strategy 1: Feature Extraction (Frozen Backbone)
Freeze all pretrained layers completely. Train only the new classification head - typically a linear layer or small MLP attached to the backbone's output.
When to use:
- Very small dataset (< 1,000 samples total)
- Target domain is very similar to ImageNet (e.g., classifying different everyday objects)
- Minimal compute budget - only a tiny fraction of parameters need gradients
- Need a reliable baseline in 15 minutes of training
What happens during training: Only the head has requires_grad=True. The backbone is a frozen feature extractor - essentially a fixed transformation from image pixels to a 2048-dimensional vector (for ResNet-50). You're training a linear (or shallow) classifier on those features.
Risk profile: Very low overfitting risk (few free parameters). Limitation: the backbone cannot adapt to patterns specific to your domain. If there's a domain shift, accuracy plateaus early.
Typical result: Often surprisingly good - 80-90% of the benefit of full fine-tuning, at 1% of the training time.
Strategy 2: Fine-Tuning the Top Layers
Freeze early layers (layer 1-2 of ResNet), fine-tune the last N layer groups plus the new head.
When to use:
- Moderate dataset size (1,000–10,000 samples)
- Domain is similar to ImageNet but not identical (plant diseases, product photos, satellite imagery)
- Standard production setup for most fine-tuning tasks
Rationale: Early layers (edges, textures) are already optimal for any visual task - don't disturb them. Late layers are ImageNet-specific - these need to adapt to your domain. Unfreeze only what needs to change.
Learning rate structure: Different rates per layer group. Newly initialized head: large LR (1e-3). Last pretrained block: medium LR (1e-4). Earlier blocks: small LR (1e-5). This prevents catastrophic forgetting of general features while allowing domain-specific adaptation.
Strategy 3: Full Fine-Tuning
Unfreeze all layers. Train the entire network end-to-end with a small learning rate.
When to use:
- Large target dataset (10,000+ samples)
- Source and target domains are significantly different
- You have enough compute and training time
Critical warning: Full fine-tuning with a uniform large learning rate will destroy pretrained representations. The randomly initialized head needs a large update; the pretrained layers should barely move. Always use:
- Very small LR for pretrained layers: 1e-5 to 1e-6
- Warmup for the first few epochs: start at 10% of target LR, ramp up linearly
- Gradient clipping: max norm of 1.0
Catastrophic forgetting is the primary failure mode of full fine-tuning. When the learning rate is too large, the gradient signal from the new classification task overwrites the general visual representations learned during ImageNet training. You end up with a model that fits the training set but has worse generalization than a properly fine-tuned model. Symptoms: training accuracy climbs fast, but validation accuracy plateaus lower than expected.
Layer-wise Learning Rate Decay
This is one of the most important practical techniques in transfer learning, and it's often omitted from tutorials.
The Problem
When you fine-tune a pretrained network with a single learning rate, you're making an implicit assumption that all layers should change at the same rate. This is wrong.
- Head (randomly initialized): needs a large LR to learn from scratch
- Last pretrained block: needs moderate adjustment for your domain
- First pretrained blocks: already learn universal edge detectors - should barely change
Using head LR everywhere causes early layers to lose their general features. Using early-layer LR everywhere means the head learns too slowly. You need different learning rates per layer.
The Implementation: PyTorch Parameter Groups
PyTorch's optimizer accepts a list of parameter groups, each with its own learning rate:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
def build_optimizer_with_lr_decay(
model: nn.Module,
base_lr: float = 1e-3,
decay_factor: float = 0.1,
weight_decay: float = 1e-4
) -> optim.Optimizer:
"""
Layer-wise learning rate decay for ResNet-50 fine-tuning.
Layer groups (shallowest to deepest):
- backbone early: conv1 + bn1 + layer1 + layer2
- backbone late: layer3 + layer4
- head: fc
LR multipliers: head=1.0, late=0.1, early=0.01
"""
# Group 1: early backbone (very small LR)
early_params = (
list(model.conv1.parameters()) +
list(model.bn1.parameters()) +
list(model.layer1.parameters()) +
list(model.layer2.parameters())
)
# Group 2: late backbone (moderate LR)
late_params = (
list(model.layer3.parameters()) +
list(model.layer4.parameters())
)
# Group 3: classification head (full LR)
head_params = list(model.fc.parameters())
param_groups = [
{
"params": early_params,
"lr": base_lr * decay_factor * decay_factor, # base_lr * 0.01
"weight_decay": weight_decay
},
{
"params": late_params,
"lr": base_lr * decay_factor, # base_lr * 0.1
"weight_decay": weight_decay
},
{
"params": head_params,
"lr": base_lr, # base_lr * 1.0
"weight_decay": weight_decay
},
]
return optim.AdamW(param_groups)
# Example: base_lr=1e-3 → early=1e-5, late=1e-4, head=1e-3
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, 15) # 15 disease classes
optimizer = build_optimizer_with_lr_decay(model, base_lr=1e-3)
print(f"LR by group: {[g['lr'] for g in optimizer.param_groups]}")
# [1e-05, 0.0001, 0.001]
This pattern is sometimes called discriminative learning rates - a term from ULMFiT (Howard & Ruder, 2018), which popularized the technique for NLP and it was quickly adopted in computer vision.
The Fine-Tuning Workflow: Step by Step
Here is the complete production-grade fine-tuning pipeline with two phases, as described at the start of this lesson.
Phase 1: Train the Head Only
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Tuple
def freeze_backbone(model: nn.Module) -> None:
"""Freeze all layers except the classification head."""
for name, param in model.named_parameters():
if "fc" not in name: # 'fc' is ResNet's classifier
param.requires_grad = False
def unfreeze_layer_group(model: nn.Module, layer_name: str) -> None:
"""Selectively unfreeze a named layer group."""
layer = getattr(model, layer_name)
for param in layer.parameters():
param.requires_grad = True
def count_trainable(model: nn.Module) -> Tuple[int, int]:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
def run_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: torch.device,
is_train: bool
) -> Tuple[float, float]:
"""Run one epoch of training or evaluation. Returns (loss, accuracy)."""
model.train() if is_train else model.eval()
total_loss, correct, total = 0.0, 0, 0
context = torch.enable_grad() if is_train else torch.no_grad()
with context:
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
if is_train:
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
if is_train:
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item() * len(labels)
correct += (logits.argmax(dim=1) == labels).sum().item()
total += len(labels)
return total_loss / total, correct / total
def fine_tune_two_phase(
num_classes: int,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
phase1_epochs: int = 10,
phase2_epochs: int = 15,
head_lr: float = 1e-3,
finetune_lr: float = 1e-4,
) -> nn.Module:
"""
Two-phase fine-tuning pipeline.
Phase 1: Freeze backbone, train head only.
Phase 2: Unfreeze top layers, fine-tune with discriminative LRs.
"""
# Load pretrained backbone
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Replace classification head for target task
in_features = model.fc.in_features # 2048 for ResNet-50
model.fc = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(in_features, num_classes)
)
model = model.to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# ─── Phase 1: Head Only ───────────────────────────────────────────
print("Phase 1: Training head only (backbone frozen)")
freeze_backbone(model)
total, trainable = count_trainable(model)
print(f" Trainable: {trainable:,} / {total:,} params ({trainable/total:.1%})")
optimizer_p1 = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=head_lr,
weight_decay=1e-4
)
scheduler_p1 = optim.lr_scheduler.CosineAnnealingLR(
optimizer_p1, T_max=phase1_epochs, eta_min=head_lr / 10
)
best_val_acc = 0.0
for epoch in range(phase1_epochs):
train_loss, train_acc = run_epoch(
model, train_loader, criterion, optimizer_p1, device, is_train=True
)
val_loss, val_acc = run_epoch(
model, val_loader, criterion, optimizer_p1, device, is_train=False
)
scheduler_p1.step()
print(f" Epoch {epoch+1:02d}: "
f"train={train_acc:.3f} val={val_acc:.3f} "
f"lr={scheduler_p1.get_last_lr()[0]:.2e}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_phase1.pt")
# ─── Phase 2: Unfreeze Top Layers ────────────────────────────────
print(f"\nPhase 2: Fine-tuning top layers (best phase1 val_acc={best_val_acc:.3f})")
model.load_state_dict(torch.load("best_phase1.pt"))
# Unfreeze layer3, layer4 (keep layer1, layer2, conv1 frozen)
unfreeze_layer_group(model, "layer3")
unfreeze_layer_group(model, "layer4")
total, trainable = count_trainable(model)
print(f" Trainable: {trainable:,} / {total:,} params ({trainable/total:.1%})")
# Discriminative learning rates: earlier layers get smaller LR
param_groups = [
{"params": model.layer3.parameters(), "lr": finetune_lr * 0.1},
{"params": model.layer4.parameters(), "lr": finetune_lr},
{"params": model.fc.parameters(), "lr": finetune_lr * 5},
]
optimizer_p2 = optim.AdamW(param_groups, weight_decay=1e-4)
# Warmup for 3 epochs, then cosine decay
warmup = optim.lr_scheduler.LinearLR(
optimizer_p2, start_factor=0.1, end_factor=1.0, total_iters=3
)
cosine = optim.lr_scheduler.CosineAnnealingLR(
optimizer_p2, T_max=phase2_epochs - 3, eta_min=1e-6
)
scheduler_p2 = optim.lr_scheduler.SequentialLR(
optimizer_p2, schedulers=[warmup, cosine], milestones=[3]
)
best_val_acc_p2 = 0.0
for epoch in range(phase2_epochs):
train_loss, train_acc = run_epoch(
model, train_loader, criterion, optimizer_p2, device, is_train=True
)
val_loss, val_acc = run_epoch(
model, val_loader, criterion, optimizer_p2, device, is_train=False
)
scheduler_p2.step()
print(f" Epoch {epoch+1:02d}: "
f"train={train_acc:.3f} val={val_acc:.3f}")
if val_acc > best_val_acc_p2:
best_val_acc_p2 = val_acc
torch.save(model.state_dict(), "best_phase2.pt")
print(f"\nFinal best val accuracy: {best_val_acc_p2:.3f}")
model.load_state_dict(torch.load("best_phase2.pt"))
return model
The Data Transforms: Non-Negotiable Normalization
# CRITICAL: Use ImageNet normalization statistics for all pretrained models.
# These are not arbitrary - they are the exact statistics of the ImageNet
# training set. Using different values means the backbone receives out-of-
# distribution inputs and its pretrained features are degraded.
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), # REQUIRED
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), # REQUIRED
])
Practical Considerations
Batch Normalization in Transfer Learning
Batch normalization has two sets of parameters:
- Learnable: scale () and shift () - trainable parameters
- Running statistics:
running_meanandrunning_var- accumulated during training, used during inference
When you freeze the backbone, you should also freeze the BatchNorm statistics by putting the backbone in eval mode. Otherwise, the running statistics update based on your small target dataset and pollute the pretrained normalization.
def freeze_backbone_with_bn(model: nn.Module) -> None:
"""Freeze backbone parameters AND fix BatchNorm statistics."""
for name, module in model.named_modules():
if "fc" in name:
continue # Don't touch the head
# Freeze parameters
for param in module.parameters(recurse=False):
param.requires_grad = False
# Fix BatchNorm in eval mode (freezes running stats)
if isinstance(module, nn.BatchNorm2d):
module.eval() # running_mean and running_var won't update
# Critical: override the training mode behavior for BN layers
def freeze_bn_on_train(m: nn.Module) -> None:
if isinstance(m, nn.BatchNorm2d) and not any(
"fc" in name for name, _ in m.named_parameters()
):
m.eval()
model.train() # Set model to train mode
model.apply(freeze_bn_on_train) # But force BN layers to eval
When you switch from Phase 1 (frozen backbone) to Phase 2 (unfreezing layers), remember to put the unfrozen BatchNorm layers back into training mode so their statistics can adapt to your domain.
Class Imbalance
Real-world datasets rarely have balanced class distributions. ImageNet has exactly 1,000 images per class. Your plant disease dataset might have 200 images of healthy leaves and 15 images of a rare blight.
import torch
from torch.utils.data import WeightedRandomSampler
from collections import Counter
def make_balanced_sampler(labels: list) -> WeightedRandomSampler:
"""
Create a sampler that oversamples rare classes and undersamples
common classes so each epoch sees roughly uniform class distribution.
"""
class_counts = Counter(labels)
total = len(labels)
# Weight for each sample = inverse frequency of its class
weights = [total / class_counts[label] for label in labels]
weights_tensor = torch.DoubleTensor(weights)
return WeightedRandomSampler(
weights=weights_tensor,
num_samples=len(labels),
replacement=True
)
# Usage with DataLoader
sampler = make_balanced_sampler(train_dataset.targets)
train_loader = DataLoader(
train_dataset,
batch_size=32,
sampler=sampler, # replaces shuffle=True
num_workers=4,
pin_memory=True
)
For severe imbalance, combine weighted sampling with label smoothing and focal loss.
Pretrained Model Sources
| Source | Best For | How to Use |
|---|---|---|
torchvision.models | Standard architectures (ResNet, EfficientNet, ConvNeXt, ViT) | models.resnet50(weights=ResNet50_Weights.DEFAULT) |
timm (PyTorch Image Models) | 700+ architectures, all well-maintained | timm.create_model("resnet50", pretrained=True) |
HuggingFace transformers | ViT, Swin, CLIP, DINO, and vision-language models | AutoModel.from_pretrained("google/vit-base-patch16-224") |
HuggingFace timm integration | Access timm models through HF hub | timm.create_model("hf_hub:timm/resnet50.a1_in1k", pretrained=True) |
timm is the go-to for production - it has the most architectures, the most up-to-date pretrained weights, and the features_only=True API for multi-scale feature extraction needed by detection/segmentation heads.
Data Augmentation for Transfer Learning
Augmentation is especially important in transfer learning because your dataset is typically small. Well-chosen augmentations can be worth more than architectural choices or longer training.
What Works and Why
Augmentations that consistently help:
- RandomResizedCrop (scale 0.7–1.0): randomly crops and resizes to target size, teaching scale invariance. Keep the
scaleparameter above 0.5 to avoid cropping out the object entirely. - RandomHorizontalFlip: free invariance for most tasks. Skip it for tasks where orientation matters - medical scans with anatomical asymmetry, digit recognition.
- ColorJitter (brightness, contrast, saturation): teaches lighting invariance. Since ImageNet has a specific lighting distribution, ColorJitter forces the model to learn color-invariant features.
- RandomErasing: randomly removes a rectangle of the image, forcing the model to not rely on any single region. Effective regularizer.
AutoAugment (Cubuk et al., 2019) used reinforcement learning to search for the optimal augmentation policy for ImageNet. The discovered policy applies sequences of operations like Equalize, AutoContrast, Posterize, Solarize, and Rotate, each with a learned probability and magnitude.
RandAugment (Cubuk et al., 2020) simplified AutoAugment by uniformly sampling N augmentations from a fixed set, each applied at magnitude M. Two hyperparameters replace the complex policy search. N=2, M=9 is the standard starting point.
from torchvision import transforms
# RandAugment - the modern default for fine-tuning
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# AutoAugment with the ImageNet policy
train_transform_aa = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
Mixup creates convex combinations of training pairs: image = lambda * img1 + (1-lambda) * img2, with labels mixed proportionally. CutMix pastes a rectangular crop from one image into another. Both are standard in modern training recipes (ConvNeXt, EfficientNet-V2, DeiT) and consistently improve fine-tuning accuracy, especially on small datasets.
When NOT to Use Transfer Learning
Transfer learning is the right default, but three situations call for a different approach.
1. Domain is statistically incompatible. Spectrogram images (audio as time-frequency maps), sonar and radar imagery, purely synthetic geometric data - these have fundamentally different statistics from natural images. ImageNet edge detectors are not appropriate. Consider self-supervised pretraining on your own unlabeled data (SimCLR, DINO, MoCo) or training from scratch.
2. You have 1M+ domain-specific labeled images. At very large scale, training from scratch on your target domain may outperform ImageNet transfer. Domain-specific pretraining matches the distribution exactly and avoids ImageNet biases. The crossover point depends on domain similarity - for natural photos rarely below 5M images; for medical imaging around 500K. Always run a scratch-training baseline at this scale before assuming transfer wins.
3. Model size is severely constrained. ResNet-50 has 25.6M parameters; EfficientNet-B0 has 5.3M. If deployment requires under 500K parameters (some MCU targets), there may not be an appropriate pretrained architecture. A tiny custom model with knowledge distillation from a large teacher is often better.
Even in these three cases, always try transfer learning first as a baseline - it frequently outperforms expectations even across large domain gaps.
Popular Pretrained Models - How to Choose
| Model | Params | Top-1 (ImageNet) | Best For |
|---|---|---|---|
| MobileNetV3-Small | 2.5M | 67.7% | Mobile, edge, MCU |
| EfficientNet-B0 | 5.3M | 77.1% | Mobile-friendly server |
| ResNet-18 | 11.7M | 69.8% | Lightweight baseline |
| DenseNet-121 | 8M | 74.4% | Medical, small datasets |
| ResNet-50 | 25.6M | 80.9%* | Universal production backbone |
| EfficientNet-B3 | 12M | 81.6% | Accuracy/size sweet spot |
| EfficientNet-B4 | 19.3M | 82.9% | Accuracy-focused production |
| ConvNeXt-Tiny | 28.6M | 82.1% | Modern CNN baseline |
| ViT-B/16 | 86.6M | 85.4%* | High accuracy, GPU only |
*With modern training recipes (IMAGENET1K_V2 for ResNet-50, DeiT recipe for ViT-B/16).
Decision guide for transfer learning specifically:
- Under 50ms CPU: MobileNetV3-Large or EfficientNet-B0
- Accuracy + reasonable cost: EfficientNet-B3 or ConvNeXt-Tiny
- Medical imaging: DenseNet-121 (feature reuse on small data) or ResNet-50
- Detection/segmentation: ResNet-50 or ResNet-101 (widest head support)
- Maximum accuracy, GPU available: ViT-B/16 or ConvNeXt-Base
import timm
# timm - consistent access to 700+ pretrained models
for model_name in ["resnet50", "efficientnet_b4", "convnext_tiny", "densenet121"]:
model = timm.create_model(model_name, pretrained=False)
n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{model_name:<30} {n_params:.1f}M params")
# Always check the expected input size - EfficientNet-B4 uses 380x380, not 224!
model = timm.create_model("efficientnet_b4", pretrained=True, num_classes=0)
cfg = model.default_cfg
print(f"Input size: {cfg['input_size']}") # (3, 380, 380)
print(f"Mean: {cfg['mean']}") # (0.485, 0.456, 0.406)
Domain Adaptation
Transfer learning assumes the pretrained features are useful. What happens when there's a significant distribution gap between ImageNet and your target domain?
Recognizing a Domain Gap
Signs that your domain has a gap with ImageNet:
- Images have different statistics (grayscale, unusual color distributions)
- Different spatial scales (microscopy, aerial imagery)
- Domain-specific textures not represented in natural images
Simple Fixes First
Before reaching for complex domain adaptation methods, try:
-
Aggressive augmentation during fine-tuning: RandomErasing, CutMix, MixUp. This forces the model to rely on robust features rather than texture shortcuts.
-
Smaller LR everywhere: If the domain is very different, use LRs 5–10× smaller than normal. The pretrained features still provide initialization warmth; you just need more time to adapt.
-
Longer warmup: Extend the warmup phase to 10-20% of total training. This prevents large early updates from destroying pretrained representations.
-
Adapt the input layer for non-RGB images (see below).
Adapting the First Layer for Non-RGB Images
ImageNet models expect 3-channel RGB input. Medical images are often grayscale. Satellite images may have 4-12 spectral bands.
import torch
import torch.nn as nn
import torchvision.models as models
def adapt_first_conv(
model: nn.Module,
in_channels: int,
) -> nn.Module:
"""
Replace the first conv layer to accept `in_channels` input channels,
preserving as much pretrained knowledge as possible.
Strategy:
- 1 channel: average the 3 pretrained RGB filters → 1 filter
- 2 channels: take first 2 of the 3 pretrained filters
- 4+ channels: copy 3 pretrained filters, initialize extras with their mean
"""
old_conv = model.conv1
old_weight = old_conv.weight.data # shape: (64, 3, 7, 7)
new_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None
)
with torch.no_grad():
if in_channels == 1:
# Average across the 3 RGB channels
new_conv.weight.data = old_weight.mean(dim=1, keepdim=True)
elif in_channels < 3:
# Use the first `in_channels` filters
new_conv.weight.data = old_weight[:, :in_channels, :, :]
else:
# in_channels > 3: copy RGB weights, pad extras with channel mean
mean_weight = old_weight.mean(dim=1, keepdim=True) # (64, 1, 7, 7)
new_weight = mean_weight.expand(-1, in_channels, -1, -1).clone()
new_weight[:, :3, :, :] = old_weight # Restore original RGB channels
new_conv.weight.data = new_weight
model.conv1 = new_conv
return model
# Grayscale medical images
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model = adapt_first_conv(model, in_channels=1)
x = torch.randn(4, 1, 224, 224)
out = model.conv1(x)
print(out.shape) # torch.Size([4, 64, 112, 112]) ✓
# 4-channel satellite (RGB + near-infrared)
model4 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model4 = adapt_first_conv(model4, in_channels=4)
x4 = torch.randn(4, 4, 224, 224)
print(model4.conv1(x4).shape) # torch.Size([4, 64, 112, 112]) ✓
Advanced Domain Adaptation
When simple fixes aren't enough, two techniques deserve mention:
DANN (Domain-Adversarial Neural Networks): Train a feature extractor to be simultaneously good at your task and bad at distinguishing source from target domain. A gradient reversal layer makes the feature distribution domain-invariant.
MMD loss (Maximum Mean Discrepancy): Directly minimize the statistical distance between source and target feature distributions. Implemented as an additional loss term during fine-tuning.
Both require unlabeled target-domain data, which is usually easy to obtain. They're worth the engineering effort when you have a severe domain gap and collecting more labeled data is expensive.
Evaluating Transfer Quality: The Linear Probe
Before committing to a fine-tuning strategy, you can quickly assess how useful the pretrained features are for your task - without any fine-tuning at all.
Linear probe: freeze the backbone entirely, extract features for your entire training set, train a logistic regression (or single linear layer) on those features.
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np
def linear_probe(
backbone: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device
) -> float:
"""
Assess pretrained feature quality via linear probe.
Returns validation accuracy of a linear classifier on frozen features.
"""
backbone.eval()
def extract_features(loader):
feats, labs = [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
features = backbone(images) # (B, D) - backbone with no head
feats.append(features.cpu().numpy())
labs.append(labels.numpy())
return np.concatenate(feats), np.concatenate(labs)
print("Extracting features...")
train_feats, train_labels = extract_features(train_loader)
val_feats, val_labels = extract_features(val_loader)
# Normalize features (important for linear classifiers)
scaler = StandardScaler()
train_feats = scaler.fit_transform(train_feats)
val_feats = scaler.transform(val_feats)
# Train logistic regression
clf = LogisticRegression(max_iter=1000, C=1.0)
clf.fit(train_feats, train_labels)
return clf.score(val_feats, val_labels)
# Usage: strip the classification head from the backbone first
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
backbone.fc = nn.Identity() # Remove head - output is 2048-dim feature vector
probe_acc = linear_probe(backbone, train_loader, val_loader, device)
print(f"Linear probe accuracy: {probe_acc:.3f}")
Interpreting the linear probe:
| Linear Probe Accuracy | Interpretation |
|---|---|
| > 85% | Pretrained features are excellent. Feature extraction alone may suffice. |
| 70–85% | Good features. Light fine-tuning (top 1-2 layers) will improve significantly. |
| 50–70% | Moderate domain gap. Fine-tune 3-4 layers with discriminative LRs. |
| < 50% | Large domain gap. Full fine-tuning required; consider from-scratch if you have enough data. |
The fine-tuning gap: the difference between fine-tuning accuracy and linear probe accuracy tells you how much adaptation potential remains. If linear probe is 82% and full fine-tuning is 84%, the features are already excellent and fine-tuning gives marginal gains. If linear probe is 60% and fine-tuning reaches 88%, the backbone is adapting substantially and the fine-tuning is essential.
Architecture Selection for Transfer Learning
Not all architectures fine-tune equally well. Some practical guidance:
For almost all production fine-tuning tasks, start with ResNet-50. It has:
- The most thoroughly characterized fine-tuning behavior across domains
- The widest compatibility with detection/segmentation heads in frameworks like Detectron2 and MMDetection
- Well-understood failure modes
- Pretrained weights from multiple providers with different recipes
Once you have a ResNet-50 baseline, upgrade to EfficientNet-B4 or ConvNeXt if you need better accuracy and can accept slightly more engineering overhead.
Common Pitfalls
Large learning rate for pretrained layers. The single most common mistake. Randomly initialized head needs LR ~ 1e-3; pretrained layers need LR ~ 1e-4 to 1e-5. Using a uniform high LR destroys pretrained representations in the first few batches. If you see training accuracy shoot up in epoch 1 followed by poor generalization, check your learning rates.
Missing ImageNet normalization. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] is not optional for pretrained models. These values come from computing the per-channel statistics of the entire ImageNet training set. A model expects inputs in this range. If you feed un-normalized or differently normalized images, the activations in the first few layers are entirely out of distribution and the pretrained features provide little benefit.
Evaluating in train mode. BatchNorm uses batch statistics during training. Always call model.eval() before validation. This is especially insidious because the bug only shows up as slightly inconsistent validation metrics - not a crash.
Not loading the best checkpoint. Always save and reload the best validation checkpoint, not the final epoch. Overfitting often starts before the last epoch.
Forgetting to unfreeze BatchNorm statistics. When transitioning from Phase 1 to Phase 2, if you unfreeze layer3 and layer4 but leave their BatchNorm layers in eval mode, the running statistics don't update and you get suboptimal adaptation. After unfreezing, call model.train() to restore all modules to training mode.
Interview Q&A
Q1: Why do ImageNet-pretrained features transfer to domains like medical imaging, when the semantic content is completely different?
CNNs learn a hierarchical feature representation. Early layers learn low-level features - oriented edge detectors, color blobs, texture patterns - that are universal across all natural images because they reflect fundamental properties of how images are formed (gradients, local intensity patterns, frequency content). These same features are useful for detecting edges in X-rays, texture in histology slides, or shapes in satellite imagery. The domain mismatch occurs primarily in the high-level semantic features (layers 4-5 and classifier), not in the early feature detectors. Transfer works because ~80% of the network learns to "see" - detecting visual primitives that are domain-agnostic - and only ~20% learns to "interpret" in an ImageNet-specific way. Fine-tuning the interpretation layers on the target domain reuses the visual perception machinery built by ImageNet training.
Q2: What are the three main fine-tuning strategies and when would you use each?
(1) Feature extraction (frozen backbone): freeze all pretrained layers, train only the new head. Use when dataset is very small (< 1,000 samples) or domain is very similar to ImageNet. Fast, low overfitting risk, but limited adaptation. (2) Top-layer fine-tuning: freeze early layers, unfreeze last 2-3 layer groups plus the new head with discriminative learning rates (earlier layers get 10-100× smaller LR). Use for moderate datasets (1,000–10,000 samples). Best accuracy/efficiency tradeoff in most cases. (3) Full fine-tuning: unfreeze everything with a very small uniform LR (1e-5) and warmup. Use when dataset is large (> 10,000 samples) or domain is significantly different. Risk: catastrophic forgetting if LR is too large or warmup is insufficient.
Q3: What is layer-wise learning rate decay (discriminative learning rates) and why does it matter?
Different layers in a pretrained network should change at different rates during fine-tuning. The randomly initialized classification head needs a large LR (e.g., 1e-3) to learn from scratch. The last pretrained layer group needs a moderate LR (e.g., 1e-4) to adapt to the new domain. Early layers - which already learn universal visual features - need a tiny LR (e.g., 1e-5) to preserve their general representations. Implementing this in PyTorch means passing a list of {"params": ..., "lr": ...} dicts to the optimizer, one per layer group. Using a uniform LR either destroys early layers (too large) or slows head learning unacceptably (too small). The typical decay factor is 0.1 per layer group as you go deeper from the head.
Q4: How do you handle fine-tuning when your input images are grayscale (1 channel) instead of the 3-channel RGB the pretrained model expects?
Adapt the first convolutional layer. The pretrained first conv has weights of shape (out_channels, 3, kernel_h, kernel_w). For single-channel input, average across the 3 input channels: new_weight = old_weight.mean(dim=1, keepdim=True). This gives a 1-channel conv that encodes the average response of the RGB filters. The rationale: grayscale images correspond roughly to a weighted average of R, G, and B channels, so averaging the pretrained RGB filters gives a reasonable initialization for grayscale. Don't randomly initialize the adapted first layer - you'd lose the learned edge detectors and force the network to re-learn low-level features from scratch.
Q5: What is catastrophic forgetting in fine-tuning and how do you prevent it?
Catastrophic forgetting happens when fine-tuning on a new task overwrites the pretrained general representations. The gradient signal from the new task propagates through all layers and, with a large learning rate, substantially updates even early layers that encoded universal visual features. After fine-tuning, the network has lost the general-purpose visual understanding and now only knows your specific task - it generalizes less well because the feature representations are less rich. Prevention: (1) Discriminative LRs - give early layers LR 100-1000× smaller than the head. (2) Start with frozen backbone (Phase 1) before any fine-tuning. (3) Warmup: start at 10% of target LR, ramp up linearly over the first 5-10 epochs. (4) Gradient clipping (max norm = 1.0). (5) For very small datasets, keep the backbone fully frozen throughout.
Production Notes
Model versioning and reproducibility: always record the exact pretrained weight version used, not just the architecture name. torchvision.models.ResNet50_Weights.IMAGENET1K_V2 and IMAGENET1K_V1 differ by ~2% top-1 accuracy. If you use pretrained=True without specifying the version, the result changes across torchvision updates. Pin both the library version and the weight enum in your model config.
Checkpoint strategy: save the full model state (model.state_dict()), optimizer state (optimizer.state_dict()), scheduler state, and the epoch number. This allows resuming interrupted training. Also save the pretrained weight source so you can reproduce the initialization.
When to retrain from scratch: transfer learning almost always wins at small to medium data regimes. At very large data (> 1M domain-specific samples), training from scratch on your domain can outperform ImageNet transfer, because the domain-specific pretraining better matches the target distribution. The crossover point varies by domain but is rarely below 500K samples. Default to transfer learning unless you have strong evidence the domain is too different and a very large dataset.
Input validation at serving time: the normalization constants and input resolution are part of your model's interface contract. If the serving pipeline applies different normalization than training - a surprisingly common production bug - accuracy degrades silently. Add input validation and logging for image statistics at inference time.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Transfer Learning Fine-Tuning demo on the EngineersOfAI Playground - no code required.
:::
