Skip to main content

Signals and IPC for ML

Reading time: ~35 min · Interview relevance: High · Target roles: MLOps Engineer, ML Systems Engineer, AI Infrastructure


It is 11:30 PM and your distributed training job has been running for 18 hours. The job is at step 45,000 of 60,000. You have about six hours of compute left before it finishes. Then the cloud spot instance gets preempted. Kubernetes sends SIGTERM to your training container. Your training process receives the signal. What happens next depends entirely on whether you handled it correctly.

If your code looks like most tutorial PyTorch code - a simple for loop with no signal handling - the process exits immediately on receiving SIGTERM. The last saved checkpoint was at step 40,000, saved five hours ago. You lose five hours of compute. You restart the job from step 40,000 and pay for those five hours again. On a cluster of 64 A100s at 32/GPU/hour,thatis32/GPU/hour, that is 10,240 gone because nobody wrote a twelve-line signal handler.

If your code handles SIGTERM correctly, the process receives the signal, sets a flag that the training loop checks at the end of the current step, finishes the current optimizer step atomically, saves a checkpoint at step 45,000, and exits cleanly. On restart, it loads from step 45,000 and you lose only the seconds since the last checkpoint save - not five hours.

Beyond graceful shutdown, signal handling connects to the broader world of inter-process communication (IPC). ML training and inference systems are almost always multi-process: DataLoader workers, gradient accumulation processes, model server workers, monitoring sidecars. These processes must communicate efficiently. The mechanisms they use - shared memory, pipes, Unix domain sockets, message queues - have dramatically different performance characteristics and semantic guarantees. Choosing the wrong mechanism for the wrong problem is a common source of both correctness bugs and performance bottlenecks.

This lesson covers both sides: signals as an OS mechanism for process control, and IPC as a mechanism for inter-process data exchange. For ML systems, these are daily operational concerns, not academic topics.


Why This Exists

The Unix process model was designed around isolated address spaces. Each process has its own virtual memory. Processes cannot read each other's memory directly. This isolation is a security and correctness feature - a bug in one process cannot corrupt another's state. But it creates a problem for systems that need to share data or coordinate actions across processes: how do you communicate if you cannot share memory?

Signals solve the narrow problem of control: send an asynchronous notification to a process that something has happened. IPC mechanisms solve the broader problem of data exchange: pass data between processes without copying it through the kernel more times than necessary. The kernel provides multiple IPC mechanisms because no single mechanism is optimal for all use cases. Shared memory is fastest but has no synchronization. Pipes are simple but unidirectional and stream-oriented. Message queues preserve message boundaries. Unix domain sockets support bidirectional full-duplex communication with optional authentication.

For ML systems, the key tension is between latency and simplicity. DataLoader workers need to pass tensors to the training loop at GPU feed speed. At batch_size=64 and 224x224 RGB images, you are moving 12 MB per batch. If that 12 MB crosses a Unix pipe, it gets copied into kernel buffer and out again - 24 MB of kernel memory operations per batch. With shared memory, the tensor lives in /dev/shm and the training loop reads it in place - zero copies for the data path, just a small message to indicate availability.


Historical Context

Unix signals date to Version 7 Unix (1979) from Bell Labs. The original signal semantics were unreliable - signals could be lost during delivery. POSIX signals (IEEE Std 1003.1, 1988) standardized reliable signal delivery and introduced sigaction(), which replaced the older signal() function with well-defined semantics. Modern Linux uses POSIX signal semantics throughout.

POSIX IPC (shared memory via shm_open, message queues via mq_open, semaphores) was standardized as part of POSIX.1b (real-time extensions) in 1993. SysV IPC (the older shmget/msgget family) predates POSIX IPC and comes from Unix System V (1983). Both exist in modern Linux. Python's multiprocessing module uses POSIX shared memory on Linux (via /dev/shm) and SysV on some other platforms.

ZeroMQ (ZMQ) was created by iMatix Corporation in 2007 and open-sourced in 2008. It provides a high-performance asynchronous messaging library that works across processes and machines using a consistent socket-like API. It became popular in distributed ML systems because it handles the complexity of asynchronous message passing, reconnection, and message routing that raw Unix sockets require you to implement yourself.


Core Concepts

Unix Signals - The Notification Layer

A signal is an asynchronous notification sent to a process by the kernel, another process, or the process itself. Signals interrupt the process's normal execution flow at any point and run a registered handler function before returning to the interrupted code.

The most important signals for ML workloads:

SignalNumberDefault ActionMeaning for ML
SIGTERM15TerminatePolite shutdown request - save checkpoint and exit
SIGKILL9Terminate (uncatchable)Immediate death - no handler possible
SIGINT2TerminateCtrl+C - treat like SIGTERM in production
SIGUSR110Terminate (default)User-defined: e.g., trigger checkpoint save
SIGUSR212Terminate (default)User-defined: e.g., toggle debug logging
SIGHUP1TerminateTerminal hangup - used for config reload in daemons
SIGCHLD17IgnoreChild process exited - used by process supervisors
SIGALRM14TerminateTimer expiry - used for watchdog timeouts

The SIGKILL constraint is absolute: SIGKILL cannot be caught, blocked, or ignored by the process. The kernel delivers it directly to the scheduler, which terminates the process immediately. There is no checkpoint saving, no cleanup, no flush. Kubernetes sends SIGTERM first and waits for terminationGracePeriodSeconds (default 30 seconds). If the process has not exited by then, Kubernetes sends SIGKILL. Your signal handler must complete all cleanup within that grace period.

