Skip to main content

Network Security for ML Platforms

Six months after launch, a security researcher emails your team. They found that your ML inference API - the one serving 50 million requests per day - is leaking model weights. Not through a code exploit. Not through a database breach. Through the inference API itself: by sending carefully crafted inputs and analyzing the outputs, they reconstructed the architecture and approximate weights of your proprietary text classifier using a model inversion attack. They are not stealing the model. They are extracting it through the API you gave them.

This is a category of threat that traditional network security does not address. Firewalls stop unauthorized network access. TLS stops eavesdropping. Authentication stops unauthenticated users. But your system was doing all of that correctly. The authenticated, authorized users - the legitimate API customers - were the attacker vector. The model itself, not the surrounding infrastructure, was the exfiltrated asset.

ML platforms face a threat landscape that is fundamentally different from traditional web services. The assets being protected are not just data and code - they are trained models representing millions of dollars of compute, proprietary training datasets, and the intellectual property of months of research. The attack vectors are not just network intrusion - they include the inference API itself (model extraction), the training pipeline (data poisoning via malicious training data), and the serving infrastructure (prompt injection that causes the model to leak system prompts or act outside its intended scope).

Securing an ML platform requires layers of defense that span traditional network security and ML-specific mitigations. This lesson teaches both: the network security fundamentals (mTLS between services, Kubernetes network policies, eBPF enforcement, secrets management, PKI) and the ML-specific security controls (rate limiting inference for extraction protection, model signing for supply chain security, network isolation for training data, prompt injection defense at the network layer). By the end, you will have a threat model and a control set for an ML platform that defends against both traditional network attacks and ML-specific adversarial patterns.

Why This Exists

The first generation of ML infrastructure was built by researchers, not security engineers. The mental model was a university HPC cluster: trusted users, isolated network, focus on getting experiments to run. Security was an afterthought. As ML moved into production - serving real users, handling sensitive data, representing real financial value - that mental model became a liability.

The threat landscape expanded in three directions simultaneously. First, models became valuable enough to steal. A GPT-3-scale language model represents $4-5M in training compute at 2021 prices. A fine-tuned medical or financial model may represent competitive advantage worth orders of magnitude more. Second, training data became sensitive. Healthcare ML models train on patient records. Financial models train on proprietary trading data. The training pipeline became an attack surface. Third, ML systems themselves introduced new attack vectors - prompt injection, model inversion, membership inference - that traditional security tools do not detect.

Historical Context

Zero-trust networking, the security model that drives most of this lesson, was formally articulated by John Kindervag at Forrester Research in 2010. The core principle: never trust any network connection by default, even internal ones. Authenticate and authorize every request, regardless of where it originates. This was a response to the perimeter security model (trust everything inside the firewall) that attackers had learned to bypass by compromising a single internal host.

Google published its zero-trust implementation - BeyondCorp - in a series of papers from 2014-2018. The core insight: network location is not a proxy for trustworthiness. A machine on the corporate network is not inherently more trusted than a machine on the internet. Authentication and authorization should be based on device identity and user identity, not network location.

The mutual TLS (mTLS) standard for service-to-service authentication follows from zero-trust principles: every service proves its identity to every other service on every connection, regardless of whether they are on the same node or across datacenters. Service meshes (Istio, Linkerd, Consul Connect) automate mTLS for Kubernetes workloads without requiring application code changes.

Threat Model for ML Platforms

Before implementing controls, you need to enumerate what you are defending against:

mTLS: Mutual TLS for Service Authentication

Standard TLS (one-way TLS) authenticates the server to the client. The client verifies the server's certificate. The server does not verify the client - this is the model for HTTPS websites.

mTLS (mutual TLS) adds client authentication: the client also presents a certificate, and the server verifies it. For ML platform microservices, this means:

  • The inference service knows it is talking to the legitimate training orchestrator, not an impersonator
  • The feature store knows a request comes from an authorized ML pipeline, not a rogue process
  • If an internal service is compromised, it cannot make unauthorized calls to other services (its certificate identifies it, and its access is limited to what that certificate authorizes)

mTLS Implementation

"""
mtls_server.py - Minimal mTLS server demonstrating mutual certificate verification.
In production, use a service mesh (Istio/Linkerd) instead of manual mTLS.
This shows the underlying mechanics.

Generate test certificates:
# CA
openssl genrsa -out ca.key 4096
openssl req -new -x509 -key ca.key -sha256 -out ca.crt -days 365 \
-subj "/CN=ML Platform CA"

# Server cert
openssl genrsa -out server.key 4096
openssl req -new -key server.key -out server.csr -subj "/CN=inference-service"
openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key \
-CAcreateserial -out server.crt -days 365

# Client cert (for each service that needs to call the inference API)
openssl genrsa -out client.key 4096
openssl req -new -key client.key -out client.csr -subj "/CN=feature-store-service"
openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key \
-CAcreateserial -out client.crt -days 365
"""

