Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the TPU Architecture demo on the EngineersOfAI Playground - no code required. :::

TPU Architecture and Use

The BERT Pretraining Cost Problem

A research team needed to pretrain BERT-Large from scratch. Their estimates from their GPU cluster: 64 NVIDIA V100 GPUs × 4 days × 3/hr=3/hr = 18,432. Acceptable, but the compute allocation was constrained. Then someone ran the numbers on a Google Cloud TPU v3 Pod: 512 TPU cores available as a flat resource, BERT-Large pretrain in 76 minutes, cost: 512×8cores×512 × 8 cores × 2.40/hr/8-core slice × 1.27 hours = approximately $1,230. Fifteen times cheaper.

The team's reaction: why does nobody talk about this? They had been budgeting GPU time when a TPU Pod would have trained their model in 2 hours for a fraction of the cost.

But then they tried to actually use it. XLA compilation failed on three custom ops. The JAX learning curve was steep. Dynamic shapes that worked trivially in PyTorch caused compilation failures in XLA. After a week of debugging, they had the model running on TPU, but they had also discovered why TPUs are not universally adopted: the productivity cost of adapting code to TPU constraints is real, and it amortizes only over large, repeated training runs.

This lesson gives you the complete picture - what TPUs are, why they are fast for certain workloads, what the programming model tradeoffs are, and how to decide whether TPUs are right for your training scenario.


Why TPUs Exist

Google designed TPUs to solve a specific problem: Google's data centers were running neural network inference at such scale that a significant fraction of all compute cycles were spent on matrix multiplication. The computational density of conventional CPUs and GPUs was insufficient. The cost per inference was too high.

The first TPU (v1, 2016) was an inference accelerator: a custom ASIC optimized purely for the matrix multiply-accumulate operations that dominate neural network computation. It achieved 92 TOPS at 8-bit precision, consuming 40W - an order of magnitude better efficiency than contemporary GPUs for inference.

TPU v2 and v3 added on-chip High Bandwidth Memory and support for training (backward pass). TPU v4 (2021) increased performance to 275 TFLOPS per chip at bfloat16, with dramatically improved interconnect topology. TPU v5e and v5p (2023) pushed to 393 TFLOPS and 459 TFLOPS respectively, with v5p targeting large model training.

The key design insight: GPUs are general-purpose parallel processors. TPUs are purpose-built matrix multiplication accelerators. For workloads that are almost entirely matrix multiplications (neural network training and inference), this specialization pays off in performance-per-watt.


The Systolic Array

The core of every TPU is a systolic array - a 2D grid of multiply-accumulate units where data flows rhythmically between units, like a heartbeat (hence "systolic").

A 256×256 systolic array has 65,536 multiply-accumulate units. In each clock cycle:

  1. One row of the input matrix enters from the left
  2. One column of the weight matrix enters from the top
  3. Each unit multiplies its input × weight and adds to its running accumulator
  4. Results flow down and to the right

After 256 clock cycles, the entire 256×256 output of a matrix multiplication emerges from the bottom. The key property: all 65,536 multiply-accumulate units are busy every clock cycle. There is no memory bandwidth bottleneck during computation - weights are already loaded into the array, and inputs flow through in a streaming fashion.

Compare to a GPU: even with Tensor Cores, a matrix multiply requires loading both matrices from HBM, executing the GEMM, and writing results back. The systolic array avoids most of this memory traffic by doing everything in the array itself.

Systolic Array - 4×4 example (256×256 in real TPU):

Weights: W00 W01 W02 W03
↓ ↓ ↓ ↓
Input: → [A00] → [*] → [*] → [*] → [*]
[A01] → [*] → [*] → [*] → [*]
[A02] → [*] → [*] → [*] → [*]
[A03] → [*] → [*] → [*] → [*]
↓ ↓ ↓ ↓
Output: C00 C01 C02 C03

[*] = multiply-accumulate unit
Each cycle: all 16 units multiply their input × weight,
accumulate into local register
Data flows: inputs flow right, weights flow down,
partial sums accumulate in each unit

TPU v4 Architecture

TPU v4 (Google, 2021) specifications:

  • Chips per pod: 4,096 chips per TPU v4 Pod
  • Compute per chip: 275 TFLOPS BF16 (2× MXU, each 128×128 systolic array)
  • On-chip memory: 32 GB HBM per chip at 1.2 TB/s bandwidth
  • Interconnect: ICI (Inter-Chip Interconnect) - custom 3D torus topology with 1.6 TB/s aggregate bandwidth per pod
  • Power: ~170W per chip

