Google TPU Architecture
Reading time: ~40 min - Interview relevance: High - Target roles: ML Systems Engineer, AI Infrastructure Engineer, Research Engineer
A systolic array is a 2D grid of multiply-accumulate units where data flows through cells like water through pipes. No cache lookups. No branch mispredictions. Just math, moving at wire speed. That is why a TPU can sustain 275 TFLOPS on a matrix multiply while a GPU spends half its time fetching data.
The Migration That Went Wrong
It is 3 AM and your team is six hours into a BERT pretraining run on a TPU v4 Pod. The job should be blazing. Instead, the training throughput dashboard looks like a heartbeat monitor - 40 seconds of high TFLOPS, then a 12-second dip to near zero, then back up. Repeat.
Your senior engineer checks the logs and sees a flood of XLA messages: retracing function for new shapes. The culprit is a dynamic padding scheme your team borrowed from your A100 training code. On GPUs, padding to the nearest power of two is a minor optimization. On TPUs, every unique input shape is a new compilation event. Your 12-second "dips" are XLA recompiling the entire model from scratch.
This scenario plays out constantly when teams migrate ML workloads from GPUs to TPUs without understanding the fundamental architectural difference. GPUs are general-purpose throughput machines that handle variable input gracefully. TPUs are purpose-built matrix engines that require you to commit to a fixed computation graph up front - and reward you with dramatically lower cost per FLOP when you do.
Understanding why this constraint exists requires going back to first principles: what is a systolic array, why did Google build one, and what does that mean for your code?
Why TPUs Exist
In 2013, Google engineers ran the numbers and found that if voice search users started using speech recognition powered by deep neural networks for just three minutes per day, Google would need to double its entire data center capacity to handle the inference load. GPUs were too expensive per inference operation, too power-hungry, and not available in the quantities Google needed. Off-the-shelf silicon designed for gaming was not going to scale to billions of Google Search and Translate queries.
The engineering team proposed building a custom ASIC - an Application-Specific Integrated Circuit designed to do exactly one thing at the maximum possible efficiency: multiply matrices. The project was kept secret for years. When AlphaGo defeated Lee Sedol in 2016, Google revealed that the neural network inference ran on TPUs. The hardware community was stunned - Google had been running a secret AI hardware program for three years.
The core insight was this: neural network inference at the time was dominated by matrix multiplication. If you could build a chip that did matrix multiplication an order of magnitude more efficiently than a GPU, you could serve the same inference workload at a fraction of the cost and power. Everything else - activation functions, normalization, attention - could be handled by auxiliary units.
This constraint - "optimize for matrix multiply above all else" - shaped every architectural decision in the TPU and continues to shape it today.
Historical Context: The TPU Generations
Understanding where TPUs stand today requires understanding where they came from.
TPU v1 (2016) - Inference Only
The first TPU was designed purely for inference. It ran INT8 arithmetic, delivered 92 TOPS (tera-operations per second), consumed 40W, and lived in a PCIe slot like a GPU. The heart of it was the Matrix Multiply Unit (MXU): a 256x256 systolic array with 65,536 multiply-accumulate cells. There was no HBM - just 28MB of on-chip SRAM and a 34 GB/s memory bandwidth connection to the host. No training support. No floating-point gradients.
TPU v1 powered Google Search, Google Translate, and RankBrain. It ran AlphaGo. At 40W and 92 TOPS, the performance-per-watt ratio was roughly 10x better than contemporary GPUs for neural network inference.
TPU v2 (2017) - Training Capable
The team added two changes to enable training: BF16 floating-point arithmetic and High Bandwidth Memory (HBM). BF16 - brain floating-point - has the same 8-bit exponent as FP32 but only 7 mantissa bits (versus 23). This means BF16 covers the same dynamic range as FP32, which matters for gradient training stability, but uses half the bits. The HBM provided 8GB per chip with 600 GB/s bandwidth. Each v2 chip delivered 45 TFLOPS BF16.
The v2 also introduced the Pod concept: 64 chips interconnected with custom high-speed links. Total pod bandwidth: 2D torus topology between chips.
TPU v3 (2018) - Liquid Cooling
Running 420 TFLOPS BF16 generates substantial heat. Google moved to liquid cooling for v3, doubled the HBM to 16GB per chip, and pushed bandwidth to 900 GB/s. A v3 Pod contained 1024 chips. This was the hardware that trained BERT, T5, and the early GPT-scale language models at Google.
TPU v4 (2021) - Optical Switching
TPU v4 is where the architecture matured significantly. Each chip delivers 275 TFLOPS BF16, carries 32GB of HBM, and connects via custom ICI (Inter-Chip Interconnect) running at speeds that enable all-reduce across thousands of chips at practical bandwidth. A v4 Pod contains 4096 chips interconnected in a 3D torus. The total pod bandwidth reaches 1.2 Pb/s.
The headline innovation was optical circuit switching in the Pod network - Google built custom optical switches that reconfigure the inter-chip topology without electrical signaling overhead. This means a v4 Pod can be reconfigured for different parallelism strategies (data parallel, tensor parallel, pipeline parallel) by changing the optical switch configuration rather than re-cabling.
PaLM (540B parameters) and Gemini Ultra were trained on v4 Pods.
TPU v5e and v5p (2023)
v5e is the efficiency variant: lower cost per chip, optimized for training and serving mid-size models (7B-70B range), good performance-per-dollar for inference workloads. v5p is the performance variant: higher TFLOPS, more HBM per chip, designed for frontier model training. Google has not disclosed full v5 specs, but reported benchmarks show 2-3x improvement over v4 in training throughput for transformer workloads.
Core Concept: The Systolic Array
This is the central piece of TPU architecture and the concept that confuses engineers most when they first encounter it.
What a Systolic Array Is
A systolic array is a 2D grid of processing elements (PEs). Each PE can do one multiply-accumulate (MAC) operation per clock cycle. The key property: data flows through the grid in a pipelined, synchronized way - no PE needs to "fetch" data from memory mid-computation. The data arrives pre-scheduled, one step at a time, like a conveyor belt.
Here is the simplest possible example. Consider multiplying a 4x4 matrix A by a 4x4 matrix B to get matrix C:
In a naive implementation, you would load a row of A, load a column of B, multiply element-wise, and sum. For a 256x256 matrix, that is 256 rows times 256 columns times 256 multiplications = 16.7 million operations. Each operation requires loading data from memory.
In the TPU's 256x256 systolic array, the matrix B is preloaded into the weight FIFO and "fills" the array column by column. Rows of matrix A flow in from the left, one row per clock cycle. Each PE multiplies the incoming A value with its resident B weight, adds the result to its running sum, and passes the partial result to the PE below it. After 256 clock cycles (one row per cycle), every cell in the output matrix C has accumulated its final sum.
The critical insight: once the weights are loaded, no PE ever touches memory during the computation. Data flows from neighbor to neighbor at register speed. The systolic array achieves near-100% utilization of its arithmetic units, which is why it delivers sustained TFLOPS close to its theoretical peak - something GPUs rarely achieve in practice.
Weights loaded into array (static during one matrix multiply)
B[0,0] B[0,1] B[0,2] B[0,3]
B[1,0] B[1,1] B[1,2] B[1,3]
B[2,0] B[2,1] B[2,2] B[2,3]
B[3,0] B[3,1] B[3,2] B[3,3]
A[0,:] --> [PE] --> [PE] --> [PE] --> [PE]
| | | |
A[1,:] --> [PE] --> [PE] --> [PE] --> [PE]
| | | |
A[2,:] --> [PE] --> [PE] --> [PE] --> [PE]
| | | |
A[3,:] --> [PE] --> [PE] --> [PE] --> [PE]
| | | |
C[0] ... C columns flow out below
How This Differs from GPU Tensor Cores
GPU Tensor Cores also do matrix multiplication, but the execution model is fundamentally different.
A GPU Tensor Core operates on 16x16 tiles. The GPU scheduler assigns thread blocks to Tensor Cores, each thread block loads its tile from global memory or shared memory, the Tensor Core runs the 16x16 matrix multiply, and the result is written back to shared memory. The GPU relies on its massive L2 cache and shared memory hierarchy to keep the Tensor Cores fed.
The problem: cache misses. For large matrix multiplications that exceed L2 cache, the GPU Tensor Cores stall waiting for data from HBM. This is why GPU utilization (MFU - Model FLOP Utilization) for LLM training typically runs at 30-50% of theoretical peak, even with careful optimization.
The TPU MXU avoids cache misses by preloading weights into the systolic array's internal weight FIFOs. The 256x256 array holds 65,536 weights in local registers. For a single matrix multiply, those weights never move - only the activations flow through. This is why TPU MFU for matrix-heavy workloads (attention, FFN layers in transformers) routinely exceeds 50-60% of theoretical peak.
The tradeoff: the systolic array is not programmable at the individual-operation level. You cannot issue arbitrary instructions to individual PEs. The only operation supported is "stream data through and accumulate." Everything else - softmax, layer norm, ReLU, element-wise ops - happens in the Vector Processing Unit (VPU).
TPU Memory Hierarchy
A TPU chip has three levels of memory, each with very different characteristics.
On-chip SRAM (scratchpad memory)
Each TPU v4 chip has 16MB of on-chip SRAM organized as a scratchpad. This memory is directly accessible to the VPU and sits adjacent to the systolic array. Bandwidth: effectively unlimited (within a single clock cycle). Latency: ~1 cycle. The SRAM holds intermediate activations between VPU operations.
HBM (High Bandwidth Memory)
32GB of HBM per v4 chip at roughly 1.2 TB/s bandwidth. This is where model weights, activations for the full batch, and optimizer states (Adam moments) live during training. For inference, model weights are loaded from HBM into the systolic array weight FIFOs before each forward pass.
HBM bandwidth is the primary bottleneck for inference. The ratio that matters is arithmetic intensity: FLOPS per byte of HBM access. The TPU's systolic array has very high arithmetic intensity because once weights are loaded, they are reused across the batch dimension.
Host DRAM (CPU memory)
The host machine running the TPU has its own DRAM. Data pipelines, dataset loading, and pre/post-processing run on the CPU and transfer data to TPU HBM over PCIe (for PCIe TPUs) or over custom interconnect (for Cloud TPU VMs where the CPU is tightly coupled).
For training, the TPU's memory hierarchy means that optimizer states are the biggest memory consumer. A 7B parameter model in BF16 requires 14GB just for weights. Adam optimizer states add 2x the weight size in FP32 = 56GB. This exceeds a single v4 chip's 32GB HBM, which is why even "small" LLMs require model parallelism on TPUs.
TPU Pod Topology and Inter-Chip Interconnect
A single TPU v4 chip is fast. A v4 Pod with 4096 chips is a supercomputer.
The chips are connected in a 3D torus topology. Each chip has 6 ICI (Inter-Chip Interconnect) links - two in each of the X, Y, and Z dimensions. Each link runs at 192 GB/s bidirectional. With 4096 chips and 6 links per chip, the total ICI bandwidth is enormous: approximately 1.2 Pb/s aggregated.
This topology is important for collective operations during distributed training. All-reduce (summing gradients across all chips) in a torus topology can be done in hops rather than hops in a ring topology. For 4096 chips, that is 16 hops versus 4096 hops.
The 3D torus enables several parallelism strategies that map naturally onto the topology:
- Data parallelism: each chip holds a full model copy, different mini-batches, gradients are all-reduced across the torus.
- Tensor parallelism: a single layer's weight matrix is sharded across chips along one torus dimension. Each chip computes a partial matrix multiply, then does an all-reduce across that dimension.
- Pipeline parallelism: different layers live on different groups of chips. Micro-batches flow through the pipeline in stages.
PaLM used a combination of data parallelism and model parallelism across 6144 TPU v4 chips. The 3D torus topology made the collective communication bandwidth sufficient to prevent communication from becoming the bottleneck at this scale.
XLA: The Compiler That Makes TPUs Work
You do not write TPU assembly. You write JAX (or PyTorch/XLA), and a compiler called XLA (Accelerated Linear Algebra) translates your high-level operations into TPU machine code.
The XLA Compilation Pipeline
When you decorate a function with @jax.jit, JAX does not execute the function immediately. Instead, it traces through the function with abstract values (ShapedArrays), recording every operation into an HLO program. XLA then compiles the HLO program to TPU machine code. This compilation can take 30 seconds to several minutes for large models.
The compiled program is cached. If you call the same function again with inputs of the same shape and dtype, the cached program runs immediately. If the shape changes, JAX retraces and XLA recompiles - this is the source of the dynamic shape problem in the opening scenario.
Why Static Shapes Are Required
The TPU's systolic array has a fixed size: 256x256. To run a 1024x1024 matrix multiply, XLA tiles it into sixteen 256x256 sub-problems, generates the exact sequence of memory transfers and compute operations required, and hard-codes those sequences into the compiled program.
There is no runtime dispatch. No "if the shape is X, do A; if the shape is Y, do B." The compiled program is a fixed sequence of TPU instructions, like a tape that runs through the hardware. If your shape changes, you need a different tape.
This is why dynamic input lengths in NLP models are a problem. A BERT model taking sequences of length 128 and sequences of length 512 requires two different compiled programs. Your options:
- Bucket padding: pad all sequences to a fixed set of lengths (e.g., 128, 256, 512). You accept some wasted compute on padding but avoid recompilation.
- Static batches: fix your batch size and sequence length and never deviate. Common for training jobs where you control the data pipeline.
- Separate compiled models: pre-compile multiple versions for each expected input shape and select at runtime. Used for inference serving.
JAX Programming Model
JAX is Google's library for high-performance numerical computing on accelerators. For TPUs, it is the primary programming interface.
Core JAX Concepts
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
# jnp is like numpy but runs on TPU/GPU
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
w = jnp.array([[0.5, 0.1], [0.2, 0.8]])
# Matrix multiply on TPU - uses systolic array automatically
result = jnp.dot(x, w)
print(result) # Array on device, not moved to CPU
jit: Compiling for TPU
import jax
import jax.numpy as jnp
from functools import partial
@jax.jit
def forward_pass(params, x):
"""Simple 2-layer MLP - will be compiled to TPU machine code."""
w1, b1, w2, b2 = params
h = jnp.tanh(jnp.dot(x, w1) + b1) # Systolic array for dot
logits = jnp.dot(h, w2) + b2 # Systolic array for dot
return logits
# First call: traces + compiles (slow, ~30s for large models)
output = forward_pass(params, x_batch)
# Subsequent calls with same shapes: runs compiled code (fast)
output = forward_pass(params, x_batch_2) # Same shape = fast
# PROBLEM: different shape triggers recompilation
output = forward_pass(params, x_longer_batch) # SLOW - recompiles
Correct Pattern: Pad to Fixed Size
import jax
import jax.numpy as jnp
def pad_to_fixed_length(tokens, max_len=512):
"""Pad sequence to fixed length to avoid XLA recompilation."""
current_len = tokens.shape[0]
if current_len > max_len:
return tokens[:max_len]
pad_width = max_len - current_len
return jnp.pad(tokens, (0, pad_width), constant_values=0)
def create_attention_mask(tokens, max_len=512):
"""Create mask so attention ignores padding."""
current_len = tokens.shape[0]
mask = jnp.arange(max_len) < current_len # True for real tokens
return mask.astype(jnp.float32)
# Now all batches have identical shapes - no recompilation
@jax.jit
def encode_batch(params, token_ids, attention_mask):
# XLA compiles once for shape (batch_size, 512)
# All subsequent calls reuse compiled code
return bert_forward(params, token_ids, attention_mask)
pmap: Data Parallelism Across TPU Chips
import jax
import jax.numpy as jnp
from jax import pmap
# Get number of available TPU devices
num_devices = jax.device_count()
print(f"Available TPU chips: {num_devices}") # e.g., 8 for v3-8
# Replicate params across all devices
params_replicated = jax.device_put_replicated(params, jax.devices())
# pmap: vectorizes over leading axis, distributes across devices
@partial(pmap, axis_name='batch')
def train_step(params, x_batch, y_batch):
def loss_fn(p):
logits = forward_pass(p, x_batch)
return cross_entropy_loss(logits, y_batch)
loss, grads = jax.value_and_grad(loss_fn)(params)
# All-reduce gradients across all chips via ICI
grads = jax.lax.pmean(grads, axis_name='batch')
return loss, grads
# x_sharded shape: (num_devices, per_device_batch, seq_len)
# Each device gets its slice automatically
x_sharded = x_batch.reshape(num_devices, -1, seq_len)
y_sharded = y_batch.reshape(num_devices, -1)
loss, grads = train_step(params_replicated, x_sharded, y_sharded)
Profiling with JAX Profiler
import jax
import jax.numpy as jnp
# Capture a trace for tensorboard / perfetto
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Your model forward/backward pass here
for step in range(10):
loss, grads = train_step(params_replicated, x_sharded, y_sharded)
params_replicated = update_params(params_replicated, grads)
jax.effects_barrier() # Force synchronization for accurate profiling
# View trace: open /tmp/jax-trace in https://ui.perfetto.dev
# Key metrics to look for:
# - MXU utilization (should be > 40% for matrix-heavy code)
# - Host-to-device transfer time (should be < 5% of step time)
# - Time in compilation vs execution
Performance Analysis: When TPUs Win
TPUs dominate on workloads with these characteristics:
| Workload Property | TPU Advantage | Reason |
|---|---|---|
| Large batch matrix multiply | High | MXU 100% utilized, no cache misses |
| Fixed-shape transformer training | High | XLA compiles once, runs at peak throughput |
| Dense feedforward layers | High | GEMM bound, systolic array ideal |
| Attention (dense, fixed seq len) | High | Large matmuls, statically shaped |
| Sparse attention, variable lengths | Low | Dynamic shapes trigger recompilation |
| Custom CUDA ops | Low/None | No CUDA on TPU |
| Small batch inference | Medium | MXU underutilized at small batch sizes |
| Research with custom ops | Low | Operator coverage gaps in XLA |
PaLM training on TPU v4:
PaLM (540B parameters) was trained on 6144 TPU v4 chips. The reported hardware FLOPs utilization (HFU) was 46.2% - remarkably high for a 540B parameter model. The key factors:
- PaLM used a "Pathways" data parallel + tensor parallel strategy that mapped cleanly onto the TPU Pod 3D torus
- All sequence lengths were bucketed to fixed sizes, eliminating recompilation
- The feedforward and attention layers are pure matrix multiplies - perfect MXU utilization
The training throughput: approximately 2.5 trillion tokens per day across the full pod. A comparable run on A100s would have required roughly 2x the hardware cost and 3x the power.
Production Engineering Notes
Sizing Your TPU Slice
TPU slices are purchased in specific configurations. A "v4-128" means 128 TPU v4 chips (not 128 tpu "cores" - each chip has one MXU). Common configurations:
v4-8: single host, 8 chips, 256GB HBM total. Good for 7B-13B model training.v4-32: 32 chips, 1TB HBM. Good for 30B-70B models.v4-128: 128 chips. Training 100B+ models.v4-512tov4-4096: full pod or multi-pod. Frontier model training.
Rule of thumb: you need roughly 2x the model parameter count in HBM (accounting for BF16 weights + FP32 optimizer states + activations). A 70B parameter model needs approximately 280GB minimum (70B * 4 bytes for FP32 Adam states), so a v4-32 slice (1TB HBM) is the practical minimum.
Minimizing Compilation Time in Practice
import jax
import jax.numpy as jnp
import pickle
import os
# Strategy 1: Ahead-of-time compilation cache
# XLA caches compiled programs to disk if you set this env var
os.environ["XLA_FLAGS"] = "--xla_gpu_cache_path=/tmp/xla_cache"
# Strategy 2: Warm up explicitly before timed benchmarks
@jax.jit
def model_step(params, batch):
return forward_pass(params, batch)
# Warmup with representative shapes
dummy_batch = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
_ = model_step(params, dummy_batch).block_until_ready()
print("Compilation complete, starting timed benchmark...")
# Strategy 3: Use static_argnums for values that change shapes
@partial(jax.jit, static_argnums=(2,))
def encode(params, tokens, seq_len):
# seq_len is static - different values = different compiled programs
# but this is explicit, not accidental
return transformer_encoder(params, tokens, seq_len)
Memory Layout Matters
XLA has strong opinions about tensor memory layout. The default layout for 2D tensors is row-major (C-order). For matrix multiply with the systolic array, input layout affects data flow through the MXU.
import jax
import jax.numpy as jnp
# XLA may insert transposes if layout doesn't match MXU expectations
# Be explicit about shapes to help the compiler
def attention(q, k, v):
# q, k, v: (batch, heads, seq_len, head_dim)
# This layout is standard and XLA handles it well
scale = q.shape[-1] ** -0.5
scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale
weights = jax.nn.softmax(scores, axis=-1)
return jnp.einsum('bhqk,bhkd->bhqd', weights, v)
# Avoid layouts that require implicit transposes
# Bad: (seq_len, batch, heads, head_dim) - XLA will add a transpose
# Good: (batch, heads, seq_len, head_dim) - maps to MXU layout
Common Mistakes
:::danger Dynamic Shapes Kill Performance
The single most common TPU performance bug is accidentally introducing dynamic shapes. This can happen through:
jnp.where(mask, x, y)wheremaskchanges shape- Python-level
if/elseon tensor shapes inside a jitted function - Variable-length sequence inputs without bucketing
jax.lax.condwith different-shaped branches
The symptom: periodic spikes in step time (recompilation) interleaved with fast steps. Always check for "Tracing Python function" messages in JAX logs during training.
Fix: pad all inputs to fixed shapes before passing them to jitted functions. Use jnp.pad(). Accept the wasted compute on padding - it is almost always less costly than recompilation.
:::
:::danger Do Not Use Python Conditionals on Tensors Inside jit
This is the most common mistake from engineers coming from PyTorch:
# WRONG - Python conditional depends on tensor value
@jax.jit
def bad_forward(x, threshold):
if jnp.mean(x) > threshold: # Python if on tensor = error or retrace
return jnp.relu(x)
else:
return jnp.tanh(x)
# CORRECT - use jax.lax.cond for data-dependent branching
@jax.jit
def good_forward(x, threshold):
return jax.lax.cond(
jnp.mean(x) > threshold,
lambda x: jnp.relu(x),
lambda x: jnp.tanh(x),
x
)
JAX tracing converts Python conditionals to static branches. If the condition depends on a traced value, JAX raises an error. If it depends on a Python value, it specializes the compiled program to that value (which means recompilation if the value changes). :::
:::warning Small Batch Sizes Underutilize the MXU
The 256x256 MXU achieves peak throughput when operating on matrices that are at least 256 in all dimensions. A batch size of 4 with a 1024-dimension embedding layer means the activation matrix is (4, 1024) - the leading dimension is 4, far smaller than 256. The MXU will be 98% idle on the batch dimension.
Practical minimum batch sizes for good TPU utilization: 32 or higher for training. For inference, use continuous batching or dynamic batching to aggregate requests into larger batches before submitting to the TPU. :::
:::warning Unsupported Operations Fall Back to Host
Some operations are not natively supported in XLA's TPU backend. When JAX encounters them, it either errors or silently falls back to executing on the CPU host. This fallback involves a device-to-host memory transfer, CPU execution, and host-to-device transfer back - typically 10-100x slower than the operation would be on GPU.
Common problematic ops:
- Some advanced indexing patterns (scatter with non-contiguous indices)
- Custom CUDA kernels (not supported at all)
- Operations on very small tensors where XLA overhead exceeds execution time
- Certain sparse tensor operations
Always run jax.make_jaxpr(your_function)(*dummy_inputs) to inspect the compiled HLO and look for unexpected custom_call nodes that indicate fallback to host.
:::
Interview Questions and Answers
Q1: Explain how a systolic array works and why it achieves better utilization than GPU tensor cores for large matrix multiplications.
A systolic array is a 2D grid of processing elements (PEs) where each PE performs one multiply-accumulate per clock cycle. The key property is that data flows through the grid in a pipelined, pre-scheduled manner - weights are preloaded into PE registers, and activation rows flow in from the left while partial sums accumulate downward. No PE fetches from memory during computation.
GPU Tensor Cores also perform matrix multiplies but rely on the cache hierarchy (L2, shared memory) to feed the compute units. For matrix operations larger than L2 cache, Tensor Cores stall waiting for HBM fetches. The TPU's MXU avoids this by fitting the weight matrix (or a tile of it) entirely in PE registers - 65,536 MACs operating on locally-held weights.
The practical result: TPU MFU (model FLOP utilization) for transformer training is typically 50-60% of theoretical peak, versus 30-45% for well-optimized GPU training. The gap widens for larger models where GPU cache miss rates increase.
Q2: Why does XLA require static shapes, and how do you write a JAX training loop that avoids dynamic shape recompilation?
XLA compiles a function by tracing it with abstract ShapedArrays (shape + dtype, no actual values). The compiled program is a fixed sequence of TPU instructions that are tiled and scheduled for the specific input shapes at compile time. There is no runtime dispatch logic in the compiled code - it executes the same instruction sequence regardless of the actual tensor values.
If input shapes change, the tiling is wrong (a 256x512 tile pattern does not map to a 256x384 input without producing incorrect results or accessing out-of-bounds memory), so XLA must recompile.
To avoid this: (1) pad all variable-length inputs to fixed sizes using jnp.pad(), (2) create attention masks to ignore padding in the computation, (3) use static_argnums in jax.jit to mark integer-valued arguments that should be treated as compile-time constants, (4) pre-bucket your dataset by sequence length into fixed-width buckets (e.g., 128, 256, 512, 1024) and compile separate programs for each bucket.
Q3: How does TPU Pod topology enable efficient gradient all-reduce at scale, and how does this compare to GPU cluster all-reduce over InfiniBand?
TPU v4 Pods use a 3D torus topology with ICI (Inter-Chip Interconnect) links at 192 GB/s per link per direction, 6 links per chip. The 3D torus means every chip is connected to 6 neighbors. An all-reduce across all 4096 chips in a v4 Pod takes hops in each direction, with each hop moving full-bandwidth data.
GPU clusters running InfiniBand use ring or tree all-reduce topologies via NCCL. A ring all-reduce across N GPUs takes 2(N-1) steps, with each step using one InfiniBand link. InfiniBand bandwidth per link is 200-400 Gb/s - comparable per-link bandwidth to ICI, but the topology means the communication takes more steps.
More importantly: in GPU clusters, all-reduce traffic competes with other traffic on shared switches. In a TPU Pod, ICI is dedicated point-to-point silicon with no packet contention. The practical result is that gradient synchronization in TPU Pod training is predictable and low-latency, while GPU cluster all-reduce can have high variance depending on network congestion.
Q4: What is the BF16 format and why did Google choose it for TPU training rather than FP16?
BF16 (Brain Float 16) allocates its 16 bits as: 1 sign bit, 8 exponent bits, 7 mantissa bits. FP16 allocates: 1 sign bit, 5 exponent bits, 10 mantissa bits.
The 8 exponent bits in BF16 match FP32 exactly. This means BF16 and FP32 cover the same dynamic range (roughly to ). FP16 has only 5 exponent bits, so it covers a much narrower range ( to approximately), which causes gradient underflow and overflow during training.
The consequence: training with FP16 requires loss scaling - artificially multiplying the loss by a large constant before backprop, then dividing the gradients back, to avoid FP16 underflow. This adds implementation complexity and can still cause instability for very deep networks or unusual gradient distributions.
BF16 has the same range as FP32 so no loss scaling is needed. The lower mantissa precision (7 vs 23 bits in FP32) reduces numerical accuracy slightly but in practice neural network training is robust to this - stochastic gradient descent is already noisy, and 7 mantissa bits is enough to represent gradient magnitudes accurately.
TPU v2 was designed around BF16 from the start. Google's training stack has never needed to implement loss scaling.
Q5: A team is migrating a 70B parameter transformer from A100s to TPU v4. What questions would you ask to assess whether the migration will succeed, and what are the top risks?
First, the operational questions:
-
Does the model use any custom CUDA kernels? Those do not exist on TPU. Any custom attention implementations, custom quantization ops, or C++ extensions need to be rewritten in JAX or replaced with XLA-compatible equivalents.
-
What are the sequence lengths in the training data? Are they variable? If yes, what is the distribution? You need to design a bucketing strategy that covers 99.9% of your data without excessive padding waste.
-
Does the data pipeline use PyTorch DataLoader? JAX training loops expect NumPy or tf.data inputs. You will need to rewrite the data pipeline, or use the tf.data - JAX bridge.
-
What parallelism strategy is currently used? ZeRO-3 on GPUs does not map directly to TPU. You will need to redesign using
jax.pmap+ manual sharding or the newerjax.shardingAPI.
Top risks:
- XLA operator gaps: advanced attention variants (e.g., FlashAttention with custom tiling) may not have XLA equivalents and will require rewriting.
- Compilation time: a 70B model may take 15-30 minutes to compile in XLA. This is a fixed cost paid once per unique shape, but it means your startup time is high and any shape changes are expensive.
- Memory layout issues: the v4 slice needed for 70B training (roughly
v4-128for reasonable batch sizes) may have different memory pressure characteristics than your A100 setup, requiring adjustments to gradient checkpointing. - Performance regression on non-matmul ops: if the model has significant non-matmul computation (custom MoE routing, variable-length attention), the VPU may be a bottleneck where the GPU was faster.
Q6: How does XLA operation fusion improve TPU performance, and how do you write code that enables effective fusion?
XLA's optimizer analyzes the HLO computation graph and fuses multiple element-wise operations into a single kernel to eliminate intermediate HBM round-trips. For example:
x = jnp.relu(x) # HBM read, compute, HBM write
x = x + bias # HBM read, compute, HBM write
x = jnp.layer_norm(x) # HBM read, compute, HBM write
Without fusion, this requires 6 HBM accesses (3 reads + 3 writes). With fusion, XLA combines these into a single fused kernel: HBM read once, compute all three operations in registers, HBM write once. 2 HBM accesses instead of 6 - a 3x reduction in memory bandwidth.
XLA performs this fusion automatically, but you can help it by:
- Keeping operations on the same device without unnecessary
device_putcalls between them - Avoiding Python-level loops that break the computation graph into multiple JIT regions
- Using
jax.checkpoint(activation checkpointing) at coarse boundaries rather than fine-grained checkpoints, so XLA can fuse across operation boundaries within a segment
The profile tool (jax.profiler.trace) shows whether fused operations appear as single HLO entries. If you see many small sequential HBM accesses for element-wise ops, the fusion is not happening - usually because something in the code path broke the JIT region.
TPU vs GPU: A Decision Framework
Not every ML workload belongs on a TPU. This framework helps you decide quickly.
Real-World LLM Training on TPUs: The PaLM and Gemini Story
The two most instructive case studies for TPU training at scale are PaLM (Pathways Language Model) and Gemini.
PaLM (2022): 540B parameters on 6144 TPU v4 chips
PaLM was trained using a combination of data parallelism and model parallelism. The Pathways system managed task scheduling across the v4 Pod. Key architectural decisions that made TPU training efficient:
- All sequence lengths bucketed to 2048 tokens. No dynamic shapes in training.
- Feedforward network used SwiGLU activation, which is a composition of element-wise ops that XLA fuses efficiently.
- Multi-query attention (single key/value heads shared across query heads) reduced the memory bandwidth requirement for KV cache.
- The feedforward and attention layers dominate FLOP count, and both are pure matrix multiplies - the MXU stays busy.
Reported hardware FLOPs utilization (HFU): 46.2%. This is the fraction of theoretical peak TFLOPS that was actually used for useful computation (excluding padding, recompilation, communication overhead). For a 540B model across 6144 chips, 46% is excellent.
The training throughput: approximately 1.5 weeks to train 540B parameters on 780 billion tokens. On A100 hardware with similar compute, Google estimated the same run would have cost roughly 2x more.
Gemini Ultra (2023): The multimodal push
Gemini Ultra was trained across multiple TPU v4 Pods simultaneously using the Pathways infrastructure. The multimodal design - processing text, images, audio, and video in a single model - posed specific challenges for TPU efficiency:
- Image tokens and text tokens have different sequence lengths. Bucketing strategy required separate compiled programs for different modality combinations.
- Video inputs introduced temporal dimension - 3D attention patterns with larger attention matrices.
- The model was trained with "interleaved" batches where each batch contained a mix of modality types, each pre-bucketed to fixed shapes.
The key lesson from both runs: TPU efficiency at scale requires treating shape management as a first-class engineering concern, not an afterthought. The teams building PaLM and Gemini had dedicated engineers whose primary job was managing the bucketing strategy, monitoring for recompilation events, and optimizing data pipeline throughput to keep the TPUs fed.
Gradient Checkpointing and Memory Management on TPU
Training large models on TPU requires careful memory management. Activations from the forward pass need to be stored for the backward pass, but storing all activations for a 100+ layer transformer quickly exceeds HBM capacity.
Gradient checkpointing with jax.checkpoint:
import jax
import jax.numpy as jnp
from functools import partial
# Without checkpointing: all activations stored in HBM during forward pass
# Memory: O(layers * batch_size * seq_len * hidden_dim)
# With checkpointing: only recompute activations during backward pass
# Memory: O(sqrt(layers) * batch_size * seq_len * hidden_dim)
# Compute: ~33% more FLOPs (activations recomputed once)
@partial(jax.checkpoint, prevent_cse=False)
def transformer_layer(params, x):
"""Single transformer layer with gradient checkpointing."""
# Attention
q = jnp.dot(x, params['q_proj'])
k = jnp.dot(x, params['k_proj'])
v = jnp.dot(x, params['v_proj'])
scale = q.shape[-1] ** -0.5
attn = jax.nn.softmax(jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale)
out = jnp.einsum('bhqk,bhkd->bhqd', attn, v)
x = x + jnp.dot(out.reshape(*out.shape[:-2], -1), params['out_proj'])
# FFN
h = jax.nn.gelu(jnp.dot(x, params['ff1']))
x = x + jnp.dot(h, params['ff2'])
return x
def full_model_forward(params, x):
"""Apply all layers with checkpointing - saves memory at cost of compute."""
for layer_params in params['layers']:
x = transformer_layer(layer_params, x)
return x
# XLA compiles this: during backward pass, it recomputes each layer's
# activations from the checkpointed layer input rather than storing all of them.
# HBM usage: one layer's activations at a time, not all layers simultaneously.
Choosing what to checkpoint:
Not everything should be checkpointed. The rule is to checkpoint at the granularity where the saved memory is large and the recompute cost is modest. For transformers:
- Checkpoint at transformer layer boundaries (recompute the full layer during backward)
- Do not checkpoint at individual operation level (too granular, overhead exceeds savings)
- For very large models, use "selective checkpointing": checkpoint the attention and FFN blocks separately, keeping the residual stream in HBM
TPU-Specific Numerics: BF16 Training Stability
Training in BF16 on TPU is not identical to training in FP32 on CPU. Understanding where numerical differences arise prevents debugging headaches.
Layer normalization precision:
Layer norm computes mean and variance, which involves summing many small numbers. In BF16, this sum can accumulate significant error for long sequences. The standard fix: cast to FP32 for the reduction in layer norm, then cast back to BF16.
import jax
import jax.numpy as jnp
def layer_norm(x, gamma, beta, eps=1e-5):
"""Layer norm with FP32 accumulation for numerical stability."""
# Cast to FP32 for mean/variance computation
x_f32 = x.astype(jnp.float32)
mean = jnp.mean(x_f32, axis=-1, keepdims=True)
var = jnp.var(x_f32, axis=-1, keepdims=True)
x_norm = (x_f32 - mean) / jnp.sqrt(var + eps)
# Cast back to BF16 for main computation
return (gamma * x_norm + beta).astype(x.dtype)
# This pattern is used in virtually all production TPU training stacks.
# The FP32 cast is cheap (VPU operation) and the stability benefit is large.
Softmax stability:
Softmax with large logits overflows even BF16. The numerically stable implementation subtracts the max before exponentiation. XLA's jax.nn.softmax does this automatically, but custom attention implementations often forget it.
import jax
import jax.numpy as jnp
def stable_softmax(logits, axis=-1):
"""Numerically stable softmax - essential for BF16 training."""
# Subtract max for numerical stability
# Without this: exp(large_logit) overflows BF16 (max ~65504)
shifted = logits - jnp.max(logits, axis=axis, keepdims=True)
exp_logits = jnp.exp(shifted)
return exp_logits / jnp.sum(exp_logits, axis=axis, keepdims=True)
# jax.nn.softmax already implements this correctly.
# Only write your own if you need a custom implementation.
Summary
Google TPUs are purpose-built matrix multiplication engines built around a 256x256 systolic array (the MXU). The systolic array achieves near-peak utilization on large matmuls by preloading weights into PE registers and flowing activations through at wire speed - no cache hierarchy involved.
The price of this efficiency is the static computation graph requirement. XLA compiles JAX programs to fixed TPU instruction sequences. Dynamic input shapes invalidate the compiled program and trigger expensive recompilation. Every effective TPU programmer learns to work with static shapes, padding to fixed sizes, and pre-bucketing variable-length datasets.
TPU Pods connect thousands of chips over high-bandwidth ICI in a 3D torus topology, enabling efficient gradient all-reduce for distributed training at scales (PaLM, Gemini) that require thousands of chips. The combination of per-chip efficiency and inter-chip bandwidth is why Google can train frontier models at lower cost than equivalent GPU clusters.
When evaluating a TPU migration: assess custom op usage, variable shape patterns, and parallelism strategy first. If the workload is a standard transformer with controlled input shapes and no custom CUDA ops, TPUs will likely deliver 1.5-2x better cost-efficiency than A100s at scale. If the workload has heavy custom ops or highly variable shapes, the migration cost may outweigh the gains.