Python signal handling constraints:

  1. Python signal handlers only run in the main thread. If your training loop is running in a background thread, SIGTERM will be received by the main thread but the handler cannot safely interact with thread-local PyTorch state. The recommended pattern: the main thread runs the training loop; signal handlers set a threading.Event that the training loop checks.
  2. Python uses a safe signaling mechanism: the C-level signal handler just sets a flag, and the flag is checked at the next "safe point" in the Python interpreter (between bytecodes). This means signal delivery is slightly delayed but never interrupts critical sections like memory allocation.
  3. Multiprocessing complicates signal handling. When you call DataLoader(num_workers=4), PyTorch forks 4 worker processes. SIGTERM is sent to the parent only. The parent must propagate the signal to workers and wait for them to finish.
import signal
import threading
import time
import os
import sys
from pathlib import Path


class GracefulTrainingShutdown:
"""Signal handler that allows a training job to save a checkpoint
before exiting when SIGTERM or SIGINT is received.

Usage pattern:
shutdown = GracefulTrainingShutdown(checkpoint_dir="/checkpoints")
shutdown.register()

for step, batch in enumerate(dataloader):
if shutdown.should_stop:
# Training loop checks this flag at end of each step
print(f"Graceful shutdown requested at step {step}")
shutdown.save_checkpoint(model, optimizer, step)
sys.exit(0)
# ... normal training step ...
"""

def __init__(self, checkpoint_dir: str, grace_period_seconds: int = 60):
self.checkpoint_dir = Path(checkpoint_dir)
self.grace_period_seconds = grace_period_seconds
self._stop_event = threading.Event()
self._shutdown_requested = False
self._signal_received = None

def register(self) -> None:
"""Register handlers for SIGTERM, SIGINT, and SIGUSR1.

MUST be called from the main thread. Signal handlers only
execute in the main Python thread.
"""
signal.signal(signal.SIGTERM, self._handle_shutdown)
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGUSR1, self._handle_checkpoint_request)
print(f"Signal handlers registered. PID={os.getpid()}")
print(f" SIGTERM/SIGINT: graceful shutdown + checkpoint save")
print(f" SIGUSR1: checkpoint save without stopping")
print(f" To trigger: kill -TERM {os.getpid()}")
print(f" To checkpoint: kill -USR1 {os.getpid()}")

def _handle_shutdown(self, signum: int, frame) -> None:
"""Signal handler for SIGTERM and SIGINT.

This function runs asynchronously when the signal arrives.
We only set the flag here - no I/O, no heavy operations.
The actual checkpoint save happens in the training loop
when it next checks should_stop.
"""
sig_name = signal.Signals(signum).name
print(f"\nReceived {sig_name} (signal {signum}). "
f"Will stop after current step.", flush=True)
self._shutdown_requested = True
self._signal_received = signum
self._stop_event.set()

def _handle_checkpoint_request(self, signum: int, frame) -> None:
"""SIGUSR1 handler: save checkpoint without stopping training.

This allows operators to trigger a checkpoint mid-training,
e.g., before a planned maintenance window.
"""
print(f"\nReceived SIGUSR1. Will save checkpoint after current step.",
flush=True)
self._checkpoint_requested = True

@property
def should_stop(self) -> bool:
return self._shutdown_requested

@property
def checkpoint_requested(self) -> bool:
return getattr(self, "_checkpoint_requested", False)

def clear_checkpoint_request(self) -> None:
self._checkpoint_requested = False

def save_checkpoint(
self,
model,
optimizer,
step: int,
loss: float | None = None,
extra: dict | None = None,
) -> Path:
"""Save a checkpoint. Returns the checkpoint path."""
import torch
checkpoint_path = self.checkpoint_dir / f"checkpoint_step_{step}.pt"
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

state = {
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"signal_received": self._signal_received,
}
if extra:
state.update(extra)

# Write to temp file first, then atomic rename
# This prevents a partial checkpoint if we are killed mid-write
tmp_path = checkpoint_path.with_suffix(".pt.tmp")
torch.save(state, tmp_path)
tmp_path.rename(checkpoint_path)

# Ensure durability: flush to disk before exit
with open(checkpoint_path, "rb") as f:
os.fsync(f.fileno())

print(f"Checkpoint saved: {checkpoint_path} (step={step})", flush=True)
return checkpoint_path


def train_with_graceful_shutdown(
model,
optimizer,
dataloader,
checkpoint_dir: str,
max_steps: int = 100_000,
) -> None:
"""Training loop with proper signal handling.

Key patterns:
1. Check shutdown flag at end of each step (not inside optimizer.step())
2. Save checkpoint on shutdown request
3. Use sys.exit(0) for clean exit (lets atexit handlers run)
4. Import torch inside the function so signal handler can be registered
before any CUDA operations start
"""
import torch

shutdown = GracefulTrainingShutdown(checkpoint_dir)
shutdown.register()

for step, batch in enumerate(dataloader):
if step >= max_steps:
break

# --- Normal training step ---
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()

# --- Periodic checkpoint save ---
if step % 1000 == 0 and step > 0:
shutdown.save_checkpoint(model, optimizer, step, loss=loss.item())