The interconnect is particularly notable. GPUs use NVLink (for within-node) + InfiniBand (for cross-node), with a dramatic bandwidth drop at the node boundary. TPU pods use ICI uniformly across all chips - there is no "node boundary" topology discontinuity. Every chip connects to its neighbors at the same bandwidth, whether they are physically adjacent or across the pod.

This uniform interconnect enables efficient parallelism strategies that GPU clusters cannot implement cleanly.


XLA and the Compilation Model

TPUs do not execute arbitrary code like a GPU's CUDA kernels. Instead, they execute XLA (Accelerated Linear Algebra) programs - compiled computation graphs that the XLA compiler has optimized for the specific hardware.

When you write a JAX function and JIT-compile it with @jax.jit, XLA:

  1. Traces the function to build a computation graph
  2. Performs algebraic simplifications (constant folding, dead code elimination)
  3. Fuses element-wise operations into single kernels
  4. Lays out tensors in memory to maximize systolic array utilization
  5. Compiles to machine code specific to the target hardware (TPU, GPU, CPU)

The constraint that causes friction: XLA requires static shapes. The computation graph is compiled for specific tensor shapes. If a tensor's shape changes between calls (e.g., variable-length sequences in a batch), XLA must recompile - or you must pad sequences to a fixed length and mask variable-length processing.

This is the biggest practical obstacle for users coming from PyTorch. PyTorch is define-by-run (eager mode): shapes can change freely between calls. XLA is define-then-run: shapes are fixed at compilation time.


JAX: The TPU-Native ML Framework

JAX (Bradbury et al., Google Brain, 2018) is the ML framework designed for TPUs. It compiles NumPy-like code to XLA programs via jax.jit. JAX is also the framework used by many recent DeepMind and Google research projects (AlphaFold, Gemini training components, etc.).

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial

# JAX is functional - no mutable state, pure functions
def linear_layer(params: dict, x: jnp.ndarray) -> jnp.ndarray:
return x @ params["W"] + params["b"]

def model_forward(params: dict, x: jnp.ndarray) -> jnp.ndarray:
"""Simple MLP forward pass."""
x = linear_layer(params["layer1"], x)
x = jax.nn.relu(x)
x = linear_layer(params["layer2"], x)
return x

def cross_entropy_loss(params: dict, x: jnp.ndarray, labels: jnp.ndarray) -> float:
logits = model_forward(params, x)
return -jnp.mean(jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1))

# Compute gradients with respect to params
loss_grad_fn = jit(grad(cross_entropy_loss, argnums=0))

# JIT compile the loss function - traces once, compiles to XLA, then fast
@jit
def train_step(params: dict, x: jnp.ndarray, labels: jnp.ndarray, learning_rate: float):
"""Single training step - JIT compiled to XLA."""
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, labels)
# Tree map applies the update to every leaf in the params pytree
new_params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g,
params,
grads,
)
return new_params, loss


# Vectorize over batch dimension with vmap
# vmap(f) applies f independently to each element along axis 0 - like a batched for-loop
# but compiled to efficient batch operations
batched_forward = vmap(partial(model_forward, params), in_axes=0)

# Initialize parameters
import numpy as np
params = {
"layer1": {
"W": jnp.array(np.random.randn(784, 256) * 0.01, dtype=jnp.float32),
"b": jnp.zeros(256),
},
"layer2": {
"W": jnp.array(np.random.randn(256, 10) * 0.01, dtype=jnp.float32),
"b": jnp.zeros(10),
}
}

# Training loop
x = jnp.array(np.random.randn(1024, 784), dtype=jnp.float32)
labels = jnp.array(np.eye(10)[np.random.randint(0, 10, 1024)], dtype=jnp.float32)

for step in range(100):
params, loss = train_step(params, x, labels, 1e-3)
if step % 10 == 0:
print(f"Step {step}: loss = {loss:.4f}")

TPU Data Parallelism with pmap

JAX's pmap (parallel map) distributes computation across TPU chips:

import jax
from jax import pmap
import jax.numpy as jnp

# Get number of available TPU/GPU devices
n_devices = jax.device_count()
print(f"Available devices: {n_devices}")