import ssl
import socket
import threading
import json
from pathlib import Path


def create_server_ssl_context(
cert_file: str,
key_file: str,
ca_cert_file: str,
) -> ssl.SSLContext:
"""Create SSL context requiring client certificate (mTLS server side)."""
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(cert_file, key_file)

# Require client certificate - this is what makes it mTLS
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(ca_cert_file)

# Security hardening
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
ctx.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20")

return ctx


def create_client_ssl_context(
cert_file: str,
key_file: str,
ca_cert_file: str,
) -> ssl.SSLContext:
"""Create SSL context presenting client certificate (mTLS client side)."""
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# Present our certificate to the server
ctx.load_cert_chain(cert_file, key_file)

# Verify server certificate
ctx.load_verify_locations(ca_cert_file)
ctx.verify_mode = ssl.CERT_REQUIRED

ctx.minimum_version = ssl.TLSVersion.TLSv1_3

return ctx


class MTLSInferenceServer:
"""
mTLS-protected inference server.
Only accepts connections from clients with certificates signed by the internal CA.
"""

def __init__(
self,
host: str = "0.0.0.0",
port: int = 8443,
cert_dir: str = "/etc/ml-platform/certs",
):
self.host = host
self.port = port
self.ssl_context = create_server_ssl_context(
cert_file=f"{cert_dir}/server.crt",
key_file=f"{cert_dir}/server.key",
ca_cert_file=f"{cert_dir}/ca.crt",
)

def handle_request(self, conn: ssl.SSLSocket):
"""Handle a single inference request, with client identity verification."""
try:
# Get client certificate subject - this is the service's identity
client_cert = conn.getpeercert()
client_cn = dict(
x[0] for x in client_cert.get("subject", [])
).get("commonName", "unknown")

# Authorization check: only allow specific services to call inference
allowed_services = {
"feature-store-service",
"batch-pipeline-service",
"api-gateway",
}
if client_cn not in allowed_services:
conn.send(json.dumps({
"error": f"Service '{client_cn}' is not authorized"
}).encode())
return

# Read and process request
data = conn.recv(65536)
request = json.loads(data)

# Run inference (placeholder)
response = {
"predictions": [0.9, 0.1],
"client_service": client_cn, # Audit log: who called us
}
conn.send(json.dumps(response).encode())

except ssl.SSLError as e:
print(f"[mtls] TLS error (likely missing/invalid client cert): {e}")
finally:
conn.close()