# --- Handle SIGUSR1: checkpoint without stopping ---
if shutdown.checkpoint_requested:
shutdown.save_checkpoint(model, optimizer, step, loss=loss.item())
shutdown.clear_checkpoint_request()

# --- Check shutdown flag at end of each step ---
# Never check in the middle of an optimizer step - that could
# save a checkpoint with mismatched model and optimizer state.
if shutdown.should_stop:
print(f"Graceful shutdown: saving final checkpoint at step {step}")
shutdown.save_checkpoint(model, optimizer, step, loss=loss.item())
print("Checkpoint saved. Exiting cleanly.")
sys.exit(0)

# Normal completion
shutdown.save_checkpoint(model, optimizer, max_steps)
print("Training complete.")

Shared Memory - Zero-Copy Data Passing

Shared memory allows two processes to read and write the same physical RAM pages without copying data through the kernel. It is the fastest IPC mechanism because the "transfer" involves no kernel involvement beyond the initial mapping setup - processes access the shared region using ordinary memory instructions (MOV, etc.).

In Linux, /dev/shm is a tmpfs filesystem that serves as POSIX shared memory backing storage. Python's multiprocessing.shared_memory module uses this by default on Linux. PyTorch's DataLoader workers store tensor data here so the training process can access it without copying.

The critical invariant: shared memory requires external synchronization. Two processes writing to the same location concurrently is a data race. Python's multiprocessing module provides Lock, Semaphore, and Event primitives backed by POSIX semaphores for this purpose.

import numpy as np
import multiprocessing as mp
from multiprocessing import shared_memory
import time
import struct


def create_shared_tensor_buffer(
shape: tuple,
dtype=np.float32,
name: str | None = None,
) -> tuple[shared_memory.SharedMemory, np.ndarray]:
"""Create a shared memory region and a numpy array backed by it.

The array can be passed to worker processes by name.
Workers attach to the same region and access the same physical memory.
No copying occurs during transfer - only the name string is passed.

Returns:
shm: the SharedMemory object (call shm.unlink() when done)
arr: numpy array backed by the shared memory
"""
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
shm = shared_memory.SharedMemory(create=True, size=nbytes, name=name)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
return shm, arr


def worker_fill_buffer(
shm_name: str,
shape: tuple,
dtype_str: str,
ready_event: mp.Event,
consumed_event: mp.Event,
worker_id: int,
) -> None:
"""Worker process: attaches to shared memory and fills it with data.

The producer-consumer protocol:
1. Worker fills the buffer with a preprocessed batch
2. Worker sets ready_event to signal the consumer
3. Worker waits on consumed_event before filling the next batch
4. Consumer reads the buffer, processes it, sets consumed_event

This is a double-buffer pattern - in production you use two or more
buffers and ping-pong between them to overlap I/O and compute.
"""
dtype = np.dtype(dtype_str)
# Attach to the existing shared memory region (do not create)
shm = shared_memory.SharedMemory(name=shm_name, create=False)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)

try:
for batch_idx in range(100):
# Simulate preprocessing (image decode, normalization, etc.)
time.sleep(0.01)
# Write directly into shared memory - zero copies
arr[:] = np.random.randn(*shape).astype(dtype) * (batch_idx + 1)

# Signal consumer that data is ready
ready_event.set()
# Wait for consumer to finish before overwriting
consumed_event.wait()
consumed_event.clear()
finally:
shm.close() # Detach but do not destroy (producer owns it)


def consumer_training_loop(
shm_name: str,
shape: tuple,
dtype_str: str,
ready_event: mp.Event,
consumed_event: mp.Event,
) -> None:
"""Main process: reads from shared memory and processes batches.

The consumer side of the zero-copy DataLoader pattern.
PyTorch's DataLoader does exactly this - workers write tensors
into /dev/shm, the main process reads them without copying.
"""
dtype = np.dtype(dtype_str)
shm = shared_memory.SharedMemory(name=shm_name, create=False)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)

import torch
processed = 0
start_time = time.time()

for batch_idx in range(100):
# Wait for worker to fill the buffer
ready_event.wait()
ready_event.clear()

# Create a PyTorch tensor from the shared memory array
# from_numpy shares the underlying buffer - still no copy
tensor = torch.from_numpy(arr.copy()) # copy here for gradient support

# Signal worker that we consumed this batch
consumed_event.set()
processed += 1

elapsed = time.time() - start_time
print(f"Processed {processed} batches in {elapsed:.2f}s")
print(f"Throughput: {processed / elapsed:.1f} batches/sec")
shm.close()


def demo_shared_memory_pipeline() -> None:
"""Demonstrate the shared memory producer-consumer pattern."""
SHAPE = (32, 3, 224, 224) # batch_size x channels x H x W
DTYPE = "float32"

# Create shared memory buffer
shm, _ = create_shared_tensor_buffer(SHAPE, dtype=np.float32)

ready_event = mp.Event()
consumed_event = mp.Event()

# Start worker process
worker = mp.Process(
target=worker_fill_buffer,
args=(shm.name, SHAPE, DTYPE, ready_event, consumed_event, 0),
daemon=True,
)
worker.start()

# Run consumer in main process
consumer_training_loop(shm.name, SHAPE, DTYPE, ready_event, consumed_event)

worker.join()
shm.unlink() # Destroy the shared memory region
print(f"Shared memory region '{shm.name}' destroyed")


