:::tip 🎮 Interactive Playground Visualize this concept: Try the vLLM Architecture demo on the EngineersOfAI Playground - no code required. :::
Triton Inference Server and TorchServe
The Production Scenario
Your team has been running model inference through a FastAPI server wrapping PyTorch. It works. Your senior ML engineer wrote it in a weekend - it handles preprocessing, loads the model, runs inference, and returns JSON. It has served you well for 6 months at 500 requests per second.
Then two things happen simultaneously. First, your GPU utilization charts show 35% average utilization. Your model takes 8ms to process a single request, and requests arrive one at a time, so the GPU spends 65% of its time idle. Second, you need to deploy a second model - a TensorRT-optimized version for lower latency. Your FastAPI server only speaks PyTorch. Adding TensorRT support means rewriting the server.
This is exactly the problem that purpose-built serving frameworks solve. Triton Inference Server handles both issues: it natively supports PyTorch, TensorRT, ONNX, TensorFlow, and others (so no rewrite), and it implements dynamic batching (so GPU utilization jumps from 35% to 85%). You configure it with a directory structure and a config file. You do not write inference server code.
Understanding what these frameworks actually give you - and when their complexity is justified - is the practical skill this lesson delivers.
Why This Exists - The Problems with DIY Serving
Building an inference server from scratch with FastAPI or Flask solves the first 80% of the problem easily. The remaining 20% - dynamic batching, multi-framework support, model ensemble, concurrent model execution, health protocols, metrics - each requires significant engineering. Purpose-built serving frameworks come with these features implemented, tested, and tuned.
Specifically:
Dynamic batching: When requests arrive one at a time, each runs as a batch of 1. A GPU A10G can process a batch of 128 in nearly the same time as a batch of 1 for many transformer models. Without dynamic batching, you waste 99% of GPU compute. Implementing a correct dynamic batcher with configurable maximum batch size and maximum wait time is non-trivial. Triton has it built in.
Multi-backend support: A production ML infrastructure typically has models in multiple frameworks: PyTorch for research models, TensorRT-optimized versions for production, ONNX for cross-platform deployment. A single Triton server can serve all of them. No framework-specific server code needed.
Model ensemble pipelines: Running preprocessing → model → postprocessing as a single Triton ensemble eliminates serialization overhead between steps. The output of one model feeds directly into the next in GPU memory.
gRPC and HTTP: Triton exposes both protocols with the same configuration. You get gRPC performance for internal calls and HTTP for debugging, with no code change.
Historical Context
Before purpose-built serving frameworks existed, teams at Google and Facebook built custom inference servers in C++ - code that would never be open-sourced, maintained by small specialized teams. The overhead of maintaining these custom servers was one reason why large companies were initially faster to deploy ML than startups.
NVIDIA released TensorRT Inference Server in 2018, later renamed to Triton Inference Server. It was designed for GPU inference specifically, with TensorRT as the primary backend and dynamic batching from day one. Over time it added support for CPU backends, ONNX, PyTorch (via LibTorch), and TensorFlow.
Facebook open-sourced TorchServe in 2020, jointly with AWS. TorchServe is PyTorch-specific and prioritizes ease of use for the PyTorch ecosystem: a simple CLI to package and serve any PyTorch model, with customizable pre/post-processing in Python.
Both serve different points in the complexity/control spectrum: Triton for maximum performance and multi-framework support, TorchServe for PyTorch teams who want to ship quickly.
NVIDIA Triton Inference Server
Model Repository Structure
Triton uses a directory-based model repository. The directory structure defines what models are available and how to serve them:
model_repository/
├── resnet50/
│ ├── config.pbtxt ← model configuration
│ └── 1/ ← version 1
│ └── model.plan ← TensorRT engine
├── bert_classifier/
│ ├── config.pbtxt
│ └── 1/
│ └── model.pt ← TorchScript model
│ └── 2/
│ └── model.pt ← newer version
└── preprocessing_pipeline/
├── config.pbtxt ← ensemble config
└── (no weights - ensemble is logic only)
A model config for a TensorRT model:
# model_repository/resnet50/config.pbtxt
name: "resnet50"
platform: "tensorrt_plan"
max_batch_size: 64
input [
{
name: "input"
data_type: TYPE_FP32
dims: [ 3, 224, 224 ] # CHW format - Triton handles N (batch) automatically
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ 1000 ] # ImageNet classes
}
]
# Dynamic batching - Triton's killer feature
dynamic_batching {
preferred_batch_size: [ 8, 16, 32 ]
max_queue_delay_microseconds: 5000 # Wait up to 5ms for more requests
}
# Instance groups: how many model instances to load
instance_group [
{
kind: KIND_GPU
count: 2 # 2 concurrent model instances on GPU
}
]
A config for a PyTorch (TorchScript) model:
# model_repository/bert_classifier/config.pbtxt
name: "bert_classifier"
platform: "pytorch_libtorch"
max_batch_size: 32
input [
{
name: "input_ids"
data_type: TYPE_INT64
dims: [ 128 ] # Sequence length
},
{
name: "attention_mask"
data_type: TYPE_INT64
dims: [ 128 ]
}
]
output [
{
name: "logits"
data_type: TYPE_FP32
dims: [ 2 ] # Binary classification
}
]
dynamic_batching {
preferred_batch_size: [ 4, 8, 16 ]
max_queue_delay_microseconds: 10000
}
version_policy {
latest { num_versions: 2 } # Keep 2 most recent versions loaded
}
Starting Triton
# Pull the Triton container
docker pull nvcr.io/nvidia/tritonserver:24.01-py3
# Start with GPU support and the model repository
docker run --gpus all \
-p 8000:8000 \ # HTTP
-p 8001:8001 \ # gRPC
-p 8002:8002 \ # Metrics (Prometheus)
-v /path/to/model_repository:/models \
nvcr.io/nvidia/tritonserver:24.01-py3 \
tritonserver \
--model-repository=/models \
--log-verbose=0 \
--strict-model-config=false
Python Client for Triton
# triton_client.py
import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from PIL import Image
import torchvision.transforms as transforms
def preprocess_image(image_path: str) -> np.ndarray:
"""Preprocess image to ResNet50 input format."""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
img = Image.open(image_path).convert("RGB")
tensor = transform(img).numpy() # Shape: [3, 224, 224]
return tensor[np.newaxis, :] # Add batch dim: [1, 3, 224, 224]
def predict_grpc(
image_paths: list[str],
triton_url: str = "localhost:8001",
model_name: str = "resnet50",
) -> np.ndarray:
"""Batch prediction via gRPC - faster than HTTP for large inputs."""
client = grpcclient.InferenceServerClient(url=triton_url)
# Preprocess all images into a batch
batch = np.concatenate([
preprocess_image(p) for p in image_paths
], axis=0).astype(np.float32) # Shape: [N, 3, 224, 224]
# Build Triton input
inputs = [
grpcclient.InferInput("input", batch.shape, "FP32"),
]
inputs[0].set_data_from_numpy(batch)
# Build Triton output request
outputs = [
grpcclient.InferRequestedOutput("output"),
]
# Call Triton
response = client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs,
)
logits = response.as_numpy("output") # Shape: [N, 1000]
return logits
def predict_http(
features: np.ndarray,
triton_url: str = "localhost:8000",
model_name: str = "bert_classifier",
) -> np.ndarray:
"""HTTP prediction - useful for debugging."""
client = httpclient.InferenceServerClient(url=triton_url)
input_ids = features[:, :128].astype(np.int64)
attention_mask = (input_ids != 0).astype(np.int64)
inputs = [
httpclient.InferInput("input_ids", input_ids.shape, "INT64"),
httpclient.InferInput("attention_mask", attention_mask.shape, "INT64"),
]
inputs[0].set_data_from_numpy(input_ids)
inputs[1].set_data_from_numpy(attention_mask)
outputs = [httpclient.InferRequestedOutput("logits")]
response = client.infer(model_name=model_name, inputs=inputs, outputs=outputs)
return response.as_numpy("logits")
def check_model_metadata(client, model_name: str):
"""Inspect model configuration - useful for debugging."""
metadata = client.get_model_metadata(model_name)
config = client.get_model_config(model_name)
print(f"Model: {model_name}")
print(f"Inputs: {[(i.name, i.shape, i.datatype) for i in metadata.inputs]}")
print(f"Outputs: {[(o.name, o.shape, o.datatype) for o in metadata.outputs]}")
print(f"Max batch size: {config.config.max_batch_size}")
Triton Ensemble Pipelines
Triton ensembles connect models in a DAG - output of one model feeds directly into the next without going through the network:
# model_repository/preprocessing_pipeline/config.pbtxt
name: "preprocessing_pipeline"
platform: "ensemble"
max_batch_size: 64
input [
{
name: "raw_text"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "classification"
data_type: TYPE_FP32
dims: [ 2 ]
}
]
ensemble_scheduling {
step [
{
model_name: "tokenizer"
model_version: 1
input_map {
key: "text"
value: "raw_text" # Maps pipeline input to model input
}
output_map {
key: "input_ids"
value: "tokenized_ids" # Internal tensor name
}
},
{
model_name: "bert_classifier"
model_version: 2
input_map {
key: "input_ids"
value: "tokenized_ids" # Maps tokenizer output to BERT input
}
output_map {
key: "logits"
value: "classification" # Maps to pipeline output
}
}
]
}
TorchServe
TorchServe is PyTorch's native serving framework. It is simpler to configure than Triton and Python-native, making it the right choice for teams in the PyTorch ecosystem who want to ship quickly.
Packaging a Model for TorchServe
# save_model_for_torchserve.py
import torch
from torchvision import models
# Step 1: Train or load your model
model = models.resnet50(pretrained=True)
model.eval()
# Step 2: Convert to TorchScript (required for TorchServe)
# TorchScript is a subset of Python that can be compiled and exported
example_input = torch.randn(1, 3, 224, 224)
scripted_model = torch.jit.script(model)
scripted_model.save("resnet50_scripted.pt")
# Alternatively, use torch.jit.trace for simpler models
traced_model = torch.jit.trace(model, example_input)
traced_model.save("resnet50_traced.pt")
# handler.py - Custom handler for pre/post-processing
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import io
import base64
from ts.torch_handler.base_handler import BaseHandler
class ResNet50Handler(BaseHandler):
"""
Custom TorchServe handler.
BaseHandler provides: model loading, device management, batching.
You implement: preprocess, inference (optional override), postprocess.
"""
def initialize(self, context):
"""Called once at server startup - load model and preprocessing."""
super().initialize(context) # Loads the TorchScript model
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
# Load class labels
with open("imagenet_classes.txt") as f:
self.labels = [line.strip() for line in f.readlines()]
def preprocess(self, data: list) -> torch.Tensor:
"""Convert raw HTTP request body to model input tensor.
data: list of dicts, one per request in the batch.
"""
tensors = []
for item in data:
# Accept both raw bytes and base64-encoded images
image_bytes = item.get("body") or item.get("data")
if isinstance(image_bytes, str):
image_bytes = base64.b64decode(image_bytes)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
tensor = self.transform(img)
tensors.append(tensor)
# Stack into a batch tensor [N, 3, 224, 224]
return torch.stack(tensors).to(self.device)
def postprocess(self, output: torch.Tensor) -> list:
"""Convert model output tensor to JSON-serializable response.
output: tensor [N, 1000] - one row per request in the batch.
"""
probs = torch.softmax(output, dim=1)
top5_probs, top5_indices = torch.topk(probs, k=5, dim=1)
results = []
for i in range(len(top5_probs)):
result = {
"predictions": [
{
"label": self.labels[idx.item()],
"probability": prob.item(),
}
for prob, idx in zip(top5_probs[i], top5_indices[i])
]
}
results.append(result)
return results
Package and serve:
# Package the model into a .mar archive
torch-model-archiver \
--model-name resnet50 \
--version 1.0 \
--serialized-file resnet50_scripted.pt \
--handler handler.py \
--extra-files imagenet_classes.txt \
--export-path model_store/
# Start TorchServe
torchserve \
--start \
--model-store model_store/ \
--models resnet50=resnet50.mar \
--ts-config torchserve_config.properties
# torchserve_config.properties
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
number_of_netty_threads=32
job_queue_size=1000
install_py_dep_per_model=false
# Batching configuration
batch_size=16
max_batch_delay=10 # milliseconds to wait for a full batch
# torchserve_client.py
import requests
import base64
def predict_torchserve(
image_path: str,
torchserve_url: str = "http://localhost:8080",
model_name: str = "resnet50",
) -> dict:
"""Send a prediction request to TorchServe."""
with open(image_path, "rb") as f:
image_bytes = f.read()
response = requests.post(
f"{torchserve_url}/predictions/{model_name}",
data=image_bytes,
headers={"Content-Type": "application/octet-stream"},
)
response.raise_for_status()
return response.json()
def list_models(torchserve_url: str = "http://localhost:8081") -> dict:
"""List all registered models via management API."""
response = requests.get(f"{torchserve_url}/models")
return response.json()
def register_model(
torchserve_url: str,
model_url: str,
model_name: str,
initial_workers: int = 1,
):
"""Register a new model at runtime - no restart needed."""
response = requests.post(
f"{torchserve_url}/models",
params={
"url": model_url,
"model_name": model_name,
"initial_workers": initial_workers,
"synchronous": "true",
},
)
response.raise_for_status()
print(f"Registered {model_name}")
ONNX Runtime: Framework-Agnostic Inference
ONNX Runtime is not a serving framework but an inference engine. It runs ONNX models (exported from PyTorch, TensorFlow, scikit-learn, etc.) with optimized execution across backends:
# onnx_inference.py
import onnxruntime as ort
import numpy as np
import time
def export_pytorch_to_onnx(
model,
example_input: np.ndarray,
output_path: str,
):
"""Export a PyTorch model to ONNX format."""
import torch
# Export with dynamic axes so any batch size works
torch.onnx.export(
model,
torch.from_numpy(example_input),
output_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"}, # Batch dimension is dynamic
"output": {0: "batch_size"},
},
opset_version=17,
do_constant_folding=True, # Fold constants for optimization
)
print(f"Exported to {output_path}")
class OnnxInferenceEngine:
"""Production-ready ONNX Runtime inference engine."""
def __init__(
self,
model_path: str,
device: str = "cuda", # "cuda" or "cpu"
):
providers = []
if device == "cuda":
providers.append(
("CUDAExecutionProvider", {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"gpu_mem_limit": 4 * 1024 * 1024 * 1024, # 4GB
"cudnn_conv_algo_search": "EXHAUSTIVE", # Find best cuDNN algorithm
"do_copy_in_default_stream": True,
})
)
providers.append("CPUExecutionProvider")
# Session options for performance
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
sess_options.intra_op_num_threads = 4
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
self.session = ort.InferenceSession(
model_path,
sess_options=sess_options,
providers=providers,
)
# Inspect model I/O
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
print(f"ONNX model loaded: {device}")
print(f"Input: {self.input_name}, Output: {self.output_name}")
def predict(self, features: np.ndarray) -> np.ndarray:
"""Run inference. Input/output are numpy arrays."""
return self.session.run(
[self.output_name],
{self.input_name: features.astype(np.float32)},
)[0]
def benchmark(
self,
feature_dim: int,
batch_size: int = 32,
n_runs: int = 100,
) -> dict:
"""Benchmark this model at the given batch size."""
dummy_input = np.random.randn(batch_size, feature_dim).astype(np.float32)
# Warmup
for _ in range(10):
self.predict(dummy_input)
times = []
for _ in range(n_runs):
start = time.perf_counter()
self.predict(dummy_input)
times.append((time.perf_counter() - start) * 1000)
times.sort()
return {
"p50_ms": times[int(n_runs * 0.50)],
"p95_ms": times[int(n_runs * 0.95)],
"p99_ms": times[int(n_runs * 0.99)],
"throughput_per_sec": (n_runs * batch_size) / sum(times) * 1000,
}
Framework Comparison
| Framework | Best For | Dynamic Batching | Multi-Framework | Complexity |
|---|---|---|---|---|
| Triton | GPU, high-performance, multi-model | Yes, configurable | Yes (TRT, ONNX, TF, PyTorch) | High |
| TorchServe | PyTorch teams, quick deployment | Yes, via config | No | Medium |
| ONNX Runtime | Cross-framework, embedded, edge | No (DIY) | Via ONNX export | Low |
| FastAPI + model | Early stage, simple use cases | No (DIY) | No | Low |
Benchmarking Serving Frameworks
# framework_benchmark.py
import time
import asyncio
import numpy as np
import aiohttp
from typing import Callable
async def benchmark_endpoint(
url: str,
payload_fn: Callable,
n_requests: int = 1000,
concurrency: int = 10,
) -> dict:
"""Benchmark an HTTP inference endpoint with concurrent requests."""
semaphore = asyncio.Semaphore(concurrency)
latencies = []
async def make_request():
async with semaphore:
payload = payload_fn()
start = time.perf_counter()
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
await response.read()
latency = (time.perf_counter() - start) * 1000
latencies.append(latency)
tasks = [make_request() for _ in range(n_requests)]
await asyncio.gather(*tasks)
latencies.sort()
return {
"n_requests": n_requests,
"concurrency": concurrency,
"p50_ms": latencies[int(n_requests * 0.50)],
"p95_ms": latencies[int(n_requests * 0.95)],
"p99_ms": latencies[int(n_requests * 0.99)],
"throughput_rps": n_requests / (sum(latencies) / 1000),
}
async def compare_frameworks():
feature_dim = 512
payload_fn = lambda: {"features": np.random.randn(feature_dim).tolist()}
print("Benchmarking serving frameworks (512-dim input, 1000 requests, 10 concurrent):")
# FastAPI baseline
fastapi_results = await benchmark_endpoint(
"http://localhost:8080/predict",
payload_fn,
)
print(f"\nFastAPI (baseline):")
print(f" p50: {fastapi_results['p50_ms']:.1f}ms, p99: {fastapi_results['p99_ms']:.1f}ms")
# TorchServe
torchserve_results = await benchmark_endpoint(
"http://localhost:8085/predictions/model",
payload_fn,
)
print(f"\nTorchServe:")
print(f" p50: {torchserve_results['p50_ms']:.1f}ms, p99: {torchserve_results['p99_ms']:.1f}ms")
# Triton HTTP
def triton_payload():
features = np.random.randn(feature_dim).tolist()
return {
"inputs": [{
"name": "input",
"shape": [1, feature_dim],
"datatype": "FP32",
"data": features,
}]
}
triton_results = await benchmark_endpoint(
"http://localhost:8000/v2/models/model/infer",
triton_payload,
)
print(f"\nTriton (HTTP):")
print(f" p50: {triton_results['p50_ms']:.1f}ms, p99: {triton_results['p99_ms']:.1f}ms")
asyncio.run(compare_frameworks())
Kubernetes Deployment for Triton
# kubernetes/triton-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: triton-inference-server
namespace: ml-serving
spec:
replicas: 3
selector:
matchLabels:
app: triton
template:
metadata:
labels:
app: triton
annotations:
prometheus.io/scrape: "true"
prometheus.io/port: "8002"
spec:
containers:
- name: triton
image: nvcr.io/nvidia/tritonserver:24.01-py3
args:
- tritonserver
- --model-repository=s3://my-models/production/
- --log-verbose=0
- --allow-metrics=true
- --metrics-port=8002
ports:
- containerPort: 8000 # HTTP
name: http
- containerPort: 8001 # gRPC
name: grpc
- containerPort: 8002 # Metrics
name: metrics
resources:
requests:
cpu: "4"
memory: "8Gi"
nvidia.com/gpu: "1"
limits:
cpu: "8"
memory: "16Gi"
nvidia.com/gpu: "1"
readinessProbe:
httpGet:
path: /v2/health/ready
port: 8000
initialDelaySeconds: 60 # Wait for model loading
periodSeconds: 10
livenessProbe:
httpGet:
path: /v2/health/live
port: 8000
initialDelaySeconds: 30
periodSeconds: 30
Production Engineering Notes
Triton metrics are Prometheus-compatible: Triton exposes latency, queue depth, GPU utilization, and throughput per model on port 8002. Point Prometheus at it and you get production monitoring for free.
Model loading timeout: Large models (multi-GB) take time to load from S3 or NFS. Set Kubernetes readiness probe initialDelaySeconds generously (60-120 seconds for large models) to prevent health check failures during startup.
TorchServe worker scaling: TorchServe workers are Python processes. For CPU-bound preprocessing, scale workers to match CPU cores. For GPU inference, scale workers to match GPU count (usually 1-2 workers per GPU to handle preprocessing while the GPU is busy on inference).
:::warning Triton Dynamic Batching Max Queue Delay Trade-off
Setting max_queue_delay_microseconds too high increases p99 latency - requests wait longer for a full batch. Too low and you never get a full batch, wasting GPU compute. Profile your traffic to find the right value: if median inter-request gap is 5ms and optimal batch size is 16, set the delay to 15ms to have a good chance of filling the batch without penalizing p99 too heavily.
:::
:::danger TorchScript Limitations
Not all PyTorch code is scriptable with torch.jit.script. Dynamic control flow, Python-specific constructs (dictionaries with non-string keys, some list comprehensions), and calls to non-scriptable libraries will fail. Test scripting early in development - discovering that your model cannot be scripted the day before the deployment deadline is painful. Use torch.jit.trace as a fallback for models with fixed control flow, but be aware that trace bakes in the control flow path taken during tracing.
:::
:::warning ONNX Opset Version Compatibility
ONNX models are versioned by opset - the set of operations the model uses. PyTorch and ONNX Runtime version updates may change which opsets are supported. An ONNX model exported with opset_version=17 may not be runnable on an older ONNX Runtime. Pin ONNX Runtime versions in your production Docker images and test model exports against that pinned version.
:::
Interview Q&A
Q: What does Triton Inference Server's dynamic batching actually do, and why does it improve GPU utilization?
Dynamic batching accumulates requests that arrive within a configurable time window (max_queue_delay_microseconds) and combines them into a single batch before sending to the GPU. A GPU processes a batch of 32 images in nearly the same wall-clock time as a batch of 1 - because the GPU has thousands of CUDA cores that all execute in parallel. Without dynamic batching, every request is a batch of 1, and 96% of GPU cores sit idle on every request. With dynamic batching, you amortize the fixed overhead of a GPU kernel launch across 16-32 requests simultaneously, bringing GPU utilization from 5-10% to 80-90%.
Q: When would you choose TorchServe over Triton?
Choose TorchServe when: your team is entirely in the PyTorch ecosystem and has no near-term plans to use TensorRT or other frameworks, you value faster time-to-deployment over maximum throughput, and your QPS requirements are moderate (under 5,000 per GPU). TorchServe's Python handler is much easier to customize than Triton's backend system. Choose Triton when: you need TensorRT optimization (Triton's TRT backend is first-class), you serve models from multiple frameworks (ONNX, TF, PyTorch), you need ensemble pipelines with minimal serialization overhead, or you need maximum throughput and are willing to invest in the more complex configuration.
Q: What is the ONNX format and why would you export a model to it?
ONNX (Open Neural Network Exchange) is a cross-framework model representation. A model exported to ONNX can be run by ONNX Runtime regardless of which framework trained it (PyTorch, TensorFlow, scikit-learn, XGBoost). Reasons to export to ONNX: (1) Deploy on hardware or runtimes that do not support the training framework directly (edge devices, embedded systems, .NET services). (2) Apply ONNX Runtime graph optimizations that are independent of the training framework. (3) Use Triton's ONNX backend, which is often faster than the PyTorch backend for transformer models. (4) Validate that model inference is framework-independent - useful for auditing.
Q: How would you migrate a FastAPI model server to Triton at scale?
Phase 1 - Shadow mode: run Triton alongside FastAPI, mirror all requests to both, compare outputs and latencies. Validate output consistency and measure Triton's latency improvement. Phase 2 - Canary: route 5% of traffic to Triton, monitor business metrics and error rates. Gradually increase to 100%. Phase 3 - Decommission: once 100% of traffic is on Triton, run FastAPI for 1 week as emergency rollback target, then remove it. Key risk during migration: ensuring Triton preprocessing produces bit-identical outputs to FastAPI preprocessing. Use the shadow comparison data to catch any preprocessing discrepancies before they affect production.
Q: What is a Triton ensemble and when would you use it?
A Triton ensemble defines a pipeline of models (and preprocessing steps) as a DAG in a configuration file. Triton executes the pipeline internally, passing the output of each step directly to the next in GPU memory - without serializing to JSON or making network calls between steps. Use ensembles when: you have a preprocessing model (e.g., tokenizer) followed by the main model, and the preprocessing output is a large tensor that would be expensive to serialize. An ensemble also allows Triton to apply dynamic batching across the entire pipeline, not just individual steps. The cost of an ensemble is configuration complexity - you must define the input/output tensor names and shapes for every step explicitly.