def serve(self):
"""Start the mTLS server."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.host, self.port))
sock.listen(50)

print(f"mTLS inference server on {self.host}:{self.port}")
print(" TLS 1.3, client certificate required (mTLS)")

with self.ssl_context.wrap_socket(sock, server_side=True) as ssock:
while True:
conn, addr = ssock.accept()
t = threading.Thread(
target=self.handle_request, args=(conn,), daemon=True
)
t.start()


class MTLSInferenceClient:
"""Client that presents its certificate to prove its identity."""

def __init__(self, cert_dir: str = "/etc/ml-platform/certs",
service_name: str = "feature-store-service"):
self.ssl_context = create_client_ssl_context(
cert_file=f"{cert_dir}/{service_name}.crt",
key_file=f"{cert_dir}/{service_name}.key",
ca_cert_file=f"{cert_dir}/ca.crt",
)
self.ssl_context.check_hostname = False # Internal services use CN, not SANs

def predict(self, host: str, port: int, payload: dict) -> dict:
"""Make an inference call with mTLS client authentication."""
with socket.create_connection((host, port)) as sock:
with self.ssl_context.wrap_socket(sock, server_hostname=host) as ssock:
# Verify server certificate
server_cert = ssock.getpeercert()
server_cn = dict(
x[0] for x in server_cert.get("subject", [])
).get("commonName", "")
if server_cn != "inference-service":
raise ValueError(f"Server CN mismatch: got '{server_cn}'")

ssock.send(json.dumps(payload).encode())
response = ssock.recv(65536)
return json.loads(response)

cert-manager for Automated Certificate Management

Manual certificate management does not scale. cert-manager automates certificate issuance and rotation in Kubernetes:

# cert-manager-ml-platform.yaml
# Automated mTLS certificates for ML services

# Install cert-manager first:
# kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.14.0/cert-manager.yaml

---
# Internal CA (Certificate Authority) for the ML platform
apiVersion: cert-manager.io/v1
kind: ClusterIssuer
metadata:
name: ml-platform-ca
spec:
ca:
secretName: ml-platform-ca-key-pair # Secret containing CA cert and key

---
# Certificate for the inference service (auto-renewed 30 days before expiry)
apiVersion: cert-manager.io/v1
kind: Certificate
metadata:
name: inference-service-cert
namespace: ml-serving
spec:
secretName: inference-service-tls
duration: 720h # 30 days
renewBefore: 168h # Renew 7 days before expiry
subject:
organizations:
- ml-platform
commonName: inference-service
dnsNames:
- inference-service.ml-serving.svc.cluster.local
- inference-service.ml-serving.svc
issuerRef:
name: ml-platform-ca
kind: ClusterIssuer

---
# Certificate for the feature store service
apiVersion: cert-manager.io/v1
kind: Certificate
metadata:
name: feature-store-cert
namespace: ml-serving
spec:
secretName: feature-store-tls
duration: 720h
renewBefore: 168h
commonName: feature-store-service
dnsNames:
- feature-store.ml-serving.svc.cluster.local
issuerRef:
name: ml-platform-ca
kind: ClusterIssuer

Kubernetes Network Policies

By default, all pods in a Kubernetes cluster can communicate with all other pods. This is the "flat network" model - convenient for development, dangerous for production. A compromised inference pod can make calls to etcd, to the training data storage, to every other service. Network policies change this.

# ml-serving-network-policy.yaml
# Restrict what the inference service pods can communicate with

apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: inference-service-policy
namespace: ml-serving
spec:
# Apply to inference service pods
podSelector:
matchLabels:
app: inference-service

policyTypes:
- Ingress
- Egress

ingress:
# Allow from API gateway only
- from:
- namespaceSelector:
matchLabels:
name: api-gateway
podSelector:
matchLabels:
app: api-gateway
ports:
- protocol: TCP
port: 8000

# Allow from monitoring (Prometheus scrape)
- from:
- namespaceSelector:
matchLabels:
name: monitoring
ports:
- protocol: TCP
port: 9090

egress:
# Allow to model registry (to load models)
- to:
- namespaceSelector:
matchLabels:
name: model-registry
ports:
- protocol: TCP
port: 443

# Allow to feature store
- to:
- podSelector:
matchLabels:
app: feature-store
ports:
- protocol: TCP
port: 6565

# Allow DNS resolution
- to:
- namespaceSelector: {}
ports:
- protocol: UDP
port: 53
- protocol: TCP
port: 53

# Block everything else (including calls to training data, etcd, etc.)
---
# Deny all traffic by default in ml-serving namespace
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: default-deny-all
namespace: ml-serving
spec:
podSelector: {} # Matches all pods
policyTypes:
- Ingress
- Egress
# No ingress/egress rules = deny everything

eBPF with Cilium

Kubernetes NetworkPolicy is implemented by CNI plugins, and the implementation quality varies. Calico and Cilium are the most capable. Cilium uses eBPF (extended Berkeley Packet Filter) to enforce policies in the Linux kernel, providing several advantages over iptables-based approaches:

  • Policies are enforced at the kernel level, not in user space
  • eBPF programs can inspect Layer 7 (HTTP, gRPC) content, not just IP/port
  • Near-zero performance overhead compared to iptables for large rule sets
  • Deep network observability with Hubble
# cilium-ml-policy.yaml
# Cilium L7 network policy - enforce HTTP-level rules on inference API
# Requires Cilium CNI installed in the cluster

apiVersion: "cilium.io/v2"
kind: CiliumNetworkPolicy
metadata:
name: inference-l7-policy
namespace: ml-serving
spec:
endpointSelector:
matchLabels:
app: inference-service

ingress:
- fromEndpoints:
- matchLabels:
app: api-gateway
toPorts:
- ports:
- port: "8000"
protocol: TCP
# L7 HTTP rules - only allow specific endpoints and methods
rules:
http:
- method: "POST"
path: "/v1/predict"
- method: "GET"
path: "/health"
- method: "GET"
path: "/metrics"
# Block all other paths (e.g., /admin, /debug endpoints)

egress:
# Allow outbound HTTPS to model registry only
- toFQDNs:
- matchName: "models.internal.example.com"
toPorts:
- ports:
- port: "443"
protocol: TCP
"""
ebpf_policy_audit.py - Query Cilium's Hubble observability API to audit
network traffic policy violations in the ML platform.