def monitor_dev_shm() -> dict:
"""Check /dev/shm usage. Critical for ML workloads with many DataLoader workers.

Default /dev/shm size is 64 MB (set by Docker's --shm-size default).
With 8 DataLoader workers and 32-image batches of 224x224 float32:
Each worker holds ~2 batches = 2 * 32 * 3 * 224 * 224 * 4 bytes = 386 MB
8 workers = 3 GB minimum needed in /dev/shm.
Docker default of 64 MB will cause DataLoader workers to crash silently.
"""
import shutil
shm_path = "/dev/shm"
total, used, free = shutil.disk_usage(shm_path)
return {
"total_gb": round(total / (1024 ** 3), 2),
"used_gb": round(used / (1024 ** 3), 2),
"free_gb": round(free / (1024 ** 3), 2),
"utilization_pct": round(used / total * 100, 1),
"recommendation": (
"OK" if free > 4 * (1024 ** 3)
else "WARNING: less than 4 GB free in /dev/shm - DataLoader workers may fail"
)
}

Pipes - Unidirectional Stream IPC

A pipe is the simplest IPC mechanism: a kernel-managed circular buffer with one write end and one read end. Data written to the write end appears in order at the read end. Pipes are unidirectional. For bidirectional communication you need two pipes or a different mechanism.

multiprocessing.Pipe() returns a pair of Connection objects. By default it creates a full-duplex (bidirectional) pipe using socketpair() on Linux - not an actual Unix pipe, but with similar semantics.

For ML, pipes are appropriate for control messages and small status updates between processes. They are inappropriate for passing large tensors because the data is copied through the kernel buffer.

import multiprocessing as mp
import time
import os


def model_server_worker(
request_pipe: mp.connection.Connection,
response_pipe: mp.connection.Connection,
model_weights_path: str,
) -> None:
"""Worker process that loads a model and serves inference requests.

Receives batches via shared memory, sends results back via pipe.
The request pipe carries metadata (batch_id, shm_name, shape).
The response pipe carries results (small tensors or scalars).
"""
import torch
import pickle

# Load model once at startup
print(f"Worker {os.getpid()}: loading model...")
# model = torch.load(model_weights_path)
# model.eval()
print(f"Worker {os.getpid()}: ready")

while True:
try:
# Receive request metadata
if not request_pipe.poll(timeout=1.0):
continue
msg = request_pipe.recv()

if msg is None: # Shutdown sentinel
print(f"Worker {os.getpid()}: received shutdown signal")
break

batch_id = msg["batch_id"]
# In real code: load from shared memory by msg["shm_name"]
# For demo: simulate inference
time.sleep(0.005) # Simulate 5ms inference
result = {"batch_id": batch_id, "predictions": [0.9, 0.1]}

response_pipe.send(result)

except EOFError:
break
except Exception as e:
print(f"Worker error: {e}")
response_pipe.send({"error": str(e)})


class PipeBasedModelServer:
"""Simple model server using pipes for request/response coordination.

For production, use Unix domain sockets (shown later) for better
performance and support for multiple concurrent clients.
"""

def __init__(self, model_path: str, num_workers: int = 2):
self.workers = []
self.request_pipes = []
self.response_pipes = []

for i in range(num_workers):
req_parent, req_child = mp.Pipe(duplex=False)
resp_parent, resp_child = mp.Pipe(duplex=False)

worker = mp.Process(
target=model_server_worker,
args=(req_child, resp_child, model_path),
daemon=True,
)
worker.start()
self.workers.append(worker)
self.request_pipes.append(req_parent)
self.response_pipes.append(resp_parent)

self._next_worker = 0

def predict(self, batch_id: int, batch_data) -> dict:
"""Send inference request to next available worker (round-robin)."""
worker_idx = self._next_worker % len(self.workers)
self._next_worker += 1

msg = {"batch_id": batch_id, "data": batch_data}
self.request_pipes[worker_idx].send(msg)
return self.response_pipes[worker_idx].recv()

def shutdown(self) -> None:
"""Gracefully shut down all workers."""
for pipe in self.request_pipes:
pipe.send(None) # Shutdown sentinel
for worker in self.workers:
worker.join(timeout=5.0)
if worker.is_alive():
worker.terminate()

Unix Domain Sockets - Full-Duplex Local IPC

Unix domain sockets (UDS) are bidirectional, full-duplex communication channels that exist as filesystem paths rather than IP addresses. They support both connection-oriented (SOCK_STREAM, like TCP) and connectionless (SOCK_DGRAM, like UDP) modes.

Advantages over pipes for ML serving:

  • Multiple clients can connect to one server (unlike pipes which are point-to-point)
  • Supports credential passing: the server can verify the client's UID/GID without a separate authentication step
  • Higher throughput than pipes for large messages because the kernel can use zero-copy tricks
  • Supports out-of-band data and scatter-gather I/O

UDS latency is typically 5-20 microseconds for small messages, compared to 50-200 microseconds for TCP loopback. For a model server handling thousands of requests per second, this difference accumulates.

import socket
import os
import struct
import json
import threading
from pathlib import Path


SOCKET_PATH = "/tmp/ml-model-server.sock"


def create_uds_server(
socket_path: str,
max_connections: int = 50,
) -> socket.socket:
"""Create a Unix domain socket server.

Permissions on the socket file control who can connect.
chmod 0600 restricts to the same user only.
chmod 0660 allows same group (useful for container-internal services).
"""
# Remove stale socket from previous run
if os.path.exists(socket_path):
os.unlink(socket_path)

