Semantic Segmentation
Reading Time: ~55 min | Interview Relevance: High | Target Roles: MLE, CV Engineer, Applied Scientist
The scenario brief is three sentences long, but its implications run deep. A camera mounted on the hood of an autonomous vehicle is capturing 30 frames per second. Each frame must be analyzed and acted upon in under 33 milliseconds. And the question the system must answer is not "is there a pedestrian in this image?" - that is object detection, and it is already solved. The question is: which specific pixels belong to the pedestrian, and which pixels belong to the road, the sidewalk, the sky, the car in the next lane?
The distinction matters because the vehicle's decision-making system does not reason in bounding boxes. It reasons in occupancy: can I drive through this pixel, or is something there? A bounding box around a pedestrian is a rectangle that includes pixels of road, car doors, and sky in its corners. Driving decisions made on bounding box occupancy would be dangerously inaccurate. What the system needs is a pixel-level map - every pixel labeled with its true semantic class - so the trajectory planning algorithm can precisely calculate whether the intended path collides with any pedestrian pixel.
Your team has been given three weeks to build the segmentation pipeline. On the first day you discover that your best classification network - a ResNet-50 that achieves 94% accuracy on the validation set - outputs a single 1,000-dimensional vector per image. That vector tells you what class dominates the scene, but the spatial information has been completely destroyed by four rounds of max-pooling. To produce a segmentation map, you need a network that outputs a label for each of the 1,280 × 720 = 921,600 pixels in the frame. The ResNet-50 you have produces exactly one label for all of them. You are starting over.
By the end of week one, you understand why. The compression-then-classify paradigm of standard CNNs is fundamentally incompatible with dense prediction. Building a system that can simultaneously understand "what is in this scene?" and "exactly which pixels is it occupying?" requires an entirely different architectural philosophy. This lesson walks through that philosophy - from the first attempt in 2015 to the state-of-the-art designs used in production perception systems today.
The Segmentation Spectrum
"Segmentation" is not one task. It is a family of related tasks with increasing information requirements. Knowing which one you need determines which architecture you use.
Semantic segmentation assigns every pixel a class label from a fixed vocabulary (road, pedestrian, car, sky, building). Crucially, it does not distinguish between instances of the same class - two cars are both "car" pixels, with no indication of where one car ends and the other begins.
Instance segmentation assigns each pixel both a class label AND an instance identity. Car #1 gets one color, car #2 gets another color. Separately, pedestrian #1 and pedestrian #2 are distinguished. This requires detecting individual objects first, then predicting which pixels belong to each one.
Panoptic segmentation combines both. "Things" (countable objects: cars, pedestrians, cyclists) get per-instance labels. "Stuff" (amorphous background regions: road, sky, vegetation) get semantic labels without instance IDs. This is the most comprehensive scene representation and the hardest task.
Original Image Semantic Seg Instance Seg Panoptic Seg
┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ sky │ │ sky (blue) │ │ sky (blue) │ │ sky (blue) │
│ car1 car2 │ → │ car (green) │ → │ car#1 car#2 │ → │ car#1 car#2 │
│ road │ │ road (grey) │ │ road (grey) │ │ road (grey) │
└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘
raw input car = one class car1 ≠ car2 thing+stuff unified
When you need which:
- Autonomous driving lane detection → semantic (which pixels are drivable?)
- Medical tumor segmentation → instance (which exact region is tumor #1?)
- Scene understanding for robots → panoptic (what can I interact with and where?)
- Satellite land cover mapping → semantic (what class covers this land patch?)
Why Classification CNNs Cannot Do Segmentation
This is worth understanding deeply, because it is the architectural problem that every segmentation model is designed to solve.
A standard classification CNN (AlexNet, VGG, ResNet) has the following structure:
Input: 224×224×3
↓ Conv + Pool ×4
Feature map: 7×7×512 (spatial info compressed ×32)
↓ Global Average Pool or Flatten
Vector: 512 or 25088-dimensional
↓ FC layers
Output: 1000-dimensional class scores (one label per image)
The spatial dimensions go from 224×224 down to 7×7. Information about which pixel was which is progressively discarded. A neuron in the 7×7 feature map has a receptive field covering the entire 224×224 input - it "sees" everything, which is great for recognizing the global category, and terrible for knowing that the cat's ear is at pixel (47, 83).
To produce a segmentation map, you need the output to be 224×224×C - the same spatial size as the input, with C class scores per pixel. Standard CNNs do the opposite: they produce 1×1×C.
The fundamental tension that every segmentation architecture must navigate:
Global context vs spatial precision:
- To classify a pixel correctly, you need large context ("this pixel is road because there are cars and sky around it")
- To localize class boundaries correctly, you need high spatial resolution ("the road ends at exactly this pixel, not two pixels to the left")
Standard CNNs achieve global context by destroying spatial resolution. The art of segmentation architectures is recovering both.
Fully Convolutional Networks (FCN, 2015): The First Deep Segmentation Model
Jonathan Long, Evan Shelhamer, and Trevor Darrell at Berkeley published Fully Convolutional Networks in 2015 - the paper that created the field of deep semantic segmentation.
The core insight: every fully connected layer in a classification CNN can be replaced by a 1×1 convolution. A 1×1 convolution is mathematically identical to a fully connected layer applied to each spatial location independently. After this replacement, the network becomes "fully convolutional" - it can accept an input of any spatial size and produces an output that is a spatial map, not a vector.
VGG-16 classification CNN: VGG-16 converted to FCN:
Input: 224×224×3 Input: 224×224×3 (or any size)
Conv layers → 7×7×512 Conv layers → 7×7×512
FC6: 7×7×512 → 4096 vector Conv 1×1: 7×7×512 → 7×7×4096
FC7: 4096 → 4096 vector Conv 1×1: 7×7×4096 → 7×7×4096
FC8: 4096 → 1000 class scores Conv 1×1: 7×7×4096 → 7×7×21 (21 VOC classes)
Upsample 32×: 7×7×21 → 224×224×21
For a 224×224 input, the FCN output before upsampling is 7×7×21. Upsampling this by 32× (the total stride of the backbone) recovers the original spatial size. But 32× upsampling is very coarse - the predictions are blurry at object boundaries.
Skip connections are FCN's solution. At stride 16, the feature map is 14×14. At stride 8, it is 28×28. These intermediate maps have finer spatial detail. FCN adds predictions from earlier layers:
FCN-32s: Upsample 32× directly from final feature map
→ very coarse, smooth but imprecise boundaries
FCN-16s: Add prediction from pool4 (stride 16), upsample 16×
→ 2× finer spatial detail at boundaries
FCN-8s: Add predictions from pool3 (stride 8) AND pool4, upsample 8×
→ 4× finer than FCN-32s, much better boundary delineation
FCN-8s on PASCAL VOC 2011: 62.7% mIoU. The previous state-of-the-art using hand-crafted features: ~50% mIoU. A 12-point improvement from a single architectural insight.
Why FCN worked was not just the architecture - it was that pretrained ImageNet features could be directly repurposed for dense prediction. The semantic understanding encoded in VGG's convolutional filters transferred to pixel-level understanding with only fine-tuning.
The Encoder-Decoder Architecture
FCN established the template that most segmentation architectures follow: an encoder (contracting path) that builds semantic representations while reducing spatial size, followed by a decoder (expanding path) that recovers spatial resolution while using the semantic representations.
Encoder (Contracting): Decoder (Expanding):
Input H×W×3 Output H×W×C
↓ ↑
Conv+Pool → H/2 × W/2 × 64 Upsample + Conv → H/2 × W/2 × 64
↓ ↑
Conv+Pool → H/4 × W/4 × 128 Upsample + Conv → H/4 × W/4 × 128
↓ ↑
Conv+Pool → H/8 × W/8 × 256 Upsample + Conv → H/8 × W/8 × 256
↓ ↑
Conv+Pool → H/16 × W/16 × 512 Upsample + Conv → H/16 × W/16 × 512
↓ ↑
Bottleneck H/32 × W/32 × 1024 ──┘
(deepest semantic features)
Why the decoder is not trivial:
Simple bilinear upsampling recovers spatial size but cannot recover spatial detail. Information that was discarded during downsampling (exactly which pixels had certain edge patterns) is genuinely gone. The decoder must either receive that information through skip connections from the encoder, or it must hallucinate plausible boundaries based on coarse semantics.
Transposed convolution (sometimes called deconvolution - a misnomer) is learned upsampling. It inserts learnable weights into the upsampling process, allowing the network to learn which spatial patterns to reconstruct. However, transposed convolutions are prone to checkerboard artifacts - a characteristic grid pattern in the output caused by overlapping kernel regions with unequal weight coverage. The standard fix is to use nearest-neighbor upsampling followed by a standard convolution, which is smoother and artifact-free.
U-Net (2015): The Medical Imaging Standard
Olaf Ronneberger, Philipp Fischer, and Thomas Brox at the University of Freiburg published U-Net in May 2015. The paper had a specific, urgent motivation: the EM ISBI challenge required segmenting neurons in electron microscopy images, and the training set had only 30 images. Standard deep learning required millions of images. U-Net had to work with almost nothing.
The two innovations that made this possible:
1. Concatenation-based skip connections (not addition like FCN):
At each decoder level, the upsampled feature map from the level below is concatenated with the corresponding encoder feature map. This means the decoder simultaneously receives:
- The what from the bottleneck (deep semantic context: "this is a cell membrane")
- The where from the encoder (precise spatial location: "the membrane edge is at this exact pixel")
Concatenation (rather than FCN's element-wise addition) preserves both sets of features independently - the decoder can learn how to weight and combine them through subsequent convolutions. Addition would force the encoder and decoder features into the same representation space, losing information.
2. Aggressive data augmentation - elastic deformations:
With only 30 training images of cell boundaries, U-Net applied extensive random elastic deformations to create a much larger effective training set. Elastic deformation warps the image as if it were a rubber sheet, creating plausible variants that preserve the biological structure while providing the network with novel training examples.
U-Net achieved the best results on the cell segmentation challenge - with 30 training images. It was published in May 2015 and to date has accumulated over 60,000 citations, making it one of the most cited computer vision papers in history. It became the de facto standard for medical image segmentation and remains competitive against transformer-based models a decade later.
The name comes from the U-shape of the architecture diagram: encoder on the left side descending, decoder on the right side ascending, with skip connections crossing the middle.
Dilated (Atrous) Convolutions: DeepLab's Contribution
Liang-Chieh Chen and colleagues at Google Brain published the DeepLab series starting in 2015, introducing a fundamentally different approach to the context-vs-precision tradeoff.
The problem formulation:
Faster R-CNN's success proved that reducing stride is harmful to detection - but standard conv networks reduce spatial resolution by 32× through pooling. For segmentation, can we get the large receptive field needed for context WITHOUT the resolution loss?
Standard convolution with stride 2: halves spatial resolution, doubles effective receptive field.
Dilated (atrous) convolution: inserts gaps between kernel elements, expanding the receptive field WITHOUT changing the spatial resolution of the output.
Standard 3×3 convolution Dilated 3×3 convolution, dilation=2
(effective receptive field: 3×3) (effective receptive field: 5×5)
. . . . . . . . . .
. X X X . . X . X .
. X X X . . . . . .
. X X X . . X . X .
. . . . . . . . . .
For a 3×3 kernel with dilation factor , the kernel elements are placed pixels apart. The effective receptive field size is:
| Dilation | Effective RF | Spatial stride |
|---|---|---|
| d=1 | 3×3 | 1 (standard conv) |
| d=2 | 5×5 | 1 |
| d=4 | 9×9 | 1 |
| d=6 | 13×13 | 1 |
| d=12 | 25×25 | 1 |
| d=18 | 37×37 | 1 |
DeepLab removes the last 1–2 pooling layers from a backbone (ResNet-101), which would normally reduce stride from 8 to 32. To compensate for the lost receptive field, the convolutions in those removed layers are replaced with dilated convolutions. The final feature map is now at stride 8 instead of stride 32 - 4× higher spatial resolution - while the receptive field is maintained through dilation.
Atrous Spatial Pyramid Pooling (ASPP):
Different objects are at different scales. A pedestrian 5 meters away needs a different receptive field than a car 20 meters away. ASPP applies dilated convolutions with multiple dilation rates in parallel and concatenates the results:
Feature map at stride 8
│
├── 1×1 conv (rate=1) → captures pixel-local detail
├── 3×3 dilated(rate=6) → captures 13×13 context
├── 3×3 dilated(rate=12) → captures 25×25 context
├── 3×3 dilated(rate=18) → captures 37×37 context
└── Global Avg Pool → captures image-global context
│
Concatenate → 1×1 conv → prediction
This single module provides context at 5 different scales simultaneously, making the model robust to objects of very different sizes in the same scene.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ASPPModule(nn.Module):
"""
Atrous Spatial Pyramid Pooling.
Applies parallel dilated convolutions at multiple rates to capture
multi-scale context without changing the feature map spatial resolution.
"""
def __init__(self, in_channels: int, out_channels: int, rates: list[int] = [6, 12, 18]):
super().__init__()
# 1×1 conv branch (rate=1, captures local detail)
self.conv1x1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
# Dilated 3×3 conv branches
self.dilated_convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
in_channels, out_channels,
kernel_size=3, padding=rate, dilation=rate, bias=False
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for rate in rates
])
# Global average pooling branch
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
# Fusion: concatenate all branches, project to out_channels
n_branches = 1 + len(rates) + 1 # 1×1 + dilated + global
self.project = nn.Sequential(
nn.Conv2d(n_branches * out_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h, w = x.shape[-2:]
# Local detail
branches = [self.conv1x1(x)]
# Multi-scale dilated context
for dilated_conv in self.dilated_convs:
branches.append(dilated_conv(x))
# Global context (upsample to match spatial size)
global_feat = self.global_avg_pool(x)
global_feat = F.interpolate(global_feat, size=(h, w), mode="bilinear", align_corners=False)
branches.append(global_feat)
# Concatenate and project
return self.project(torch.cat(branches, dim=1))
# Test ASPP
aspp = ASPPModule(in_channels=2048, out_channels=256)
x = torch.randn(2, 2048, 65, 65) # stride-8 feature map for 520×520 input
out = aspp(x)
print(f"ASPP output: {out.shape}") # (2, 256, 65, 65)
DeepLabV3+: State of the Art for Dense Prediction
DeepLabV3+ (Chen et al., 2018) combined the ASPP module with a simple but effective encoder-decoder structure:
Encoder: Xception-65 or ResNet-101 backbone with atrous convolutions (output stride=16) followed by ASPP module.
Decoder: Instead of simple 16× bilinear upsampling, DeepLabV3+ uses a lightweight decoder:
- Upsample ASPP output by 4× (to stride-4 resolution)
- Concatenate with low-level features from the early encoder (stride-4 features, reduced to 48 channels with a 1×1 conv)
- Apply 3×3 conv layers to refine
- Upsample 4× to full resolution
This two-stage decoder is significantly better than direct 16× upsampling because the stride-4 encoder features contain precise boundary information.
DeepLabV3+ vs U-Net - when to use which:
| Property | U-Net | DeepLabV3+ |
|---|---|---|
| Medical / microscopy images | Excellent | Adequate |
| Natural scene segmentation | Adequate | Excellent |
| Small training set | Strong (with augmentation) | Weaker (needs more data) |
| Resolution requirements | Pixel-perfect | Near-pixel-perfect |
| Backbone flexibility | Custom encoder | ResNet/Xception |
| Training complexity | Simple | Moderate |
| mIoU on Cityscapes | ~67% (standard U-Net) | 82.1% (DeepLabV3+ Xception) |
Loss Functions for Segmentation
The choice of loss function has an outsized effect on segmentation performance, especially when classes are imbalanced.
Pixel-wise cross-entropy:
Standard cross-entropy treats each pixel independently. The problem: in a typical street scene, road pixels might be 40% of the image, sky 25%, buildings 20%, and pedestrians only 3%. Cross-entropy loss is dominated by gradients from the majority classes. The model learns to predict "road" and "sky" very well, and "pedestrian" and "cyclist" poorly - precisely the classes you care most about for autonomous driving.
Dice loss:
where are predicted probabilities and are ground truth binary labels for a single class. Dice loss directly optimizes the overlap between prediction and ground truth, independent of class frequency. A class with 1,000 pixels contributes just as much to the Dice loss as a class with 1,000,000 pixels.
For multi-class segmentation, compute Dice per class and average:
The (smooth factor, typically 1.0) prevents division by zero when a class is absent from a batch.
Focal loss for segmentation:
Focal loss downweights easy pixels (background that the model correctly classifies with high confidence) and concentrates the loss on hard pixels (rare classes, ambiguous boundaries). is the standard choice.
Combined loss (cross-entropy + Dice) in practice:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
"""
Multi-class Soft Dice Loss.
Handles class imbalance by computing overlap per class independently.
"""
def __init__(self, smooth: float = 1.0, ignore_index: int = 255):
super().__init__()
self.smooth = smooth
self.ignore_index = ignore_index
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Args:
logits: (B, C, H, W) raw model output (before softmax)
targets: (B, H, W) integer class labels
Returns:
scalar Dice loss
"""
# Create valid pixel mask
valid_mask = (targets != self.ignore_index)
probs = torch.softmax(logits, dim=1) # (B, C, H, W)
n_classes = logits.shape[1]
# One-hot encode targets for valid pixels
targets_clamped = targets.clone()
targets_clamped[~valid_mask] = 0 # avoid out-of-range indices
targets_oh = F.one_hot(targets_clamped, n_classes) # (B, H, W, C)
targets_oh = targets_oh.permute(0, 3, 1, 2).float() # (B, C, H, W)
# Zero out invalid pixels in one-hot targets
targets_oh = targets_oh * valid_mask[:, None].float()
# Per-class Dice over batch and spatial dimensions
dims = (0, 2, 3)
intersection = (probs * targets_oh).sum(dim=dims)
cardinality = probs.sum(dim=dims) + targets_oh.sum(dim=dims)
dice_per_class = (2.0 * intersection + self.smooth) / (cardinality + self.smooth)
return 1.0 - dice_per_class.mean()
class SegmentationLoss(nn.Module):
"""
Combined Cross-Entropy + Dice loss.
CE provides stable gradient for all pixels.
Dice directly optimizes the overlap metric, handling class imbalance.
"""
def __init__(
self,
ce_weight: float = 0.5,
dice_weight: float = 0.5,
class_weights: torch.Tensor | None = None,
ignore_index: int = 255,
):
super().__init__()
self.ce = nn.CrossEntropyLoss(
weight=class_weights,
ignore_index=ignore_index,
)
self.dice = DiceLoss(ignore_index=ignore_index)
self.ce_weight = ce_weight
self.dice_weight = dice_weight
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> dict:
ce_loss = self.ce(logits, targets)
dice_loss = self.dice(logits, targets)
total = self.ce_weight * ce_loss + self.dice_weight * dice_loss
return {
"loss": total,
"ce_loss": ce_loss.item(),
"dice_loss": dice_loss.item(),
}
Evaluation: mIoU
Mean Intersection over Union is the standard metric for semantic segmentation. It measures, on average across all classes, how well the predicted pixels overlap with the ground truth pixels.
Per-class IoU (Jaccard Index):
- : pixels predicted as class that are truly class
- : pixels predicted as class that are NOT class
- : pixels that are truly class but predicted as something else
Note that , so this is equivalent to .
Mean IoU:
Why mIoU is fairer than pixel accuracy:
On Cityscapes (19 classes), road pixels make up ~35% of the image and rider pixels ~0.3%. A model predicting "road" everywhere achieves 35% pixel accuracy but mIoU ≈ 0 (it gets road's IoU right but zeros out all other classes). mIoU treats every class equally - a model cannot hide its failure on rare classes behind success on common ones.
import numpy as np
import torch
def compute_miou(
preds: torch.Tensor, # (N, H, W) predicted class indices
targets: torch.Tensor, # (N, H, W) ground truth class indices
num_classes: int,
ignore_index: int = 255,
) -> dict:
"""
Compute per-class IoU and mean IoU over an entire dataset.
Uses a confusion matrix for efficiency - avoids looping over pixels.
Entry conf_matrix[i, j] = number of pixels truly class i predicted as class j.
"""
# Build confusion matrix
conf_matrix = torch.zeros(num_classes, num_classes, dtype=torch.long)
valid_mask = targets != ignore_index
valid_preds = preds[valid_mask]
valid_targets = targets[valid_mask]
# Vectorized accumulation into confusion matrix
indices = valid_targets * num_classes + valid_preds
conf_matrix = torch.bincount(
indices, minlength=num_classes * num_classes
).reshape(num_classes, num_classes)
# Per-class IoU from confusion matrix
# TP: diagonal; FP: column sum - diagonal; FN: row sum - diagonal
tp = conf_matrix.diag()
fp = conf_matrix.sum(dim=0) - tp
fn = conf_matrix.sum(dim=1) - tp
iou_per_class = tp.float() / (tp + fp + fn).float().clamp(min=1e-6)
# Classes not present in ground truth (row sum == 0) are excluded from mean
present_mask = conf_matrix.sum(dim=1) > 0
miou = iou_per_class[present_mask].mean().item()
return {
"miou": miou,
"iou_per_class": iou_per_class.tolist(),
"confusion_matrix": conf_matrix,
}
U-Net: PyTorch Implementation from Scratch
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""
Two consecutive Conv(3×3) → BatchNorm → ReLU blocks.
The basic building block of U-Net.
Receptive field per block: 5×5 (two 3×3 convolutions).
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
super().__init__()
mid_channels = mid_channels or out_channels
self.block = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)
class Down(nn.Module):
"""
Encoder block: MaxPool2d(2) → DoubleConv.
Halves spatial resolution, doubles channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.pool_conv = nn.Sequential(
nn.MaxPool2d(kernel_size=2),
DoubleConv(in_channels, out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.pool_conv(x)
class Up(nn.Module):
"""
Decoder block: Upsample → concatenate skip connection → DoubleConv.
Args:
in_channels: channels entering from below (decoder side)
out_channels: output channels
bilinear: use bilinear upsampling (True) or transposed conv (False)
Bilinear is smoother, avoids checkerboard artifacts.
"""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
# After concat: in_channels channels from below + in_channels//2 from skip
self.conv = DoubleConv(in_channels, out_channels, mid_channels=in_channels // 2)
else:
# Transposed conv: learnable upsampling, halves channels
self.up = nn.ConvTranspose2d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x_decoder: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
"""
Args:
x_decoder: (B, C, H, W) upsampled features from decoder
x_skip: (B, C', H', W') skip connection from encoder
H' and W' may differ by 1 pixel from H, W due to
integer division in MaxPool - handle with padding.
"""
x_decoder = self.up(x_decoder)
# Pad x_decoder if its size doesn't match x_skip exactly
# This happens when input dimensions are not divisible by 2^depth
dh = x_skip.size(2) - x_decoder.size(2)
dw = x_skip.size(3) - x_decoder.size(3)
x_decoder = F.pad(x_decoder, [dw // 2, dw - dw // 2, dh // 2, dh - dh // 2])
# Concatenate skip connection (fine spatial detail) with upsampled decoder features
x = torch.cat([x_skip, x_decoder], dim=1)
return self.conv(x)
class OutConv(nn.Module):
"""Final 1×1 convolution to produce per-pixel class logits."""
def __init__(self, in_channels: int, num_classes: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class UNet(nn.Module):
"""
U-Net for semantic segmentation (Ronneberger et al., 2015).
Architecture:
- 4 encoder blocks (double channels at each step: 64 → 128 → 256 → 512)
- Bottleneck (1024 channels)
- 4 decoder blocks with skip connections from encoder
- 1×1 final conv for class logits
Input: (B, n_channels, H, W)
Output: (B, n_classes, H, W) - same spatial size as input
"""
def __init__(
self,
n_channels: int = 3,
n_classes: int = 2,
bilinear: bool = True,
):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
# When bilinear=True, we halve channels in the Up block to keep param count down
factor = 2 if bilinear else 1
# Encoder
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024 // factor)
# Decoder
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder - save all stages for skip connections
x1 = self.inc(x) # (B, 64, H, W)
x2 = self.down1(x1) # (B, 128, H/2, W/2)
x3 = self.down2(x2) # (B, 256, H/4, W/4)
x4 = self.down3(x3) # (B, 512, H/8, W/8)
x5 = self.down4(x4) # (B, 512, H/16, W/16) [bilinear=True]
# Decoder - each step receives skip connection from corresponding encoder stage
x = self.up1(x5, x4) # (B, 256, H/8, W/8)
x = self.up2(x, x3) # (B, 128, H/4, W/4)
x = self.up3(x, x2) # (B, 64, H/2, W/2)
x = self.up4(x, x1) # (B, 64, H, W)
return self.outc(x) # (B, n_classes, H, W)
# Verify model dimensions
if __name__ == "__main__":
model = UNet(n_channels=3, n_classes=19) # Cityscapes: 19 classes
x = torch.randn(2, 3, 512, 512)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {n_params:,}")
# Input: torch.Size([2, 3, 512, 512])
# Output: torch.Size([2, 19, 512, 512])
# Parameters: 31,037,779
Instance Segmentation: Mask R-CNN
Semantic segmentation labels all pixels of the same class identically. In many applications - tracking pedestrians, counting cells, grasping individual objects - you need to distinguish separate instances of the same class.
Mask R-CNN (He et al., 2017) extends Faster R-CNN with a third parallel head that predicts a binary mask for each detected object instance.
Architecture:
Image → Backbone (ResNet + FPN)
↓
RPN → region proposals
↓
RoI Align (per-proposal features)
↓
┌─────┴──────┐──────────┐
↓ ↓ ↓
Classification Box Reg Mask Head
(what class?) (where?) (which pixels?)
↓
28×28 binary mask
per instance
RoI Align vs RoI Pooling:
RoI Pooling divides a region into a H×W grid and max-pools each cell. The problem: the grid boundaries are rounded to integer pixel coordinates. For small regions (e.g., a 5×5 pixel object), this quantization error is massive - the pooled region may include wrong pixels.
RoI Align eliminates quantization by using bilinear interpolation at exactly the projected floating-point coordinates. For each output cell, it samples 4 points using bilinear weights. This sub-pixel precision is essential for generating accurate masks (a 28×28 mask on a 30×30 object instance has no tolerance for pixel misalignment).
Mask head:
For each detected object (already classified and localized), the mask head predicts a 28×28 binary mask: which pixels within the bounding box belong to this specific instance? The mask is predicted independently per class - for each of the classes, a 28×28 mask is predicted, and the mask for the detected class is selected. This avoids class competition inside the mask head.
Practical Training Details
Class weighting for imbalanced datasets:
When class imbalances are severe but Dice loss alone is not sufficient (e.g., very rare classes appearing in only some batches), weight the cross-entropy loss inversely proportional to class frequency:
import numpy as np
import torch
def compute_class_weights(
dataset_pixel_counts: dict[int, int],
num_classes: int,
method: str = "inverse_frequency",
) -> torch.Tensor:
"""
Compute class weights for weighted cross-entropy.
Args:
dataset_pixel_counts: {class_id: total_pixel_count_in_dataset}
method: "inverse_frequency" or "median_frequency"
"""
counts = np.array([dataset_pixel_counts.get(c, 1) for c in range(num_classes)], dtype=float)
total = counts.sum()
if method == "inverse_frequency":
# Weight inversely proportional to frequency
weights = total / (num_classes * counts)
elif method == "median_frequency":
# Weight = median_freq / class_freq (Eigen & Fergus, 2015)
freq = counts / total
weights = np.median(freq) / freq
weights = weights / weights.sum() * num_classes # Normalize to mean=1
return torch.tensor(weights, dtype=torch.float32)
# Example: Cityscapes approximate class frequencies
pixel_counts = {
0: 5_000_000, # road (very common)
1: 2_000_000, # sidewalk
2: 3_000_000, # building
11: 50_000, # person (rare)
12: 5_000, # rider (very rare)
}
weights = compute_class_weights(pixel_counts, num_classes=19)
criterion = torch.nn.CrossEntropyLoss(weight=weights, ignore_index=255)
Mixed precision training for high-resolution inputs:
Segmentation on 1024×1024 or larger images exhausts GPU memory quickly. Mixed precision (FP16) halves memory footprint:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for images, masks in train_loader:
images, masks = images.cuda(), masks.cuda()
optimizer.zero_grad()
with autocast(): # Compute in FP16
logits = model(images)
loss = criterion(logits, masks)["loss"]
scaler.scale(loss).backward() # Scale gradients to avoid underflow
scaler.step(optimizer)
scaler.update()
Patch-based training for very large images:
Whole-slide pathology images can be 50,000 × 50,000 pixels. Processing them directly is impossible - instead, extract and train on patches:
import random
import torch
from torch.utils.data import Dataset
from PIL import Image
class PatchSegmentationDataset(Dataset):
"""
Extract random patches from large images during training.
Ensures each patch contains at least some foreground (non-background) pixels.
"""
def __init__(
self,
image_paths: list[str],
mask_paths: list[str],
patch_size: int = 512,
patches_per_image: int = 8,
foreground_threshold: float = 0.05,
transform=None,
):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.fg_threshold = foreground_threshold
self.transform = transform
def __len__(self) -> int:
return len(self.image_paths) * self.patches_per_image
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_idx = idx // self.patches_per_image
image = Image.open(self.image_paths[img_idx]).convert("RGB")
mask = Image.open(self.mask_paths[img_idx])
W, H = image.size
p = self.patch_size
# Sample patch with minimum foreground content
for _ in range(50): # try up to 50 times
x = random.randint(0, W - p)
y = random.randint(0, H - p)
mask_patch = mask.crop((x, y, x + p, y + p))
mask_arr = torch.from_numpy(
__import__("numpy").array(mask_patch, dtype="int64")
)
fg_ratio = (mask_arr > 0).float().mean().item()
if fg_ratio >= self.fg_threshold:
break
image_patch = image.crop((x, y, x + p, y + p))
if self.transform:
image_patch, mask_patch = self.transform(image_patch, mask_patch)
return image_patch, mask_patch
Test-time augmentation (TTA) for segmentation:
Averaging predictions over multiple augmented versions of the test image reliably improves mIoU by 1–3 points:
import torch
import torch.nn.functional as F
def tta_predict(
model: torch.nn.Module,
image: torch.Tensor,
num_classes: int,
) -> torch.Tensor:
"""
Test-time augmentation: average predictions over flips and scales.
Args:
image: (1, C, H, W) single test image
Returns: (1, num_classes, H, W) averaged probability map
"""
model.eval()
h, w = image.shape[-2:]
prob_sum = torch.zeros(1, num_classes, h, w, device=image.device)
n_augmentations = 0
with torch.no_grad():
for flip in [False, True]:
img = image.flip(-1) if flip else image # horizontal flip
logits = model(img)
probs = torch.softmax(logits, dim=1)
if flip:
probs = probs.flip(-1) # flip predictions back
prob_sum += probs
n_augmentations += 1
return prob_sum / n_augmentations
Modern Approaches: Beyond Encoder-Decoder CNNs
SegFormer (2021): Transformers for Segmentation
The vision transformer revolution reached segmentation with SegFormer (Xie et al., NVIDIA, NeurIPS 2021). It replaces the CNN encoder with a hierarchical vision transformer - the Mix-Transformer (MiT) - and pairs it with an unusually lightweight MLP decoder.
The Mix-Transformer has four stages at strides 4, 8, 16, and 32, mirroring the multi-scale structure of ResNet + FPN but using self-attention instead of convolutions. Each stage uses Efficient Self-Attention: instead of computing full O(N²) attention over all spatial positions, it reduces the spatial dimension of keys and values by a factor R (e.g., R=4 at stage 1, R=2 at stage 2), computing attention only at the reduced-resolution feature points. This makes transformer encoding tractable at high resolution.
Key design decision: no positional encoding. Traditional ViT uses fixed positional embeddings, which forces the model to use one specific resolution at test time. SegFormer instead uses zero-padding in the depth-wise convolutions of its Mix-FFN blocks. The model implicitly learns position from the padding context, and naturally handles different input sizes at test time without interpolation artifacts.
The decoder is a simple 4-layer MLP:
Features from 4 MiT stages (different resolutions and channels)
│
├── Linear projection to uniform C channels
├── Upsample all to 1/4 input resolution
├── Concatenate
└── Linear fusion → per-pixel class predictions
This lightweight decoder works because the transformer encoder already captures long-range context through self-attention - you do not need a heavy decoder to aggregate context from spatial positions. The encoder does that work. The decoder just needs to project and fuse.
SegFormer results:
- SegFormer-B0 (3.8M params): 76.2 mIoU on ADE20K at 48 fps
- SegFormer-B5 (84.6M params): 84.0 mIoU on ADE20K - outperforms DeepLabV3+ at similar FLOPs
- On Cityscapes: SegFormer-B5 achieves 84.0 mIoU vs DeepLabV3+'s 80.9%
SAM: Segment Anything Model (Meta AI, 2023)
SAM is to segmentation what GPT-3 was to text - a foundation model trained at scale that generalizes to virtually any segmentation task without fine-tuning.
The core claim: train on enough data (1.1 billion masks from 11 million images), with the right architecture and interactive prompting design, and you can build a segmentation model that works on images it has never seen before - medical scans, satellite imagery, underwater footage, microscopy.
The architecture has three components:
Image Encoder: ViT-H (632M parameters) processes the image at 1024×1024 and produces a 64×64 embedding. This is expensive (~0.15 seconds on A100 GPU), but it runs only once per image. The embedding is cached and reused for every prompt on that image.
Prompt Encoder: encodes user inputs into embeddings:
- Positive/negative point clicks → positional + type embeddings
- Bounding box → two corner point embeddings
- Mask (from a previous iteration) → downsampled and convolved to dense embedding
Mask Decoder: a lightweight 2-layer transformer decoder that attends between the image embeddings and prompt embeddings. It outputs three masks at different semantic granularities (for an ambiguous point click: just the eye, the face, or the whole head?) plus predicted IoU scores for each. The three-mask design handles ambiguity: the user picks the one that corresponds to the intended granularity.
Why SAM changed the field:
-
Zero-shot generalization: SAM works on medical images, satellite imagery, underwater footage, microscopy - domains not seen in training - with no fine-tuning.
-
Interactive segmentation at real-time speeds: after the image encoder runs once, the mask decoder produces results in milliseconds. You can click, see a mask, click again to refine, and get instant feedback.
-
High-quality automatic segmentation: run a grid of point prompts over the image and get an "everything" segmentation with no human input. This was used to generate the 1B mask training set itself.
Practical usage pattern:
# Using SAM via the segment-anything library (Meta's official release)
# pip install segment-anything
from segment_anything import sam_model_registry, SamPredictor
import torch
import numpy as np
from PIL import Image
# Load SAM model (ViT-H, ViT-L, or ViT-B available)
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth")
sam = sam.to("cuda")
predictor = SamPredictor(sam)
# Set image - image encoder runs once here
image = np.array(Image.open("medical_scan.png").convert("RGB"))
predictor.set_image(image)
# Predict from a point prompt (specify the tumor location with one click)
# The mask decoder runs in <50ms for each prompt
input_point = np.array([[300, 200]]) # (x, y) coordinates
input_label = np.array([1]) # 1 = positive (foreground)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, # Return 3 masks at different granularities
)
# masks: (3, H, W) boolean arrays
# scores: (3,) IoU confidence for each mask
best_mask = masks[np.argmax(scores)] # Pick highest-confidence mask
print(f"Mask shape: {best_mask.shape}") # (H, W)
print(f"Foreground pixels: {best_mask.sum()}") # number of segmented pixels
# For automatic segmentation (no prompts - segment everything)
from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # Grid density for automatic point prompts
pred_iou_thresh=0.86, # Filter out low-quality masks
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Minimum mask size (removes tiny fragments)
)
masks = mask_generator.generate(image)
# Returns list of dicts: {segmentation, area, bbox, predicted_iou, ...}
print(f"Found {len(masks)} segments automatically")
When to use SAM vs U-Net vs DeepLabV3+:
| Situation | Recommended Approach |
|---|---|
| Supervised, labeled dataset exists, specific classes | U-Net or DeepLabV3+ - best mIoU on your task |
| Medical imaging, limited labels (< 100 images) | U-Net with pretrained encoder + Dice loss |
| Need to annotate a new dataset interactively | SAM for annotation assistance - human-in-the-loop |
| Zero-shot on new domain with no labels | SAM - generalizes without fine-tuning |
| Real-time segmentation on edge/mobile | SegFormer-B0 or Fast-SCNN |
| Research: highest mIoU on standard benchmarks | Mask2Former or SegFormer-B5 |
SAM's image encoder is a powerful vision backbone. A productive pattern: use SAM's ViT-H image encoder (frozen or lightly fine-tuned) + a small task-specific decoder head trained on your labeled data. You get SAM's rich representations without training a large backbone from scratch, and converge 3-5× faster with less labeled data.
SAM2 (2024): Video Segmentation
Meta released SAM2 in 2024, extending SAM to video. The core addition: a memory bank that stores encoded representations from past frames. When segmenting a new frame, the mask decoder attends to both the current image embeddings and the stored memory - enabling consistent tracking of the segmented object across frames without re-prompting.
SAM2 can segment an object in the first frame of a video with a single click and track it across the entire sequence, handling occlusions and re-appearances. On image segmentation, SAM2 also outperforms SAM on most benchmarks.
Production Notes
ONNX export for deployment:
import torch
from pathlib import Path
def export_unet_onnx(
model: torch.nn.Module,
input_size: tuple[int, int] = (512, 512),
output_path: str = "unet_segmentation.onnx",
) -> None:
model.eval()
dummy = torch.randn(1, 3, *input_size)
torch.onnx.export(
model,
dummy,
output_path,
input_names=["image"],
output_names=["logits"],
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"logits": {0: "batch", 2: "height", 3: "width"},
},
opset_version=13,
)
print(f"Exported to {output_path}")
Tiling for large images at inference:
def tile_predict(
model: torch.nn.Module,
image: torch.Tensor,
tile_size: int = 512,
overlap: int = 64,
num_classes: int = 19,
) -> torch.Tensor:
"""
Predict segmentation on a large image by tiling.
Overlap between tiles prevents boundary artifacts.
Args:
image: (1, C, H, W) full-resolution image
tile_size: size of each square tile
overlap: pixel overlap between adjacent tiles
num_classes: number of segmentation classes
Returns:
(1, num_classes, H, W) probability map for the full image
"""
_, C, H, W = image.shape
step = tile_size - overlap
prob_map = torch.zeros(1, num_classes, H, W, device=image.device)
count_map = torch.zeros(1, 1, H, W, device=image.device)
model.eval()
with torch.no_grad():
for y in range(0, H - overlap, step):
for x in range(0, W - overlap, step):
y2 = min(y + tile_size, H)
x2 = min(x + tile_size, W)
y1 = y2 - tile_size
x1 = x2 - tile_size
tile = image[:, :, y1:y2, x1:x2]
logits = model(tile)
probs = torch.softmax(logits, dim=1)
prob_map[:, :, y1:y2, x1:x2] += probs
count_map[:, :, y1:y2, x1:x2] += 1.0
return prob_map / count_map.clamp(min=1.0)
Padding and divisibility. U-Net with 4 downsampling steps requires input dimensions to be divisible by . An input of 513×513 will cause a shape mismatch in the Up block when the upsampled decoder feature (256×256 after up4) doesn't match the encoder skip connection (257×257 from x1 when input is 513×513). The Up block in the implementation above handles this with explicit padding - but you should still pad your inputs to the nearest multiple of 16 before inference to be safe.
When to use pretrained encoders. The U-Net implementation above uses a custom encoder trained from scratch - fine for medical imaging where the domain is very different from ImageNet (grayscale microscopy vs natural photos). For natural scene segmentation (Cityscapes, ADE20K), always use a pretrained ResNet or EfficientNet as the encoder and only train the decoder from scratch. Pretrained encoders reduce training time by 5–10× and improve mIoU by 5–15 points on natural image datasets.
Interview Q&A
Q1: What problem does U-Net solve, and why do skip connections use concatenation rather than addition?
The fundamental problem U-Net solves is the spatial information loss in the encoder. As the encoder downsamples to build semantic features (large receptive field for context), it discards the precise spatial information needed for pixel-accurate segmentation. The decoder upsamples to recover full resolution, but upsampling alone cannot reconstruct discarded spatial detail. Skip connections pass the encoder's feature maps directly to the corresponding decoder level, providing fine-grained spatial information (edge locations, texture boundaries) that the decoder combines with deep semantic context. U-Net uses concatenation rather than addition (like ResNet skip connections or FCN) because concatenation preserves both feature sets independently - the decoder receives both the semantic "what is this region?" and the spatial "exactly where is the boundary?" as separate channels, then learns through convolution how to combine them optimally. Addition forces both to lie in the same representation space, which is harder to satisfy when they encode different levels of abstraction.
Q2: Why are dilated (atrous) convolutions useful for segmentation? What does the dilation rate control?
Segmentation requires large receptive fields for context (to classify a pixel, you need to see surrounding objects) AND high spatial resolution (to draw precise boundaries). Standard CNNs achieve large receptive fields through pooling, but pooling destroys the spatial resolution needed for precise segmentation. Dilated convolutions expand the receptive field by inserting gaps of pixels between kernel elements. A 3×3 conv with dilation has an effective receptive field of without any stride or pooling - spatial resolution is fully preserved. The dilation rate directly controls the scale of context: low dilation (d=1-2) captures local texture and edges; high dilation (d=12-18) captures global structure and scene layout. DeepLab's ASPP uses multiple dilation rates in parallel to capture multi-scale context simultaneously - different neurons respond to features at different scales, making the model robust to objects of varying sizes in the same image.
Q3: What is the difference between semantic segmentation, instance segmentation, and panoptic segmentation? Give a concrete example of when each is needed.
Semantic segmentation labels every pixel with a class but makes no distinction between separate instances - all car pixels get label "car" regardless of how many cars are in the scene. Instance segmentation labels every pixel with both a class and an instance ID - car #1 gets one color mask, car #2 gets another. Panoptic segmentation combines both: "things" (countable, discrete objects: cars, pedestrians) get per-instance labels; "stuff" (amorphous regions: road, sky, vegetation) get only semantic labels. Concrete use cases: lane detection in autonomous driving needs semantic segmentation (which pixels are the drivable lane?). Cell counting in pathology needs instance segmentation (exactly how many cells are there, and where is each one?). A robotic manipulation system needs panoptic segmentation (which table surface can I place objects on - stuff - and which specific cup should I pick up - things with instance IDs for tracking?).
Q4: Explain mean IoU (mIoU) as a segmentation metric. Why is pixel accuracy insufficient, and when can mIoU also be misleading?
Pixel accuracy measures the fraction of pixels classified correctly. If road covers 40% of image pixels in a self-driving dataset, a model predicting "road" for every pixel achieves 40% accuracy on road images - completely useless but seemingly reasonable. mIoU computes IoU separately for each class: , then averages across all classes. Each class contributes equally to mIoU regardless of how many pixels it covers, so a model cannot hide failure on rare classes behind success on common ones. mIoU of 0.72 means the model overlaps the ground truth by 72% on average across all classes. However, mIoU can also be misleading: if a dataset has 19 classes but 2 of them (road, building) cover 80% of pixels, good performance on those 2 classes pulls up the mIoU even if the model fails completely on the other 17. In practice, always report per-class IoU alongside mIoU and specifically analyze performance on safety-critical or rare classes.
Q5: What is the Dice loss, when should you use it instead of cross-entropy, and why do practitioners typically combine both?
Dice loss is , where are predicted probabilities and are ground truth binary labels. It directly optimizes the overlap between prediction and ground truth, making it inherently class-balanced: a rare class with 1,000 pixels contributes just as much Dice loss as a common class with 1,000,000 pixels, because Dice computes overlap within each class independently. Use Dice loss when class imbalance is severe (tumor vs background, defect vs normal surface) and cross-entropy is dominated by majority-class gradients, causing the model to ignore rare classes. The reason practitioners combine both is that they have complementary strengths: cross-entropy provides dense, pixel-level gradient signal for every pixel in every batch - its gradient is well-conditioned and training is stable. Dice loss provides the right objective for the imbalanced classes but can have unstable gradients for very small objects (when the denominator is near zero). Combined loss CE + Dice gets stable training from CE and the correct imbalance-handling objective from Dice, consistently outperforming either alone on medical and industrial segmentation benchmarks.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the 2D Convolution Visualization demo on the EngineersOfAI Playground - no code required.
:::