Requires: Hubble CLI or Hubble Relay API access
pip install grpcio grpcio-tools
"""

import subprocess
import json
from datetime import datetime
from typing import List, Dict


def get_policy_drops(namespace: str, time_window: str = "1h") -> List[Dict]:
"""
Query Hubble for dropped packets (policy violations) in a namespace.
Useful for identifying misconfigured network policies or lateral movement attempts.
"""
result = subprocess.run(
[
"hubble", "observe",
"--namespace", namespace,
"--verdict", "DROPPED",
"--last", "1000",
"--output", "json",
],
capture_output=True, text=True
)

drops = []
for line in result.stdout.split("\n"):
if not line.strip():
continue
try:
event = json.loads(line)
flow = event.get("flow", {})
if flow:
drops.append({
"timestamp": flow.get("time", ""),
"source": flow.get("source", {}).get("identity", "unknown"),
"destination": flow.get("destination", {}).get("identity", "unknown"),
"l4": flow.get("l4", {}),
"verdict": flow.get("verdict", ""),
"drop_reason": flow.get("drop_reason_desc", ""),
})
except json.JSONDecodeError:
pass

return drops


def detect_lateral_movement_attempts(drops: List[Dict]) -> List[Dict]:
"""
Identify suspicious lateral movement patterns in network policy drops.
Heuristic: same source pod attempting connections to many different destinations.
"""
from collections import defaultdict

source_destinations = defaultdict(set)
for drop in drops:
source = drop.get("source", "")
dest = drop.get("destination", "")
if source and dest:
source_destinations[source].add(dest)

# Sources trying to reach many destinations = potential lateral movement scan
suspicious = []
for source, destinations in source_destinations.items():
if len(destinations) >= 5: # Threshold: 5+ different blocked destinations
suspicious.append({
"source": source,
"blocked_destination_count": len(destinations),
"destinations": list(destinations),
"severity": "HIGH" if len(destinations) > 20 else "MEDIUM",
})

return sorted(suspicious, key=lambda x: x["blocked_destination_count"],
reverse=True)

Secrets Management with HashiCorp Vault

ML platforms use many secrets: API keys for LLM providers, database credentials for the feature store, S3 credentials for model storage, encryption keys for data at rest. Kubernetes Secrets store them as base64 (not encrypted by default). HashiCorp Vault provides a proper secrets management solution.

"""
vault_ml_secrets.py - Vault integration for ML platform secrets management.

Setup:
vault secrets enable -path=ml-platform kv-v2
vault write ml-platform/data/openai api_key="sk-..." org_id="org-..."
vault write ml-platform/data/s3 access_key="..." secret_key="..."

Install: pip install hvac
"""

import hvac
import os
import time
import threading
import logging
from typing import Optional, Dict


class VaultMLSecretsManager:
"""
Vault-backed secrets manager for ML services.
Handles authentication via Kubernetes service account JWT,
secret rotation, and dynamic credentials.
"""

def __init__(
self,
vault_addr: str = "https://vault.ml-platform.svc.cluster.local:8200",
secrets_path: str = "ml-platform",
role: str = "inference-service",
):
self._vault_addr = vault_addr
self._secrets_path = secrets_path
self._role = role
self._client: Optional[hvac.Client] = None
self._token_expires_at: float = 0
self._cache: Dict[str, tuple] = {} # key -> (value, expires_at)
self._lock = threading.Lock()

def _authenticate(self):
"""Authenticate to Vault using Kubernetes service account token (JWT)."""
# Read the service account token mounted by Kubernetes
jwt_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
with open(jwt_path) as f:
jwt = f.read().strip()

client = hvac.Client(url=self._vault_addr)
result = client.auth.kubernetes.login(
role=self._role,
jwt=jwt,
)

client.token = result["auth"]["client_token"]
lease_duration = result["auth"]["lease_duration"]
self._token_expires_at = time.time() + lease_duration - 60 # Renew 1min early

self._client = client
logging.info(f"[vault] Authenticated as role '{self._role}', "
f"token valid for {lease_duration}s")

def _ensure_authenticated(self):
"""Re-authenticate if the token is about to expire."""
if self._client is None or time.time() > self._token_expires_at:
self._authenticate()

def get_secret(self, secret_name: str, key: str,
cache_ttl_seconds: int = 300) -> str:
"""
Retrieve a secret from Vault with local caching.
Cache reduces Vault API calls for frequently used secrets.
"""
cache_key = f"{secret_name}/{key}"

with self._lock:
if cache_key in self._cache:
value, expires_at = self._cache[cache_key]
if time.time() < expires_at:
return value

self._ensure_authenticated()
response = self._client.secrets.kv.v2.read_secret_version(
path=secret_name,
mount_point=self._secrets_path,
)
value = response["data"]["data"][key]

with self._lock:
self._cache[cache_key] = (value, time.time() + cache_ttl_seconds)

return value

def rotate_api_key(self, service_name: str, generate_fn):
"""
Rotate an API key: generate a new one, write to Vault, return new key.
The old key remains valid for `overlap_seconds` to allow in-flight requests.
"""
new_key = generate_fn()
self._ensure_authenticated()

# Write new version (Vault KV v2 keeps version history)
self._client.secrets.kv.v2.create_or_update_secret(
path=f"api-keys/{service_name}",
secret={"api_key": new_key, "rotated_at": str(time.time())},
mount_point=self._secrets_path,
)

# Invalidate cache
with self._lock:
keys_to_remove = [k for k in self._cache
if k.startswith(f"api-keys/{service_name}")]
for k in keys_to_remove:
del self._cache[k]

logging.info(f"[vault] Rotated API key for {service_name}")
return new_key

def get_dynamic_s3_credentials(
self,
bucket: str,
policy: str = "read-only",
ttl: str = "1h",
) -> dict:
"""
Generate short-lived S3 credentials via Vault AWS secrets engine.
Better than static keys: credentials expire automatically.
Used for: ML training jobs that need temporary S3 access.
"""
self._ensure_authenticated()
response = self._client.secrets.aws.generate_credentials(
name=f"ml-{policy}", # Pre-configured Vault role with S3 policy
ttl=ttl,
)
return {
"access_key": response["data"]["access_key"],
"secret_key": response["data"]["secret_key"],
"session_token": response["data"]["security_token"],
"expiry": time.time() + 3600,
}

JWT Validation Middleware for ML APIs

"""
jwt_middleware.py - JWT validation middleware for ML inference API.
Validates tokens, extracts claims, enforces ML-specific rate limits per user.