server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind(socket_path)
# Restrict to owner only - only this UID can connect
os.chmod(socket_path, 0o600)
server_sock.listen(max_connections)

return server_sock


def handle_inference_client(
client_sock: socket.socket,
client_addr: str,
model,
) -> None:
"""Handle one client connection: receive batches, return predictions.

Protocol (length-prefixed JSON):
[4 bytes: message length][N bytes: JSON message body]
"""

def recv_message(sock: socket.socket) -> dict | None:
"""Receive a length-prefixed JSON message."""
header = b""
while len(header) < 4:
chunk = sock.recv(4 - len(header))
if not chunk:
return None
header += chunk
msg_len = struct.unpack(">I", header)[0]

data = b""
while len(data) < msg_len:
chunk = sock.recv(msg_len - len(data))
if not chunk:
return None
data += chunk
return json.loads(data.decode("utf-8"))

def send_message(sock: socket.socket, msg: dict) -> None:
"""Send a length-prefixed JSON message."""
data = json.dumps(msg).encode("utf-8")
header = struct.pack(">I", len(data))
sock.sendall(header + data)

try:
while True:
request = recv_message(client_sock)
if request is None:
break

# Process the inference request
# In production: deserialize tensor from shm_name in request
batch_id = request.get("batch_id")
inputs = request.get("inputs", [])

# Simulate model inference
import time
time.sleep(0.003)
predictions = [0.7, 0.2, 0.1]

response = {
"batch_id": batch_id,
"predictions": predictions,
"latency_ms": 3.0,
}
send_message(client_sock, response)

except (ConnectionResetError, BrokenPipeError):
pass
finally:
client_sock.close()


def run_uds_model_server(socket_path: str = SOCKET_PATH) -> None:
"""Run a Unix domain socket model server.

Accepts multiple concurrent clients. Each client gets its own
handling thread. For production, use asyncio or a worker pool
instead of a thread per client.
"""
server_sock = create_uds_server(socket_path)
print(f"Model server listening on {socket_path}")
print(f"Socket permissions: {oct(os.stat(socket_path).st_mode)}")

# model = load_model() # Load once
model = None

try:
while True:
client_sock, _ = server_sock.accept()

# Get client credentials (Linux-specific)
# SO_PEERCRED returns (pid, uid, gid) of connecting process
try:
creds = client_sock.getsockopt(
socket.SOL_SOCKET, socket.SO_PEERCRED,
struct.calcsize("3i")
)
pid, uid, gid = struct.unpack("3i", creds)
print(f"Client connected: PID={pid}, UID={uid}, GID={gid}")
except OSError:
pass

thread = threading.Thread(
target=handle_inference_client,
args=(client_sock, socket_path, model),
daemon=True,
)
thread.start()
finally:
server_sock.close()
os.unlink(socket_path)


def uds_inference_client(
socket_path: str,
batch_id: int,
inputs: list,
) -> dict:
"""Client: send inference request to UDS model server."""
client_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client_sock.connect(socket_path)

request = json.dumps({"batch_id": batch_id, "inputs": inputs}).encode()
header = struct.pack(">I", len(request))
client_sock.sendall(header + request)

# Receive response
header = client_sock.recv(4)
msg_len = struct.unpack(">I", header)[0]
data = b""
while len(data) < msg_len:
data += client_sock.recv(msg_len - len(data))

client_sock.close()
return json.loads(data.decode())

PyTorch DataLoader Shared Memory in Detail

PyTorch's DataLoader with num_workers > 0 uses fork-based multiprocessing on Linux. Each worker is a forked copy of the main process that runs the dataset's __getitem__ method. The resulting tensors are stored in /dev/shm using PyTorch's shared memory allocator, and the main process receives a handle (the shared memory file descriptor and offset) that it uses to access the tensor without copying.

import torch
from torch.utils.data import DataLoader, Dataset
import psutil
import os


class MLDataset(Dataset):
"""Example dataset with proper shared memory awareness."""

def __init__(self, data_path: str, transform=None):
self.data_path = data_path
self.transform = transform
# Load file index only in main process
# Workers inherit this via fork (copy-on-write)
self.file_list = list(sorted(Path(data_path).glob("*.pt")))

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx: int):
# This runs in a worker process
sample = torch.load(self.file_list[idx])
if self.transform:
sample = self.transform(sample)
# Return value is stored in /dev/shm by PyTorch's worker manager
return sample


def configure_dataloader_for_system(
dataset: Dataset,
batch_size: int = 32,
) -> DataLoader:
"""Configure DataLoader based on available system resources.

Key decisions:
- num_workers: set to num_CPU_cores - 2 (leave headroom for training)
- pin_memory: True if CUDA available, False otherwise
(pin_memory pre-maps memory for DMA, speeds GPU transfer)
- prefetch_factor: 2 is default; increase for I/O-bound datasets
- persistent_workers: keep workers alive between epochs
(avoids the 2-5 second worker startup overhead per epoch)
"""
num_cpus = os.cpu_count() or 4
# Check available /dev/shm space
shm_info = psutil.disk_usage("/dev/shm")
shm_free_gb = shm_info.free / (1024 ** 3)