# Replicate parameters across all devices
def replicate(x):
"""Replicate a value on all devices."""
return jax.device_put_replicated(x, jax.devices())

replicated_params = replicate(params)

@partial(pmap, axis_name="batch")
def distributed_train_step(params, x, labels, learning_rate):
"""
pmap executes this function independently on each device.
axis_name="batch" allows cross-device operations via pmean/psum.
"""
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, labels)

# Average gradients across all devices (all-reduce)
grads = jax.lax.pmean(grads, axis_name="batch")
loss = jax.lax.pmean(loss, axis_name="batch")

# Update parameters
new_params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g,
params,
grads,
)
return new_params, loss

# Split batch across devices: shape (n_devices, local_batch, ...)
def shard_batch(x, n_devices):
"""Split batch dimension across devices."""
assert x.shape[0] % n_devices == 0
return x.reshape(n_devices, x.shape[0] // n_devices, *x.shape[1:])

# Example: 1024 total batch, 8 devices → 128 per device
x_sharded = shard_batch(x, n_devices)
labels_sharded = shard_batch(labels, n_devices)

# Run distributed training step
replicated_params, losses = distributed_train_step(
replicated_params,
x_sharded,
labels_sharded,
jnp.array(1e-3), # must be a JAX array, not Python float
)
print(f"Mean loss: {losses[0]:.4f}")

GPU vs TPU: When Each Wins

Cost Comparison: BERT-Large Pretraining

HardwareTimeCost/hr per unitTotal cost
TPU v3 Pod (512 cores)76 min$2.40/hr per 8-core~$1,230
64× A100 80GB (on-demand)6 hours$3.21/hr per GPU~$1,232
64× A100 80GB (spot)6 hours~$1.00/hr per GPU~$384
TPU v4 Pod (preemptible)45 min~$1.20/hr per 8-core~$460

The TPU advantage is real for large-scale standard training. The GPU spot instance advantage is also real for price-performance when using preemptible instances. For custom research code with non-standard ops, GPUs win on productivity.


TPU Limitations

No support for custom CUDA operations. Every TensorFlow/JAX op you write must be expressible in XLA's primitive operations. Custom CUDA kernels, third-party libraries that link against CUDA (some tokenizers, some data augmentation libraries), and custom C extensions cannot run on TPU.

Static shapes required. Variable-length inputs require padding. For NLP with highly variable document lengths, this wastes compute on padding tokens. FlashAttention - one of the most important GPU optimizations - was initially incompatible with TPU due to its use of dynamic shapes (though implementations now exist for TPU via JAX).

Limited operator coverage. While coverage has dramatically improved, you may encounter ops that are not supported or perform poorly on TPU. Check before committing to TPU for a new architecture.

Cold compilation cost. The first XLA compilation of a function can take 30–300 seconds. Subsequent calls hit the compilation cache. For iterative research with frequent model changes, this adds up. For production training runs that run the same code for days, it is a one-time cost.


Production Engineering Notes

Use bfloat16 natively - TPUs are designed for it. TPU Matrix Units operate natively in bfloat16. Float32 operations are significantly slower on TPU (4× less FLOP throughput). Always use bfloat16 for TPU training.

Profile compilation time separately from execution time. JAX/XLA compilation can dominate early training steps. Use jax.jit(...).lower(...).compile() to pre-compile before training, or count the first 5 steps as warmup in your benchmark timing.

Checkpoint to GCS, not local disk. TPU VMs have limited local disk. All checkpoints should go to Google Cloud Storage. Use orbax (JAX checkpoint library) which supports async GCS writes.


Common Mistakes

:::danger Using dynamic shapes or Python-level control flow that depends on tensor values JAX traces your function once and compiles the static computation graph. If your Python code has if tensor.sum() > 0:, JAX traces only the path taken during tracing - the other branch is never compiled. Use jax.lax.cond(condition, true_fn, false_fn, ...) instead for data-dependent branching. :::

:::warning Benchmarking TPU performance on small models or short runs The XLA compilation overhead amortizes over long training runs. If you run a 5-minute benchmark on TPU vs GPU, compilation costs dominate TPU timing. Always benchmark over at least 500 steps, discarding the first 20 as compilation warmup. Also benchmark at the scale where you plan to deploy - TPU pods have different performance characteristics at 512 chips than at 8 chips. :::

:::tip Start with Flax or T5X for TPU training rather than raw JAX Writing a full transformer model in raw JAX is educational but time-consuming. Flax (Google's neural network library) provides PyTorch-like model definitions that compile cleanly to XLA. T5X (Google's training framework) provides the full training infrastructure including checkpointing, evaluation, and multi-host distribution. These frameworks encode hard-won best practices for TPU programming. :::


Interview Questions

Q1: What is a systolic array and why is it well-suited for matrix multiplication?

A systolic array is a 2D grid of multiply-accumulate units where data flows rhythmically between adjacent cells. For matrix multiplication C=A×BC = A \times B: input rows flow from left to right, weight columns flow from top to bottom, and each unit accumulates partial products in place. Every unit is busy every clock cycle - utilization approaches 100% during the core computation. GPUs perform matrix multiplication via GEMM kernels that must load both matrices from HBM, execute the multiply, and write results back. The systolic array avoids most of this memory traffic by streaming data through the array and accumulating results in-register. This makes it extremely efficient at the specific operation that dominates neural network computation.

Q2: What does XLA compilation mean and why does it require static shapes?

XLA (Accelerated Linear Algebra) is a compiler that takes a high-level computation graph (expressed in JAX's HLO representation) and compiles it to optimized hardware-specific code. The compilation performs algebraic simplifications, operation fusion (merging multiple element-wise ops into one pass), memory layout optimization, and hardware-specific code generation. Static shapes are required because many compiler optimizations depend on knowing tensor dimensions at compile time - buffer sizes, loop bounds, parallelization strategies. With static shapes, XLA can produce highly optimized code where all parameters are constants at compile time. Dynamic shapes require either recompilation for each new shape or generalized code that is less efficient.

Q3: When would you recommend TPUs over GPUs for a training project?

TPUs are most compelling for: (1) large-scale standard transformer training (>10B parameters, >7 day runs) where the training cost savings justify the 1–2 week productivity cost of adapting code to XLA/JAX; (2) scenarios where you can access TPU Pods (Google Cloud) with preemptible pricing for significant cost savings; (3) teams already invested in the JAX/Flax/T5X ecosystem who do not pay the adaptation cost. GPUs are preferred for: (4) research with custom ops, dynamic shapes, or non-standard architectures; (5) smaller-scale training where compilation overhead is a larger fraction of total time; (6) inference workloads where GPU tooling (TensorRT, vLLM) is more mature.

Q4: Explain JAX's pmap and how it differs from PyTorch's DDP.

pmap (parallel map) applies a function independently to each element along a batch dimension, one element per device. Data-parallel gradient averaging is expressed explicitly inside the function using jax.lax.pmean - there is no hidden gradient synchronization. This explicit model makes the communication pattern visible in code, which simplifies debugging. DDP wraps the model in a module that automatically adds gradient all-reduce hooks to each parameter tensor after the backward pass. DDP is more transparent to PyTorch users (no model code changes needed) but the synchronization is implicit. pmap requires restructuring code around functional transforms but gives more explicit control over what communicates and when.

Q5: A BERT-Large pretraining job costs 18,000onyourGPUcluster.YouestimateTPUv4Podpreemptiblewouldcost18,000 on your GPU cluster. You estimate TPU v4 Pod preemptible would cost 1,500 for the same result, but your team uses PyTorch and has no JAX experience. How do you make the decision?

The 16,500costsavingsneedstobeweighedagainst:(1)developertimetoportBERTLargetrainingcodefromPyTorchtoJAX/Flaxestimate13engineerweeksforacleanportwithXLAcompatibility;at16,500 cost savings needs to be weighed against: (1) developer time to port BERT-Large training code from PyTorch to JAX/Flax - estimate 1–3 engineer-weeks for a clean port with XLA compatibility; at 200/hr loaded rate, that is 8,0008,000–24,000; (2) debugging time when XLA compilation fails on edge cases - budget another week; (3) ongoing maintenance cost of maintaining a JAX codebase when the team's expertise is PyTorch. Net calculation: if this is a one-time pretraining run, the porting cost likely exceeds the compute savings. If you plan to run BERT variants repeatedly (ablations, different configs, future model versions), the amortized porting cost is worth it. My recommendation: for one run, use GPU cluster and accept the cost. For a research program with multiple large-scale runs, invest in JAX porting. Alternatively, investigate whether the GPU run can be made more cost-efficient with spot instances (potentially 60% cheaper) before committing to the porting effort.

© 2026 EngineersOfAI. All rights reserved.