Install: pip install PyJWT cryptography
"""

import jwt
import time
import logging
from functools import wraps
from typing import Optional, Dict, Any
from collections import defaultdict
import threading


class MLAPIJWTValidator:
"""
JWT validation middleware for ML inference endpoints.
Supports:
- RS256/ES256 signed tokens (asymmetric - verify without the signing key)
- Scoped permissions (inference:read, training:write, admin:all)
- Rate limiting per client_id from token claims
- Model version access control from token claims
"""

def __init__(
self,
public_key_pem: str,
algorithm: str = "RS256",
issuer: str = "https://auth.ml-platform.example.com",
audience: str = "ml-inference-api",
):
self._public_key = public_key_pem
self._algorithm = algorithm
self._issuer = issuer
self._audience = audience

# Rate limiting state: client_id -> (count, window_start)
self._rate_limits: Dict[str, tuple] = defaultdict(lambda: (0, time.time()))
self._rate_limit_lock = threading.Lock()

# Default rate limits per tier from JWT claims
self._tier_limits = {
"free": 10, # 10 requests per minute
"standard": 100,
"enterprise": 10000,
}

def validate_token(self, token: str) -> Dict[str, Any]:
"""
Validate JWT and return claims. Raises ValueError on any validation failure.
"""
try:
claims = jwt.decode(
token,
self._public_key,
algorithms=[self._algorithm],
audience=self._audience,
issuer=self._issuer,
options={
"require": ["exp", "iat", "sub", "scope", "client_id"],
"verify_exp": True,
}
)
except jwt.ExpiredSignatureError:
raise ValueError("Token expired")
except jwt.InvalidAudienceError:
raise ValueError("Token not valid for this API")
except jwt.InvalidIssuerError:
raise ValueError("Token issued by untrusted issuer")
except jwt.DecodeError as e:
raise ValueError(f"Invalid token: {e}")

return claims

def check_scope(self, claims: dict, required_scope: str) -> bool:
"""
Check that the token has the required scope.
Example scopes: "inference:predict", "models:read", "training:submit"
"""
token_scopes = set(claims.get("scope", "").split())
return required_scope in token_scopes

def check_rate_limit(self, client_id: str, tier: str = "standard") -> bool:
"""
Token bucket rate limiting per client_id.
Returns True if request is allowed, False if rate limited.
"""
limit_per_minute = self._tier_limits.get(tier, 100)

with self._rate_limit_lock:
count, window_start = self._rate_limits[client_id]
now = time.time()

# Reset window every minute
if now - window_start > 60:
self._rate_limits[client_id] = (1, now)
return True

if count >= limit_per_minute:
return False

self._rate_limits[client_id] = (count + 1, window_start)
return True

def check_model_access(self, claims: dict, model_name: str,
model_version: str) -> bool:
"""
Check if the token grants access to a specific model version.
Enterprise users get access to all models.
Standard users only get access to stable versions.
Free users only get access to public models.
"""
tier = claims.get("tier", "free")
allowed_models = claims.get("allowed_models", []) # Explicit model whitelist

if tier == "enterprise":
return True
if allowed_models and model_name in allowed_models:
return True
if tier == "standard" and model_version == "stable":
return True
if tier == "free" and model_name in {"public-classifier", "public-embedder"}:
return True

return False


def require_jwt(validator: MLAPIJWTValidator, scope: str):
"""
Decorator for ML inference endpoint functions.
Validates JWT, checks scope, enforces rate limits.
"""
def decorator(func):
@wraps(func)
def wrapper(request, *args, **kwargs):
# Extract token from Authorization header
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return {"error": "Missing or invalid Authorization header"}, 401

token = auth_header[7:] # Strip "Bearer "

try:
claims = validator.validate_token(token)
except ValueError as e:
logging.warning(f"[auth] Token validation failed: {e}")
return {"error": str(e)}, 401

# Scope check
if not validator.check_scope(claims, scope):
logging.warning(f"[auth] Insufficient scope: need {scope}, "
f"got {claims.get('scope', '')}")
return {"error": f"Insufficient scope. Required: {scope}"}, 403

# Rate limit check
client_id = claims.get("client_id", "unknown")
tier = claims.get("tier", "free")
if not validator.check_rate_limit(client_id, tier):
logging.warning(f"[rate-limit] Client {client_id} exceeded limit")
return {"error": "Rate limit exceeded"}, 429

# Inject claims into request for downstream use
request.jwt_claims = claims
return func(request, *args, **kwargs)

return wrapper
return decorator

Model Supply Chain Security

A trained model is a build artifact. Like a software binary, it can be tampered with after the training run - a malicious actor who gains access to the model storage could insert backdoors into model weights. Model signing provides cryptographic assurance that the artifact you are loading is the one that was produced by your training pipeline.

"""
model_signing.py - Sign and verify ML model artifacts.
Uses cosign for OCI artifact signing (same tool used for container image signing).