# Estimate memory needed: batch_size * sample_size * prefetch_factor * num_workers
# For 224x224 float32 images: 3 * 224 * 224 * 4 = 602,112 bytes per sample
bytes_per_sample = 3 * 224 * 224 * 4
bytes_per_batch = batch_size * bytes_per_sample
num_workers = max(1, num_cpus - 2)
prefetch_factor = 2
estimated_shm_gb = (num_workers * prefetch_factor * bytes_per_batch) / (1024 ** 3)

if estimated_shm_gb > shm_free_gb * 0.8:
print(f"WARNING: Estimated /dev/shm usage {estimated_shm_gb:.1f} GB "
f"exceeds 80% of available {shm_free_gb:.1f} GB.")
print("Reducing num_workers to prevent /dev/shm overflow.")
# Reduce workers until estimated usage fits
while num_workers > 1 and estimated_shm_gb > shm_free_gb * 0.8:
num_workers -= 1
estimated_shm_gb = (num_workers * prefetch_factor * bytes_per_batch) / (1024 ** 3)

cuda_available = torch.cuda.is_available()
print(f"DataLoader configuration:")
print(f" num_workers: {num_workers}")
print(f" pin_memory: {cuda_available}")
print(f" prefetch_factor: {prefetch_factor}")
print(f" est. /dev/shm use: {estimated_shm_gb:.2f} GB")

return DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=cuda_available,
prefetch_factor=prefetch_factor,
persistent_workers=(num_workers > 0),
# multiprocessing_context="fork" is default on Linux
# "spawn" is safer but slower (re-imports everything)
# "forkserver" is intermediate: forks only the server process
)

ZeroMQ for Distributed ML Communication

ZeroMQ (ZMQ) provides a high-performance asynchronous message queue library. It exposes socket-like endpoints that can span processes, machines, and networks using a unified API. ZMQ handles connection management, message framing, reconnection, and high-water-mark backpressure automatically.

For ML systems, ZMQ is commonly used in:

  • Parameter server architectures (workers push gradients, server distributes updated weights)
  • Distributed inference pipelines (load balancer distributes requests to workers)
  • Experiment tracking sidecars (training process sends metrics to collector)
  • Model update broadcasting (new model version is published to all serving workers)
import zmq
import pickle
import time
import threading
import numpy as np


def gradient_aggregator_server(
pull_address: str = "tcp://*:5555",
pub_address: str = "tcp://*:5556",
num_workers: int = 4,
) -> None:
"""Parameter server: collects gradients from workers, broadcasts updates.

Pattern: PULL (receive gradients) + PUB (broadcast updated weights)

Worker pattern: PUSH (send gradients) + SUB (receive updated weights)

This is a simple synchronous barrier implementation.
For async SGD, remove the barrier and apply gradients as they arrive.
"""
context = zmq.Context()

# PULL socket: workers PUSH gradients here
pull_sock = context.socket(zmq.PULL)
pull_sock.bind(pull_address)

# PUB socket: broadcast updated weights to all workers
pub_sock = context.socket(zmq.PUB)
pub_sock.bind(pub_address)

print(f"Parameter server started:")
print(f" Gradient PULL: {pull_address}")
print(f" Weight PUB: {pub_address}")

# Simulate model weights
weights = np.random.randn(1000).astype(np.float32)
step = 0
lr = 0.01

while True:
# Collect gradients from all workers (synchronous barrier)
gradients = []
for worker_idx in range(num_workers):
msg = pull_sock.recv_pyobj()
worker_id = msg["worker_id"]
grad = msg["gradient"]
gradients.append(grad)

# Average gradients (AllReduce equivalent)
avg_gradient = np.mean(gradients, axis=0)

# Update weights with SGD
weights -= lr * avg_gradient
step += 1

# Broadcast updated weights to all workers
update_msg = {
"step": step,
"weights": weights,
}
# Topic prefix "" means all subscribers receive it
pub_sock.send_multipart([b"weights", pickle.dumps(update_msg)])

if step % 100 == 0:
print(f"Step {step}: avg_gradient_norm={np.linalg.norm(avg_gradient):.4f}")


def training_worker(
worker_id: int,
push_address: str = "tcp://localhost:5555",
sub_address: str = "tcp://localhost:5556",
num_steps: int = 1000,
) -> None:
"""Training worker: computes gradients and pushes to parameter server.

In real PyTorch DDP, this is handled by NCCL AllReduce.
This ZMQ pattern is useful for heterogeneous hardware (CPU + GPU workers),
for parameter server experiments, or for fault-tolerant training where
workers can join/leave independently.
"""
context = zmq.Context()

# PUSH socket: send gradients to parameter server
push_sock = context.socket(zmq.PUSH)
push_sock.connect(push_address)

# SUB socket: receive updated weights from parameter server
sub_sock = context.socket(zmq.SUB)
sub_sock.connect(sub_address)
sub_sock.setsockopt(zmq.SUBSCRIBE, b"weights") # Subscribe to weight updates

# Allow pub/sub to connect before first message
time.sleep(0.1)

print(f"Worker {worker_id}: connected to parameter server")

# Local model state (simplified)
local_weights = np.random.randn(1000).astype(np.float32)

for step in range(num_steps):
# Simulate forward + backward pass
gradient = np.random.randn(1000).astype(np.float32) * 0.1

# Push gradient to server
msg = {"worker_id": worker_id, "step": step, "gradient": gradient}
push_sock.send_pyobj(msg)

# Wait for updated weights from server
topic, payload = sub_sock.recv_multipart()
update = pickle.loads(payload)
local_weights = update["weights"].copy()

