:::tip 🎮 Interactive Playground Visualize this concept: Try the Edge ML Deployment demo on the EngineersOfAI Playground - no code required. :::
Edge ML Deployment
The Production Scenario
Apple's Face ID authenticates 1.2 billion iPhone users. Every unlock is a neural network inference: the TrueDepth camera captures 30,000 infrared dots projected onto the user's face, the Apple Neural Engine (ANE) runs a 3D face recognition model in under 1 millisecond, and the authentication decision is made. No cloud call. No network latency. No privacy concern about face data leaving the device.
The Face ID model is a compressed neural network specifically designed and trained to run on the Apple Neural Engine - a purpose-built matrix multiplication accelerator. It was quantized to INT8, optimized with Core ML, and compiled into hardware-specific instructions during the iOS installation process. The entire pipeline, from camera to authentication decision, completes before the user's thumb lifts off the side button.
This is edge ML at its most ambitious. But the constraints are universal: whether you are running Face ID on an iPhone, a fall detection model on an Apple Watch, a defect detection model on a factory camera, or a wake word detection model on a smart speaker - the engineering challenges are the same. Your model must be small enough to fit in the device's memory, fast enough to meet latency requirements, efficient enough to not drain the battery, and accurate enough to be useful. All of these goals are in tension. Optimizing one makes the others harder.
This lesson covers the full stack of edge ML deployment: from model compression (making models small enough to fit on devices) to the runtime environments that execute them (TFLite, CoreML, ONNX Runtime), to the operational challenges of updating models on 100 million devices simultaneously.
Why This Exists - The Case for On-Device ML
Cloud inference is the default for most ML applications. Send data to a server, run the model, return the result. Why complicate things with edge deployment?
Latency: A round trip to a cloud server takes 50-200ms even under good network conditions. Many ML applications need faster responses: face unlock (must feel instant), augmented reality (must match 60fps video), keyboard prediction (must appear as the user types). On-device inference eliminates network latency entirely.
Privacy: Users increasingly resist sending personal data - face images, health metrics, personal communications - to external servers. Face ID processes biometric data entirely on the device. Google's speech recognition runs locally on Pixel phones. WhatsApp's message categorization (spam detection) runs on-device. On-device processing means data never leaves the user's control.
Offline capability: Network connectivity is not guaranteed. Factory floor automation, autonomous vehicles, agricultural sensors, and rural health monitoring devices must work when the network is unavailable. Edge ML provides continuous operation without network dependency.
Cost: Cloud inference at 1 billion requests per day costs roughly $50,000-500,000 per month depending on model size and cloud provider. If even a fraction of those inferences can run on-device at zero marginal cost, the savings are significant.
Bandwidth: Sending raw sensor data (HD video, audio, LiDAR) to the cloud for processing consumes enormous bandwidth. Running ML at the edge reduces what must be transmitted to just the result - a single prediction or label rather than megabytes of raw data.
Historical Context
On-device machine learning for mobile devices became practical around 2017-2018. Before that, mobile hardware - while capable of basic neural network operations - lacked the purpose-built accelerators that make edge ML economically viable.
Three hardware developments enabled the current generation of edge ML: Apple's Neural Engine (introduced in the A11 Bionic chip, 2017), Google's Edge TPU (2018, embedded in Pixel phones and Coral devices), and Qualcomm's Hexagon DSP with Hexagon Vector eXtensions (now called AI Engine, present in Snapdragon chips since ~2019).
The software ecosystem followed: TensorFlow Lite (Google, 2017), Core ML (Apple, 2017), and ONNX Runtime (Microsoft, 2019 for mobile). These runtimes handle the translation from a trained model (PyTorch, TensorFlow) to device-specific optimized execution.
Model Compression for Edge
A ResNet-50 model has 25 million parameters and weighs 100MB. An iPhone has 6GB of RAM but only a fraction available to a third-party app. A Raspberry Pi has 4GB RAM shared with the OS. A microcontroller has 512KB. Model compression reduces model size and compute requirements without catastrophic accuracy loss.
The three core techniques:
1. Quantization
Convert weights from FP32 (4 bytes) to INT8 (1 byte) or INT4 (0.5 bytes). A 100MB model becomes 25MB at INT8.
# quantize_for_tflite.py
import tensorflow as tf
import numpy as np
from typing import Callable
def quantize_model_for_edge(
saved_model_path: str,
output_path: str,
calibration_data: np.ndarray,
quantize_type: str = "int8",
):
"""
Quantize a TensorFlow model for TFLite deployment.
quantize_type:
- "float16": FP16 weights, ~2x size reduction, minimal accuracy loss
- "dynamic": INT8 weights only, ~4x size reduction, no calibration needed
- "int8": Full INT8 (weights + activations), ~4x, requires calibration dataset
- "int8_only": INT8 for weights only (for MCUs without INT8 arithmetic)
"""
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
if quantize_type == "float16":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
elif quantize_type == "dynamic":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quantize_type == "int8":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
def representative_dataset():
for i in range(min(500, len(calibration_data))):
sample = calibration_data[i:i+1].astype(np.float32)
yield [sample]
converter.representative_dataset = representative_dataset
tflite_model = converter.convert()
with open(output_path, "wb") as f:
f.write(tflite_model)
original_size_mb = len(tf.io.gfile.GFile(saved_model_path + "/saved_model.pb", "rb").read()) / 1e6
compressed_size_mb = len(tflite_model) / 1e6
print(f"Original: {original_size_mb:.1f} MB")
print(f"Quantized: {compressed_size_mb:.1f} MB")
print(f"Compression: {original_size_mb / compressed_size_mb:.1f}x")
return tflite_model
2. Pruning
Pruning removes model weights that contribute little to accuracy. A well-pruned model might have 80-90% of weights set to zero with less than 1% accuracy drop:
# model_pruning.py
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from typing import Optional
def apply_structured_pruning(
model: nn.Module,
sparsity: float = 0.5,
n_finetuning_steps: int = 1000,
train_loader=None,
optimizer=None,
) -> nn.Module:
"""
Apply structured pruning: remove entire neurons/channels.
Structured pruning actually speeds up inference (no sparse math needed).
sparsity: fraction of channels to remove (0.5 = remove 50% of channels)
"""
# Apply L1 unstructured pruning first to identify least important weights
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
prune.l1_unstructured(module, name="weight", amount=sparsity)
# After pruning, fine-tune to recover accuracy
if train_loader and optimizer:
model.train()
for step, (inputs, labels) in enumerate(train_loader):
if step >= n_finetuning_steps:
break
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
# Make pruning permanent - remove masked weights
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
prune.remove(module, "weight")
# Count remaining non-zero weights
total_params = sum(p.numel() for p in model.parameters())
nonzero_params = sum(p.nonzero().size(0) for p in model.parameters())
print(f"Model sparsity: {1 - nonzero_params/total_params:.1%}")
print(f"Remaining non-zero parameters: {nonzero_params:,} / {total_params:,}")
return model
3. Knowledge Distillation
Train a small "student" model to mimic the behavior of a large "teacher" model. The student learns from the teacher's soft probability outputs (which carry more information than hard labels) and becomes significantly more accurate than a student trained from labels alone:
# knowledge_distillation.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class DistillationLoss(nn.Module):
"""
Combined distillation loss:
- Soft loss: KL divergence between student and teacher soft predictions
- Hard loss: Cross-entropy with ground truth labels
The temperature T controls the "softness" of the teacher's distribution.
Higher T reveals more information about inter-class relationships.
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.7, # Weight for soft loss
):
super().__init__()
self.T = temperature
self.alpha = alpha
self.beta = 1 - alpha
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
# Soft loss: student tries to match teacher's soft probabilities
# T > 1 softens the distribution, making it easier to learn from
soft_student = F.log_softmax(student_logits / self.T, dim=-1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=-1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean")
soft_loss = soft_loss * (self.T ** 2) # Scale by T^2 to maintain gradient magnitude
# Hard loss: standard cross-entropy with ground truth
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * soft_loss + self.beta * hard_loss
def distill(
teacher: nn.Module,
student: nn.Module,
train_loader,
n_epochs: int = 30,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
):
"""Train student model with knowledge distillation from teacher."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = teacher.to(device).eval()
student = student.to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
for epoch in range(n_epochs):
student.train()
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
with torch.no_grad():
teacher_logits = teacher(inputs)
student_logits = student(inputs)
loss = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = student_logits.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
scheduler.step()
print(f"Epoch {epoch+1}/{n_epochs}: "
f"loss={total_loss/len(train_loader):.4f}, "
f"acc={correct/total:.2%}")
return student
Edge Inference Runtimes
TensorFlow Lite
TFLite is Google's on-device inference runtime. It runs on Android, iOS, Raspberry Pi, and microcontrollers:
# tflite_inference.py
import numpy as np
import time
from typing import Optional
class TFLiteInferenceEngine:
"""Production TFLite inference with proper input/output handling."""
def __init__(self, model_path: str, use_nnapi: bool = True):
import tflite_runtime.interpreter as tflite
# NNAPI: Android Neural Networks API - uses device AI accelerator
delegates = []
if use_nnapi:
try:
nnapi_delegate = tflite.load_delegate("libnnapi.so")
delegates.append(nnapi_delegate)
except Exception:
pass # NNAPI not available, fall back to CPU
self.interpreter = tflite.Interpreter(
model_path=model_path,
experimental_delegates=delegates,
num_threads=4, # Use 4 CPU threads for parallel execution
)
self.interpreter.allocate_tensors()
# Cache input/output tensor details
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# Validate quantization requirements
self.input_dtype = self.input_details[0]["dtype"]
self.input_scale = self.input_details[0]["quantization"][0]
self.input_zero_point = self.input_details[0]["quantization"][1]
def preprocess_input(self, features: np.ndarray) -> np.ndarray:
"""Convert FP32 input to INT8 if model is quantized."""
if self.input_dtype == np.int8:
# Quantize: float_val = (int8_val - zero_point) * scale
# So: int8_val = float_val / scale + zero_point
quantized = (features / self.input_scale + self.input_zero_point)
return quantized.astype(np.int8)
return features.astype(np.float32)
def predict(self, features: np.ndarray) -> np.ndarray:
"""Run inference on a single example."""
input_data = self.preprocess_input(features)
input_data = np.expand_dims(input_data, axis=0) # Add batch dimension
self.interpreter.set_tensor(self.input_details[0]["index"], input_data)
self.interpreter.invoke()
output = self.interpreter.get_tensor(self.output_details[0]["index"])
# Dequantize if INT8 output
output_details = self.output_details[0]
if output_details["dtype"] == np.int8:
scale = output_details["quantization"][0]
zero_point = output_details["quantization"][1]
output = (output.astype(np.float32) - zero_point) * scale
return output[0] # Remove batch dimension
def benchmark(self, feature_dim: int, n_runs: int = 1000) -> dict:
"""Benchmark on the target device."""
dummy = np.random.randn(feature_dim).astype(np.float32)
# Warmup
for _ in range(10):
self.predict(dummy)
times = []
for _ in range(n_runs):
start = time.perf_counter()
self.predict(dummy)
times.append((time.perf_counter() - start) * 1000)
times.sort()
return {
"p50_ms": times[int(n_runs * 0.50)],
"p95_ms": times[int(n_runs * 0.95)],
"p99_ms": times[int(n_runs * 0.99)],
"mean_ms": sum(times) / len(times),
}
Core ML (iOS/macOS)
# coreml_converter.py
import coremltools as ct
import torch
def convert_pytorch_to_coreml(
pytorch_model: torch.nn.Module,
input_shape: tuple,
model_name: str,
quantize: bool = True,
) -> ct.models.MLModel:
"""
Convert a PyTorch model to Core ML format for Apple devices.
The converted model runs on:
- Apple Neural Engine (A11+ chips): fastest, most energy efficient
- GPU: for larger batch sizes
- CPU: fallback
Core ML automatically selects the best execution unit.
"""
pytorch_model.eval()
# Step 1: Convert to TorchScript
example_input = torch.randn(1, *input_shape)
scripted = torch.jit.trace(pytorch_model, example_input)
# Step 2: Convert to Core ML
coreml_model = ct.convert(
scripted,
inputs=[ct.TensorType(
name="input",
shape=ct.Shape(shape=(1, *input_shape)),
)],
minimum_deployment_target=ct.target.iOS16,
# Enable Neural Engine acceleration
compute_units=ct.ComputeUnit.ALL,
)
# Step 3: Quantize to INT8 (optional but recommended for mobile)
if quantize:
from coremltools.optimize.coreml import (
OptimizationConfig,
OpLinearQuantizerConfig,
linearly_quantize_weights,
)
config = OptimizationConfig(
global_config=OpLinearQuantizerConfig(mode="linear_symmetric")
)
coreml_model = linearly_quantize_weights(coreml_model, config)
print("Applied INT8 weight quantization")
# Add metadata for App Store review
coreml_model.author = "Your Company"
coreml_model.short_description = model_name
coreml_model.version = "1.0"
# Save
coreml_model.save(f"{model_name}.mlpackage")
print(f"Saved Core ML model: {model_name}.mlpackage")
return coreml_model
OTA Model Updates at Scale
Deploying a model update to 100 million mobile devices is a logistical challenge distinct from any other deployment:
# model_update_client.py (runs on device, in app code)
import os
import hashlib
import requests
import threading
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
import json
@dataclass
class ModelManifest:
model_name: str
version: str
url: str
sha256: str
size_bytes: int
min_app_version: str
class EdgeModelUpdater:
"""
Handles OTA model updates on the device side.
Policy:
- Check for updates on app launch (once per 24h)
- Download in background when on WiFi
- Apply update on next app cold start
- Never block the user on update check or download
- Fall back to bundled model if downloaded model is corrupted
"""
MANIFEST_URL = "https://models.yourapp.com/manifest.json"
MODELS_DIR = Path("/data/local/models/")
UPDATE_CHECK_INTERVAL_HOURS = 24
def __init__(self, model_name: str, bundled_model_path: str):
self.model_name = model_name
self.bundled_model = bundled_model_path
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
def get_current_model_path(self) -> str:
"""Return path to the best available model version."""
downloaded_path = self.MODELS_DIR / f"{self.model_name}.tflite"
if downloaded_path.exists() and self._verify_checksum(downloaded_path):
return str(downloaded_path)
return self.bundled_model
def check_for_update_async(self):
"""Check for updates in a background thread - never blocks main thread."""
thread = threading.Thread(target=self._check_and_download, daemon=True)
thread.start()
def _check_and_download(self):
"""Background thread: check manifest, download if newer version available."""
try:
# Check update interval
flag_path = self.MODELS_DIR / ".last_check"
if flag_path.exists():
import time
age_hours = (time.time() - flag_path.stat().st_mtime) / 3600
if age_hours < self.UPDATE_CHECK_INTERVAL_HOURS:
return
# Fetch manifest
response = requests.get(self.MANIFEST_URL, timeout=10)
manifest_data = response.json()
# Find this model in manifest
model_entry = next(
(m for m in manifest_data["models"] if m["name"] == self.model_name),
None,
)
if not model_entry:
return
manifest = ModelManifest(**model_entry)
# Check if we have this version already
current_version_path = self.MODELS_DIR / ".version"
if current_version_path.exists():
current_version = current_version_path.read_text().strip()
if current_version == manifest.version:
return
# Download the model (only on WiFi in a real app)
self._download_model(manifest)
# Update last check timestamp
flag_path.touch()
except Exception as e:
print(f"Model update check failed: {e}")
def _download_model(self, manifest: ModelManifest):
"""Download model file with integrity verification."""
tmp_path = self.MODELS_DIR / f"{self.model_name}.tmp"
final_path = self.MODELS_DIR / f"{self.model_name}.tflite"
# Download with streaming (don't load entire file into memory)
response = requests.get(manifest.url, stream=True, timeout=60)
sha256 = hashlib.sha256()
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
sha256.update(chunk)
# Verify integrity
if sha256.hexdigest() != manifest.sha256:
tmp_path.unlink()
raise ValueError(f"Model download corrupted: hash mismatch")
# Atomically replace current model
tmp_path.rename(final_path)
# Record version
(self.MODELS_DIR / ".version").write_text(manifest.version)
print(f"Downloaded model {self.model_name} v{manifest.version}")
def _verify_checksum(self, path: Path) -> bool:
"""Verify model file integrity before loading."""
version_file = self.MODELS_DIR / ".version"
if not version_file.exists():
return False
sha256 = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
# Compare against manifest (would need to cache the expected hash)
# Simplified: just check file exists and is non-empty
return path.stat().st_size > 0
Federated Learning: On-Device Training
Federated learning trains models across many devices without centralizing raw data. Each device trains locally on its own data, sends only model weight gradients (not raw data) to a central server, and the server aggregates the gradients:
# federated_learning_client.py
import numpy as np
import torch
import torch.nn as nn
from copy import deepcopy
from typing import Optional
class FederatedLearningClient:
"""
Federated learning client - runs on device.
Process:
1. Download global model from server
2. Fine-tune on local data for N steps
3. Compute delta (local_weights - global_weights)
4. Upload compressed delta to server (not raw data)
5. Server aggregates deltas (FedAvg) and updates global model
"""
def __init__(
self,
global_model: nn.Module,
local_lr: float = 0.01,
local_steps: int = 5,
clip_grad_norm: float = 1.0,
):
self.global_model = deepcopy(global_model)
self.local_model = deepcopy(global_model)
self.local_lr = local_lr
self.local_steps = local_steps
self.clip_grad = clip_grad_norm
def train_local(
self,
local_data: list, # On-device training examples
loss_fn: callable,
) -> dict:
"""
Fine-tune the local model on device data.
Returns weight delta (not raw data).
"""
optimizer = torch.optim.SGD(
self.local_model.parameters(),
lr=self.local_lr,
momentum=0.9,
)
self.local_model.train()
for step in range(self.local_steps):
# Sample a mini-batch from local data
batch = local_data[step % len(local_data)]
inputs, labels = batch
optimizer.zero_grad()
outputs = self.local_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
# Gradient clipping for privacy (limits individual update magnitude)
nn.utils.clip_grad_norm_(self.local_model.parameters(), self.clip_grad)
optimizer.step()
# Compute weight delta
delta = {}
global_state = self.global_model.state_dict()
local_state = self.local_model.state_dict()
for name in global_state:
delta[name] = (local_state[name] - global_state[name]).cpu()
return delta
def apply_global_update(self, new_global_model: nn.Module):
"""Update local model with the new global model from server."""
self.global_model = deepcopy(new_global_model)
self.local_model = deepcopy(new_global_model)
class FederatedServer:
"""
Federated learning server - aggregates client updates.
FedAvg: weighted average of client deltas.
"""
def __init__(self, global_model: nn.Module):
self.global_model = global_model
def aggregate_fedavg(
self,
client_deltas: list[dict],
client_n_samples: list[int],
) -> nn.Module:
"""
Federated Averaging: weight each client's delta by their data size.
Clients with more local data have more influence.
"""
total_samples = sum(client_n_samples)
weights = [n / total_samples for n in client_n_samples]
# Compute weighted average of all client deltas
averaged_delta = {}
for name, param in self.global_model.state_dict().items():
averaged_delta[name] = torch.zeros_like(param)
for delta, weight in zip(client_deltas, weights):
if name in delta:
averaged_delta[name] += delta[name] * weight
# Apply averaged delta to global model
global_state = self.global_model.state_dict()
for name in averaged_delta:
global_state[name] = global_state[name] + averaged_delta[name]
self.global_model.load_state_dict(global_state)
return self.global_model
Real-World Edge ML Examples
iPhone Face ID (Apple Neural Engine):
- Model: 3D face mesh neural network, ~5M parameters
- Hardware: Apple Neural Engine (ANE) - 16 cores, 11 TOPS
- Latency: under 1ms for authentication
- Quantization: INT8 via Core ML
- Privacy: face data never leaves the secure enclave
Google Pixel Camera (Google Neural Core):
- Night Sight: burst photography + multi-frame image fusion ML
- Portrait Mode: real-time depth estimation at 30fps
- Hardware: Google Tensor chip with embedded neural processor
- Runtime: TFLite with NNAPI delegation
iOS Autocorrect and Keyboard Prediction:
- Model: on-device language model, updated with Federated Learning
- Learns user's vocabulary without sending keystroke data to Apple
- FedAvg aggregates updates from millions of devices weekly
Production Engineering Notes
Never ship a model that cannot fall back: On-device models can become corrupted, run out of memory, or fail to load on specific device configurations. Always maintain a smaller, simpler fallback model (or rule-based logic) that runs if the primary model fails.
Battery impact is a first-class metric: A model that drains the battery 5% faster will generate user complaints and app store reviews. Profile energy consumption on real devices (Xcode's Energy Impact gauge, Android's Battery Historian). INT8 quantization typically reduces energy consumption by 2-3x compared to FP32.
Test on the oldest supported device, not your MacBook: The device you develop on is likely 10x more powerful than the device 20% of your users have. Set minimum hardware requirements and test on the oldest device you support before release.
:::warning INT8 Accuracy on Edge Hardware INT8 accuracy on server GPUs (calibrated with your specific dataset) may differ from INT8 accuracy on mobile Neural Engines (which have different INT8 arithmetic implementations). Always benchmark accuracy on target hardware, not just on your development machine. Some models require INT8-aware fine-tuning (quantization-aware training) rather than post-training quantization to achieve acceptable accuracy on mobile hardware. :::
:::danger The Model Size vs RAM Usage Confusion A 10MB TFLite model does not use only 10MB of RAM at runtime. The inference engine allocates activation buffers (intermediate tensors) that can be 2-5x the model size during inference. A 10MB model with large activations may require 50MB of RAM during inference. On devices with 1GB RAM shared with the OS and other apps, this matters. Measure peak RAM usage on the target device, not just model file size. :::
Interview Q&A
Q: Why would you deploy an ML model on-device rather than calling a cloud API?
Four primary reasons: (1) Latency - eliminating the network round trip (50-200ms) enables use cases like face unlock (must feel instant), AR overlays (must match 60fps), and keyboard autocorrect (must appear before the next keystroke). (2) Privacy - biometric data (face, voice, health metrics) can stay on the device and never be transmitted. (3) Offline capability - manufacturing sensors, autonomous vehicles, and rural health monitors must work without connectivity. (4) Cost - at billions of requests per day, shifting inference to devices eliminates cloud compute costs. The trade-offs are model size constraints, hardware variability across device generations, and the complexity of OTA updates.
Q: What is knowledge distillation and how does it enable edge deployment?
Knowledge distillation trains a small student model to mimic a large teacher model's behavior. The student learns from the teacher's soft probability distributions (not just hard labels), which contain richer information - specifically, inter-class similarities that hard labels discard. A student trained with distillation typically achieves better accuracy than a student of the same size trained directly on labels. For edge deployment: train your best large model on your full training infrastructure, then distill it into a model small enough to fit on the target device. Google uses distillation to compress models by 10x with less than 1% accuracy loss for many classification tasks.
Q: How does OTA (over-the-air) model updating work for mobile apps?
The app includes a bundled model at install time (available immediately, no download required). A lightweight update client runs periodically (e.g., once per day on app launch) to check a manifest server for new model versions. When a newer version is available, it downloads the model in the background (typically gated on WiFi connection and battery above a threshold). The downloaded model is cryptographically verified (SHA-256 hash check) before replacing the current model. The new model takes effect on the next app cold start, not the current session. This design keeps updates invisible to users and prevents download failures from affecting the current session.
Q: What is federated learning and why does it matter for mobile ML?
Federated learning trains a model across many devices without centralizing raw data. Each device fine-tunes the current global model on its local data, computes the weight delta (change from global to local weights), and sends only this delta to a central server. The server applies Federated Averaging - a weighted average of all client deltas - to update the global model. No raw user data leaves the device. This enables training on sensitive data (health records, personal communications, browsing history) that users would not consent to send to a server. Apple uses FL for keyboard prediction. Google uses it for Next Word Prediction on Gboard. The privacy property is: the server only sees gradients, not data, and even gradients can be protected with differential privacy noise.
Q: How would you benchmark an edge model deployment before shipping to users?
Four dimensions: (1) Accuracy - run the model on a held-out test set representative of real-world inputs. Compare to the server-side model. Alert if accuracy drop exceeds a threshold (e.g., 1% AUC). (2) Latency - measure p50, p95, and p99 inference time on the oldest supported device in your fleet. This is your actual latency, not your development machine latency. (3) Memory - profile peak RAM usage during inference, including activation buffers. Verify it fits within your budget given other app components. (4) Energy - measure battery drain rate during sustained inference (e.g., continuous real-time object detection) using platform tools (Instruments on iOS, Battery Historian on Android). All four must pass before shipping. A model that is accurate but drains the battery fails just as definitively as one that is fast but inaccurate.