Install: pip install sigstore (or use cosign CLI)
"""

import hashlib
import json
import subprocess
import tempfile
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional


@dataclass
class ModelProvenance:
"""SLSA-inspired provenance for ML model artifacts."""
model_name: str
model_version: str
training_job_id: str
training_start_time: str
training_end_time: str
dataset_hash: str # SHA256 of training dataset manifest
training_code_commit: str # Git SHA of training code
framework: str # "pytorch==2.2.0"
hardware: str # "8x A100-80GB"
final_loss: float
artifact_sha256: str # SHA256 of the model file itself


def sha256_file(path: str) -> str:
"""Compute SHA256 hash of a file."""
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""):
h.update(chunk)
return h.hexdigest()


def sign_model_artifact(
model_path: str,
provenance: ModelProvenance,
signing_key: str = "gcpkms://projects/ml-platform/locations/global/keyRings/ml-signing/cryptoKeys/model-key",
) -> dict:
"""
Sign a model artifact with cosign.
Attaches both the signature and provenance as OCI annotations.

In production, use keyless signing with OIDC (GitHub Actions, etc.)
or KMS-backed keys.
"""
# Verify the artifact hash matches what's in provenance
actual_hash = sha256_file(model_path)
if actual_hash != provenance.artifact_sha256:
raise ValueError(
f"Artifact hash mismatch. Expected: {provenance.artifact_sha256}, "
f"Got: {actual_hash}"
)

# Write provenance to temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json",
delete=False) as f:
json.dump(asdict(provenance), f, indent=2)
provenance_path = f.name

# Sign with cosign (attaches signature to OCI registry alongside model)
result = subprocess.run(
[
"cosign", "sign-blob",
"--key", signing_key,
"--bundle", f"{model_path}.cosign.bundle",
model_path,
],
capture_output=True, text=True
)

if result.returncode != 0:
raise RuntimeError(f"cosign signing failed: {result.stderr}")

return {
"artifact_path": model_path,
"artifact_sha256": actual_hash,
"signature_bundle": f"{model_path}.cosign.bundle",
"provenance_path": provenance_path,
}


def verify_model_artifact(
model_path: str,
expected_provenance: Optional[ModelProvenance] = None,
signing_key_or_cert: str = "gcpkms://projects/ml-platform/locations/global/keyRings/ml-signing/cryptoKeys/model-key",
) -> bool:
"""
Verify a model artifact's signature before loading.
Call this in your model serving startup code.
"""
bundle_path = f"{model_path}.cosign.bundle"
if not Path(bundle_path).exists():
raise FileNotFoundError(
f"No signature bundle found at {bundle_path}. "
f"Model may be unsigned or tampered."
)

result = subprocess.run(
[
"cosign", "verify-blob",
"--key", signing_key_or_cert,
"--bundle", bundle_path,
model_path,
],
capture_output=True, text=True
)

if result.returncode != 0:
raise SecurityError(f"Model signature verification FAILED: {result.stderr}")

# Additional: verify the file hash matches provenance
actual_hash = sha256_file(model_path)
if expected_provenance and actual_hash != expected_provenance.artifact_sha256:
raise SecurityError(
f"Model file hash does not match provenance record. "
f"Possible tampering: {actual_hash} vs {expected_provenance.artifact_sha256}"
)

return True


class SecurityError(Exception):
pass

Zero-Trust Network Architecture for ML

Firewall Rules for GPU Clusters

# GPU cluster firewall configuration using iptables
# Run on each GPU training node

# Default: drop all traffic except explicitly allowed

# Allow established connections (required for all bidirectional communication)
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT

# Allow SSH from bastion host only
iptables -A INPUT -s 10.0.0.5 -p tcp --dport 22 -j ACCEPT

# Allow NCCL communication between training nodes (specific CIDR for GPU subnet)
iptables -A INPUT -s 10.1.0.0/24 -p tcp --dport 29500:29510 -j ACCEPT
iptables -A INPUT -s 10.1.0.0/24 -p udp --dport 29500:29510 -j ACCEPT

# Allow InfiniBand RDMA (managed by IB kernel drivers, not iptables)
# IB traffic bypasses the kernel TCP/IP stack - secure at the fabric level

# Allow Prometheus node metrics scraping from monitoring subnet
iptables -A INPUT -s 10.0.1.0/24 -p tcp --dport 9100 -j ACCEPT

# Allow DNS queries to cluster DNS
iptables -A OUTPUT -d 10.96.0.10 -p udp --dport 53 -j ACCEPT
iptables -A OUTPUT -d 10.96.0.10 -p tcp --dport 53 -j ACCEPT

# Allow NTP (time synchronization - critical for distributed training timestamps)
iptables -A OUTPUT -p udp --dport 123 -j ACCEPT

# Block all other inbound traffic
iptables -A INPUT -j DROP

# Block outbound internet access from training nodes (prevents data exfiltration)
# Exception: allow access to internal package mirrors only
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT # Internal network OK
iptables -A OUTPUT -j DROP # Block public internet

# Save rules
iptables-save > /etc/iptables/rules.v4

:::danger Model Extraction via Inference API A model with high-confidence softmax outputs is vulnerable to model extraction attacks. By sending thousands of carefully chosen inputs and observing the outputs, an adversary can train a "surrogate model" that approximates your proprietary model. This was demonstrated against commercial APIs including early versions of various image classifiers.

Mitigations:

  1. Return top-K predictions only (not the full probability distribution)
  2. Add calibrated noise to output probabilities (differential privacy)
  3. Implement aggressive rate limiting per client (makes extraction impractically expensive)
  4. Monitor for systematic probing patterns: many requests covering the input distribution uniformly
  5. Limit output precision: round probabilities to 2 decimal places
  6. For critical models, consider returning only the predicted class label with no probability

None of these fully prevent extraction against a determined adversary with unlimited budget. Defense in depth is the goal. :::

:::warning Kubernetes Secrets are Not Encrypted by Default kubectl get secret my-secret -o yaml will show you the "encrypted" data as base64. Base64 is encoding, not encryption. Anyone with kubectl access can read all secrets in the cluster.

For real encryption:

  1. Enable encryption at rest in the kube-apiserver configuration (EncryptionConfiguration)
  2. Use Vault with the vault-agent-injector to inject secrets as environment variables without storing them in Kubernetes
  3. Use the Secrets Store CSI Driver to mount secrets from Vault/AWS SSM directly into pods
  4. Never use environment variables for highly sensitive secrets in containers (they appear in docker inspect and /proc/PID/environ) :::

:::danger Prompt Injection via API For LLM-based ML services, network security must include prompt injection defense. An attacker who can control the content that gets concatenated into your model's prompt can override system instructions, extract the system prompt, or cause the model to perform unintended actions. This is a network-level threat: the malicious content arrives via the API.

Network-level mitigations:

  1. Input validation: block known injection patterns at the API gateway (ignore previous instructions, you are now, etc.)
  2. Output validation: scan LLM output for patterns indicating system prompt leakage
  3. Separate trust boundaries: user-provided content should never have the same trust level as system prompt content in your LLM serving pipeline
  4. Rate limiting on unusual output lengths (very long outputs may indicate a jailbreak succeeded)

These do not fully solve prompt injection - it is fundamentally a model alignment problem. But network-level controls raise the cost of attacks significantly. :::

Interview Q&A

Q1: Explain the difference between TLS and mTLS and when each is appropriate in an ML platform.

Standard TLS authenticates the server to the client: when your browser connects to https://inference.example.com, the server presents a certificate proving its identity, and the client verifies it. The client does not prove its identity. This is appropriate for user-facing APIs where clients are users or external services with no predictable certificates. mTLS authenticates both sides: the client also presents a certificate, and the server verifies it. For internal ML platform communication - inference service calling feature store, training orchestrator calling model registry - mTLS ensures that only authorized internal services can make calls, even if an internal network is compromised. If an inference pod is compromised, it cannot impersonate the training orchestrator to make unauthorized calls, because it does not have the training orchestrator's certificate.

Q2: A Kubernetes pod running an ML inference service was compromised by an attacker who exploited a vulnerability in a dependency. What does the attacker have access to by default, and how would Kubernetes NetworkPolicy and Cilium limit the blast radius?

By default with no NetworkPolicy, the compromised pod can make network connections to any other pod in the cluster, any Kubernetes service, the Kubernetes API server, etcd (if accessible), any cloud metadata endpoint, and potentially the public internet. The attacker can exfiltrate training data, query other ML models, steal secrets from Kubernetes Secrets, and attempt lateral movement. With NetworkPolicy (enforcement by Calico or Cilium), the compromised pod can only make connections that the policy explicitly allows - for example, only to the feature store on port 6565 and DNS on port 53. Attempts to reach etcd, the training data store, or external IPs are blocked at the kernel level before the TCP SYN leaves the node. Cilium's eBPF enforcement adds Layer 7 filtering: even within allowed connections, only specific HTTP paths are permitted, preventing the compromised inference service from making admin API calls to the feature store even on an allowed port.

Q3: What is the threat model for model supply chain attacks, and how does model signing with cosign address it?

The threat: a malicious actor with write access to your model storage (S3, GCS, a model registry) can replace a model file with a backdoored version that behaves normally for most inputs but triggers malicious behavior (e.g., always predicts a specific class, leaks information in output probabilities) for inputs controlled by the attacker. This is analogous to a software supply chain attack (replacing a package with a backdoored version). Model signing with cosign creates a cryptographic signature over the model file's SHA256 hash, signed with a key that only the CI/CD system has access to (ideally a KMS-backed key or keyless OIDC signing). Before loading any model, the serving infrastructure verifies the signature. A replaced model file will have a different hash, the signature check will fail, and the model will not be loaded. The provenance record adds additional assurance: you know not just that the file is intact, but which training job produced it, which dataset was used, and what the final loss was.

Q4: How does cert-manager in Kubernetes automate mTLS certificate lifecycle management, and what happens when a certificate expires?

cert-manager is a Kubernetes controller that watches Certificate resources and automatically issues, renews, and rotates TLS certificates. You configure a ClusterIssuer (backed by a CA, Let's Encrypt, or a Vault PKI mount), then create Certificate resources that specify which service needs a cert, its DNS names, and the issuer. cert-manager creates a Kubernetes Secret containing the cert and key. Pods mount the Secret as a volume. cert-manager monitors certificate expiry and begins renewal when the certificate is within the renewBefore window (typically 7-30 days before expiry). On renewal, cert-manager updates the Secret with the new certificate. If your application reloads TLS certificates dynamically (which most TLS libraries support), zero-downtime rotation is achieved. If a certificate is not renewed (cert-manager is down, CA unreachable), the old certificate expires and mTLS connections begin failing - this is why cert-manager itself should run with multiple replicas and monitoring.

Q5: What is the principle of zero-trust networking and how does it differ from perimeter security for ML infrastructure?

Perimeter security (traditional model): trust everything inside the corporate network or the Kubernetes cluster. The firewall at the network edge is the primary control. Once inside the perimeter, services talk to each other freely. The failure mode: one compromised service has free access to everything else. Zero-trust: never trust any connection by default, regardless of network location. Every service authenticates every request. An internal service is not automatically trusted just because it is on the same network or in the same cluster. For ML infrastructure, zero-trust means: inference pods authenticate to feature store via mTLS (not just shared network); training workers authenticate to data stores with short-lived credentials from Vault (not static keys); network policies prevent any lateral movement even within the ML namespace; JWT validation on every inference request even from internal microservices. The practical benefit: a compromised training worker cannot access the inference model registry. A compromised inference pod cannot read training data. The blast radius of any single compromise is bounded by that service's explicitly granted permissions.

Q6: Describe three ML-specific network-level threats that traditional web security controls do not address.

First, model extraction via API: an authenticated user with legitimate API access systematically probes the model with carefully chosen inputs to reconstruct the model's behavior (and approximately its weights) in a surrogate model. Traditional authentication does not prevent this because the attacker is authorized. Defense: rate limiting, output noise injection, monitoring for probing patterns.

Second, training data poisoning via network: in federated learning or continual learning systems, model updates arrive over the network from potentially untrusted sources. A malicious participant sends crafted gradient updates that cause the model to behave maliciously on specific inputs while appearing normal on others. Defense: gradient clipping, Byzantine-robust aggregation, cryptographic verification of gradient sources.

Third, prompt injection via API for LLM-based services: an attacker controls content that gets incorporated into the LLM's context (user-provided text, retrieved documents, tool outputs) and includes instructions designed to override the system prompt and cause the model to leak sensitive information or take unintended actions. Traditional input sanitization is insufficient because distinguishing benign instructions from malicious ones requires semantic understanding. Defense: input validation, output scanning, trust boundary separation, rate limiting on anomalous outputs.

© 2026 EngineersOfAI. All rights reserved.