print(f"Worker {worker_id}: finished {num_steps} steps")
push_sock.close()
sub_sock.close()
context.term()


def demo_zmq_parameter_server(num_workers: int = 4) -> None:
"""Demonstrate ZMQ-based parameter server with N workers."""
# Start parameter server in background thread
server_thread = threading.Thread(
target=gradient_aggregator_server,
args=("tcp://*:5555", "tcp://*:5556", num_workers),
daemon=True,
)
server_thread.start()
time.sleep(0.2) # Give server time to bind

# Start workers
worker_threads = []
for worker_id in range(num_workers):
t = threading.Thread(
target=training_worker,
args=(worker_id, "tcp://localhost:5555", "tcp://localhost:5556", 200),
daemon=True,
)
t.start()
worker_threads.append(t)

for t in worker_threads:
t.join()
print("All workers finished.")

Architecture Diagram


Production Engineering Notes

The grace period math: Kubernetes default terminationGracePeriodSeconds is 30 seconds. For a training job, your checkpoint save must complete within 30 seconds. For a 7B model in bf16 (14 GB), writing to NVMe at 3 GB/s takes about 5 seconds. Writing to a network-mounted PVC at 1 GB/s takes 14 seconds. Both fit within 30 seconds. But if you are also running torch.save() through overlayfs to the container's writable layer, throughput might be only 300 MB/s - that is 47 seconds for 14 GB. You exceed the grace period. Solution: always write checkpoints to a mounted NVMe PVC, not to the container's root filesystem.

Worker process signal propagation: When your training process receives SIGTERM, Python delivers it to the main thread. But your DataLoader workers (forked by PyTorch) are separate processes. They do not receive SIGTERM automatically unless you propagate it. PyTorch's DataLoader handles this in its _shutdown_workers() method. But if you write custom multiprocessing code, you must explicitly send signals to child processes: os.kill(worker.pid, signal.SIGTERM) for each worker before calling worker.join().

ZMQ high-water marks: ZMQ sockets have a configurable high-water mark (HWM) - the maximum number of messages buffered before blocking or dropping. For gradient PUSH sockets in a parameter server, set ZMQ_SNDHWM to the number of steps you can buffer without the server keeping up. If the parameter server is slow, workers should block rather than accumulate unbounded memory. Set push_sock.setsockopt(zmq.SNDHWM, 100) to limit buffering to 100 messages.

/dev/shm and Docker's default 64 MB limit: Docker sets /dev/shm to 64 MB by default. With 4 DataLoader workers loading 224x224 float32 images in batches of 32: each worker needs approximately 32 * 3 * 224 * 224 * 4 bytes = 19 MB per batch, with 2 prefetched batches = 38 MB per worker, 4 workers = 152 MB total. Docker's default 64 MB is insufficient. The symptom: DataLoader workers crash silently and training hangs. Fix: docker run --shm-size=4g or in Kubernetes, add an emptyDir volume with medium: Memory at /dev/shm.


Common Mistakes

:::danger Signal Handler Does Heavy Work A signal handler that calls torch.save(), opens files, or acquires locks is unsafe. Signal handlers can interrupt any point in the program - including inside malloc(), inside PyTorch's memory allocator, inside a Mutex::lock(). If you call any of these functions inside a signal handler while they are already being executed, you get a deadlock or heap corruption. Signal handlers must only set a flag (an integer or a threading.Event). The actual cleanup work happens in the normal execution path when the flag is checked. :::

:::danger SIGTERM Not Reaching Training Process Inside Container In Docker containers, if your entrypoint is a shell script (CMD ["/bin/sh", "-c", "python train.py"]), SIGTERM from Docker goes to the shell (PID 1), not to your Python process. The shell receives SIGTERM, ignores it (bash's default), and waits. Docker then sends SIGKILL after the grace period, killing the shell and your Python process without any checkpoint. Fix: always use exec form in Dockerfiles (ENTRYPOINT ["python", "train.py"]), or if you need a shell script, use exec python train.py as the last line to replace the shell with your process. :::

:::warning multiprocessing.Queue vs Pipe Performance multiprocessing.Queue is backed by a Pipe plus a Lock plus a Semaphore and a background thread for serialization. It is convenient but has 3-10x higher overhead than a raw Pipe. For high-throughput tensor passing (thousands of tensors per second), use multiprocessing.Pipe directly or, better, shared memory with a lightweight signaling mechanism. Queue.put() serializes the tensor using pickle (CPU memory copy), sends the bytes through a pipe (another copy), and the receiver deserializes (another copy). For a 12 MB batch tensor, that is 36 MB of copies. Shared memory does zero copies for the tensor data. :::

:::warning /dev/shm Leaks from Crashed Workers If a DataLoader worker crashes (SIGKILL, OOM kill), it may not clean up its shared memory allocations in /dev/shm. PyTorch creates files like /dev/shm/torch_<pid>_<offset> for each tensor. Crashed workers leak these files. Over time (across many training restarts), /dev/shm fills up. Symptom: DataLoader workers fail to allocate shared memory even though df -h /dev/shm shows space. Fix: periodically run ls -la /dev/shm/ | grep torch and delete files whose owner PID no longer exists. PyTorch registers an atexit handler to clean these up, but atexit does not run on SIGKILL. :::


Interview Questions

Q1: What is the difference between SIGTERM and SIGKILL? Why can SIGKILL not be caught?

SIGTERM (signal 15) is a software termination request. The process can catch it, run a signal handler, and decide what to do (save state, flush buffers, exit cleanly). SIGKILL (signal 9) is delivered by the kernel directly to the process scheduler. There is no signal delivery to the process's signal handler because the kernel terminates the process before user code can run. This is by design: SIGKILL is the guarantee that a process can always be killed. If processes could catch SIGKILL, a buggy or malicious process could ignore it and become unkillable. For ML training, this means your graceful shutdown logic can only work with SIGTERM. Kubernetes always sends SIGTERM first and waits for terminationGracePeriodSeconds before escalating to SIGKILL. Your job: save a checkpoint within that grace period.

Q2: How does PyTorch's DataLoader use /dev/shm, and what happens when it runs out of space?

PyTorch's DataLoader with num_workers > 0 forks N worker processes. Each worker runs your dataset's __getitem__ method and stores the resulting tensor in a shared memory region backed by /dev/shm. PyTorch creates files in /dev/shm with names like torch_<shm_id>_<offset> and maps them into both the worker's address space and the main process's address space. The main process receives a descriptor (file descriptor + offset) via a Unix pipe from the worker. It then reads the tensor directly from the mapped shared memory without any data copy. When /dev/shm fills up (most commonly because Docker's default 64 MB limit is too small for the batch size), PyTorch workers fail to allocate shared memory. The symptom is usually a RuntimeError: unable to open shared memory object or the training process hanging indefinitely because workers crash silently and stop producing batches.

Q3: Implement a signal handler that triggers a checkpoint save without stopping training.

import signal
import threading

_checkpoint_requested = threading.Event()

def _usr1_handler(signum, frame):
# Only set a flag - no I/O allowed in signal handler
_checkpoint_requested.set()

signal.signal(signal.SIGUSR1, _usr1_handler)

# In training loop:
for step, batch in enumerate(dataloader):
loss = train_step(model, optimizer, batch)

if _checkpoint_requested.is_set():
save_checkpoint(model, optimizer, step)
_checkpoint_requested.clear()
print(f"On-demand checkpoint saved at step {step}")

SIGUSR1 is sent with kill -USR1 <pid>. The signal handler only sets a threading.Event (one atomic operation, signal-safe). The training loop checks the event at the end of each step and performs the I/O synchronously, avoiding any reentrancy issues with PyTorch's memory allocator.

Q4: What is the performance difference between multiprocessing.Queue and shared memory for passing tensors between processes?

multiprocessing.Queue for tensors: the tensor is pickled (CPU copy), written to a pipe (kernel buffer copy), and unpickled (CPU copy). For a 12 MB batch tensor, approximately 36 MB of memory operations occur, plus two system calls (write + read). Latency is typically 1-5 milliseconds per tensor. Shared memory: the worker writes the tensor directly into a pre-allocated /dev/shm region. The main process receives a 16-byte handle (shm name + offset) via a fast pipe. The main process reads the tensor from the shared region using ordinary memory reads - zero copies, zero system calls for the data transfer. Latency is typically 50-200 microseconds (dominated by the semaphore signal + kernel wake-up latency). For a DataLoader targeting 1000 batches/second at 12 MB/batch, Queue uses 36 GB/s of memory bandwidth. Shared memory uses effectively 0 MB/s for the IPC path (the data was already written to /dev/shm by the worker).

Q5: How would you use ZeroMQ to build a gradient aggregation system that tolerates worker failures?

Design: workers use zmq.PUSH sockets to send gradients to a parameter server's zmq.PULL socket. The server uses zmq.PUB to broadcast updated weights. Fault tolerance mechanism: (1) set ZMQ_LINGER=0 on worker sockets so they do not block on close if the server is unreachable. (2) On the server, use a timer-based approach instead of a strict barrier: collect gradients for up to T milliseconds, average whatever arrived, and broadcast. Workers that fail simply stop sending - the server aggregates from the remaining workers. (3) Use heartbeat messages: workers send periodic {"type": "heartbeat", "worker_id": N} so the server can detect dead workers and reduce the expected contributor count. (4) Workers use zmq.DEALER/zmq.ROUTER instead of PUSH/PULL for reliable request-reply with worker identity tracking, enabling per-worker retry logic. This architecture is the foundation of parameter server frameworks like Ray's PS backend and older MXNet training.

Q6: Your training job runs on spot instances that can be preempted with 2 minutes warning. How do you design signal handling and checkpointing to minimize compute waste?

Three-layer design: (1) Register SIGTERM handler that sets a should_stop event immediately and logs the preemption time. (2) Every N steps (choose N so that checkpoint saving takes at most 30 seconds - which is max_throughput_bytes / N_steps_size), save an incremental checkpoint using PyTorch's streaming checkpoint (saves layer by layer to avoid peak memory spikes). (3) On SIGTERM, if less than 90 seconds have elapsed since the last checkpoint, save an emergency checkpoint immediately and exit. If the last checkpoint was recent, skip the save (it is already fresh enough). The checkpoint file should be written atomically: write to .pt.tmp, call os.fsync(), then os.rename(). Rename is atomic on the same filesystem. If preemption kills the process mid-write (SIGKILL after grace period), the .pt.tmp file is corrupt but the previous .pt file is intact. On restart, detect and delete any .pt.tmp orphans.

© 2026 EngineersOfAI. All rights reserved.