:::tip 🎮 Interactive Playground Visualize this concept: Try the Safety and Bias Evaluation demo on the EngineersOfAI Playground - no code required. :::
Membership Inference
The Patient Who Wasn't a Patient
The hospital had trained a clinical prediction model on electronic health records from 2018 to 2022. The model was impressive - it predicted readmission risk with 89% AUC. They deployed it in 2023 as a decision support tool for discharge planning.
Three months after deployment, Priya received a legal notice. A patient was claiming that his medical data had been used to train the model without proper consent. The consent forms from the original data collection specified research use, not AI model training. GDPR Article 22 applied. If the patient's data was used in training, the hospital was exposed.
The patient's attorney had hired an ML security expert. The expert ran what's known as a membership inference attack against the model's API. For each of 1,200 patients, they queried the model with that patient's clinical record and observed the confidence score on the model's prediction. Membership inference theory predicts that models produce higher-confidence (lower-loss) outputs for data they were trained on than for data they haven't seen.
The expert identified 847 records that showed the signature of training set membership. The patient's record was among them. The confidence was high enough that the court accepted it as evidence of training data membership. The hospital settled for $2.3 million.
This is membership inference: using a model's behavior to determine whether specific records were in its training set. It's a privacy attack that every organization training models on personal data needs to understand - before the attorney's letter arrives.
The Theory: Why Models Remember
Training a model on a dataset changes the model's parameters to minimize loss on that dataset. As a result, the model achieves lower loss on training examples than on unseen examples - especially when the model is large or overfit.
The key signal: A model's confidence on an input it was trained on tends to be higher (lower loss) than its confidence on a similar input it wasn't trained on.
This is directly related to generalization: a perfectly generalizing model would show no difference in loss between training and test examples. The larger the train-test loss gap, the more "membership signal" leaks from the model's outputs.
The Privacy-Generalization Tradeoff
Here is the core tension in membership inference:
| Model Property | Generalization | Privacy |
|---|---|---|
| Overfit (large gap) | Poor | Poor - high MI risk |
| Well-regularized (small gap) | Good | Better - lower MI risk |
| Differentially private trained | Good | Strong formal guarantee |
The generalization gap is a direct privacy leak. A model that performs identically on training and test data reveals nothing about training membership. A model that performs 10% better on training data than test data is revealing that it "remembers" its training examples - and membership inference exploits this memory.
LLMs and Memorization
Large language models have a particular form of membership inference risk: verbatim memorization. When a model is trained on a specific text, it may reproduce that text near-verbatim given a short prefix. This is different from the statistical signal used in traditional MI attacks - it's direct reproduction of training data.
Research from Carlini et al. (2021) demonstrated that GPT-2 could reproduce hundreds of training examples verbatim, including:
- Phone numbers and addresses of private individuals
- Login credentials in code examples
- Personal emails and messages
- Proprietary code and documents
For LLMs, membership inference has two distinct attack surfaces:
- Confidence-based MI: Use output probabilities to infer membership (classical approach)
- Extraction-based MI: Directly extract training text via prefix completion
Attack Techniques
1. Threshold-Based Attack (Yeom et al., 2018)
The simplest attack: if the model's loss on a sample is below a threshold, classify it as a member.
import numpy as np
import anthropic
client = anthropic.Anthropic()
def threshold_membership_inference(
model_callable: callable,
candidate_records: list[dict],
threshold: float = 0.5,
feature_key: str = "text"
) -> list[dict]:
"""
Threshold-based membership inference attack.
For each candidate record, query the model and use the
output confidence as a membership signal.
Args:
model_callable: Function that returns (output, confidence) tuple
candidate_records: Records to test for membership
threshold: Confidence above this → predict member
feature_key: Key in record dict containing the input feature
"""
results = []
for record in candidate_records:
input_text = record.get(feature_key, "")
output, confidence = model_callable(input_text)
is_predicted_member = confidence > threshold
results.append({
"record_id": record.get("id", "unknown"),
"confidence": confidence,
"predicted_member": is_predicted_member,
"output_preview": str(output)[:100]
})
predicted_members = sum(1 for r in results if r["predicted_member"])
print(f"Predicted members: {predicted_members}/{len(results)} ({predicted_members/len(results)*100:.1f}%)")
return results
def simulate_confidence_distributions(
n_members: int = 500,
n_nonmembers: int = 500,
generalization_gap: float = 0.2 # Controls attack strength
) -> dict:
"""
Simulate the confidence distributions for members vs. non-members.
Shows how generalization gap determines attack effectiveness.
A well-regularized model (small gap) → poor attack accuracy.
An overfit model (large gap) → high attack accuracy.
"""
# Members: higher confidence (parametrically controlled by gap)
member_mean = 0.5 + generalization_gap
member_std = 0.15
member_confidences = np.clip(
np.random.normal(member_mean, member_std, n_members), 0, 1
).tolist()
# Non-members: lower confidence
nonmember_mean = 0.5
nonmember_std = 0.15
nonmember_confidences = np.clip(
np.random.normal(nonmember_mean, nonmember_std, n_nonmembers), 0, 1
).tolist()
# Compute threshold-based attack AUC
all_confidences = member_confidences + nonmember_confidences
labels = [1] * n_members + [0] * n_nonmembers
sorted_pairs = sorted(zip(all_confidences, labels), reverse=True)
tpr_fpr = []
tp = fp = 0
for conf, label in sorted_pairs:
if label == 1:
tp += 1
else:
fp += 1
tpr_fpr.append((fp / n_nonmembers, tp / n_members))
auc = 0.0
for i in range(1, len(tpr_fpr)):
fpr_diff = tpr_fpr[i][0] - tpr_fpr[i-1][0]
auc += fpr_diff * (tpr_fpr[i][1] + tpr_fpr[i-1][1]) / 2
return {
"generalization_gap": generalization_gap,
"member_mean_confidence": np.mean(member_confidences),
"nonmember_mean_confidence": np.mean(nonmember_confidences),
"attack_auc": auc,
"privacy_risk": "high" if auc > 0.7 else "medium" if auc > 0.6 else "low",
"interpretation": f"Attacker can distinguish members from non-members with AUC={auc:.3f}"
}
# Demonstrate the effect of generalization gap
for gap in [0.05, 0.1, 0.2, 0.35]:
result = simulate_confidence_distributions(generalization_gap=gap)
print(f"Gap={gap:.2f} → Attack AUC={result['attack_auc']:.3f} ({result['privacy_risk']} risk)")
2. Shadow Model Attack (Shokri et al., 2017)
The foundational paper. The attacker trains "shadow models" that mimic the target model's training process, then uses those shadows to learn what membership looks like:
from dataclasses import dataclass
from sklearn.linear_model import LogisticRegression
import numpy as np
import random
@dataclass
class ShadowModelAttackConfig:
"""Configuration for a shadow model MI attack."""
n_shadow_models: int = 5
shadow_train_fraction: float = 0.5
attack_model_type: str = "logistic" # "logistic" or "neural"
class ShadowModelAttack:
"""
Full shadow model membership inference attack (Shokri et al., 2017).
The attack:
1. Attacker collects data from the same distribution as target training data
2. Trains N "shadow models" on subsets of this data (knows in/out split)
3. Collects model signals (confidence, entropy) for in-set vs. out-set examples
4. Trains an "attack model": signal → membership prediction
5. Applies attack model to target model outputs
This attack does NOT require access to target model weights -
only API access to get output probabilities.
"""
def __init__(self, config: ShadowModelAttackConfig):
self.config = config
self.attack_model = None
def generate_shadow_datasets(
self,
available_data: list[dict],
) -> list[dict]:
"""Generate shadow training datasets."""
shadow_datasets = []
for _ in range(self.config.n_shadow_models):
shuffled = available_data.copy()
random.shuffle(shuffled)
n_in = int(len(shuffled) * self.config.shadow_train_fraction)
shadow_datasets.append({
"in_data": shuffled[:n_in],
"out_data": shuffled[n_in:]
})
return shadow_datasets
def extract_signal(
self,
model_output: dict
) -> list[float]:
"""
Extract membership signal from model output.
In practice: uses full probability distribution if available.
"""
confidence = model_output.get("confidence", 0.5)
entropy = model_output.get("entropy", 1.0)
# Features for attack classifier
return [
confidence,
1 - confidence, # Uncertainty
entropy, # Output entropy
confidence ** 2, # Non-linear features
abs(confidence - 0.5), # Distance from uncertainty
]
def collect_shadow_signals(
self,
shadow_datasets: list[dict],
shadow_model_trainer: callable,
shadow_model_inference: callable
) -> tuple[list[list[float]], list[int]]:
"""
For each shadow model, collect (signal, membership_label) pairs.
"""
all_signals = []
all_labels = []
for shadow_data in shadow_datasets:
shadow_model = shadow_model_trainer(shadow_data["in_data"])
for example in shadow_data["in_data"]:
output = shadow_model_inference(shadow_model, example)
signal = self.extract_signal(output)
all_signals.append(signal)
all_labels.append(1) # Member
for example in shadow_data["out_data"]:
output = shadow_model_inference(shadow_model, example)
signal = self.extract_signal(output)
all_signals.append(signal)
all_labels.append(0) # Non-member
return all_signals, all_labels
def train_attack_classifier(
self,
signals: list[list[float]],
labels: list[int]
):
"""Train a binary classifier: signal → is_member."""
X = np.array(signals)
y = np.array(labels)
self.attack_model = LogisticRegression(max_iter=1000, C=1.0)
self.attack_model.fit(X, y)
return self.attack_model
def predict_membership(
self,
target_model_output: dict
) -> dict:
"""Predict whether a record is a training set member."""
if self.attack_model is None:
raise ValueError("Attack model not trained yet")
signal = self.extract_signal(target_model_output)
signal_array = np.array(signal).reshape(1, -1)
prob_member = self.attack_model.predict_proba(signal_array)[0][1]
is_member = self.attack_model.predict(signal_array)[0] == 1
return {
"predicted_member": bool(is_member),
"membership_probability": float(prob_member),
"confidence": "high" if abs(prob_member - 0.5) > 0.3 else "low"
}
3. LLM-Specific: Verbatim Memorization Detection
For large language models, membership inference via text extraction:
import anthropic
client = anthropic.Anthropic()
def detect_verbatim_memorization(
candidate_texts: list[str],
prefix_lengths: list[int] = [50, 100, 200, 300],
match_threshold: float = 0.75
) -> list[dict]:
"""
Detect verbatim memorization in LLMs.
For each candidate text, provide a prefix and check if the model
completes it with the actual continuation. If yes, the text was
likely in training data.
This attack is especially effective against:
- Books and articles (copyrighted content)
- Code with distinctive patterns
- Medical or legal records with unique identifiers
- Personal information appearing in training data
Args:
candidate_texts: Full texts to test for memorization
prefix_lengths: Number of characters to use as prefix
match_threshold: Word overlap ratio to consider a match
"""
results = []
for text in candidate_texts:
memorization_signals = []
for prefix_len in prefix_lengths:
if len(text) < prefix_len + 100:
continue
prefix = text[:prefix_len]
actual_continuation = text[prefix_len:prefix_len+300]
# Ask model to complete the prefix
try:
response = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=300,
messages=[{
"role": "user",
"content": f"Complete the following text naturally, continuing from where it leaves off:\n\n{prefix}"
}]
)
model_continuation = response.content[0].text
except Exception:
continue
# Compute word-level overlap
actual_words = set(actual_continuation.lower().split())
model_words = set(model_continuation.lower().split())
if len(actual_words | model_words) > 0:
overlap = len(actual_words & model_words) / len(actual_words | model_words)
else:
overlap = 0.0
# Also check for exact substring match (stronger signal)
# Look for 5-gram overlap
def get_ngrams(text: str, n: int = 5) -> set:
words = text.lower().split()
return {' '.join(words[i:i+n]) for i in range(len(words) - n + 1)}
actual_5grams = get_ngrams(actual_continuation)
model_5grams = get_ngrams(model_continuation)
ngram_overlap = (
len(actual_5grams & model_5grams) / max(len(actual_5grams), 1)
if actual_5grams else 0
)
memorization_signals.append({
"prefix_length": prefix_len,
"word_overlap": overlap,
"ngram_5_overlap": ngram_overlap,
"memorized": overlap > match_threshold or ngram_overlap > 0.3
})
is_memorized = any(s["memorized"] for s in memorization_signals)
max_overlap = max((s["word_overlap"] for s in memorization_signals), default=0.0)
max_ngram = max((s["ngram_5_overlap"] for s in memorization_signals), default=0.0)
results.append({
"text_preview": text[:150] + "...",
"is_memorized": is_memorized,
"max_word_overlap": max_overlap,
"max_ngram_overlap": max_ngram,
"signals": memorization_signals,
"risk_level": "critical" if max_ngram > 0.5 else "high" if is_memorized else "low"
})
memorized_count = sum(1 for r in results if r["is_memorized"])
print(f"Memorization detected in {memorized_count}/{len(results)} texts")
print(f"High-risk memorizations: {sum(1 for r in results if r['risk_level'] == 'critical')}")
return results
def measure_memorization_risk(
model_callable: callable,
training_samples: list[str],
held_out_samples: list[str],
prefix_length: int = 100
) -> dict:
"""
Measure a model's memorization risk.
For each training sample, check if the model can reproduce it.
Compare against held-out samples to measure above-chance memorization.
"""
def get_continuation_match(text: str, prefix_len: int) -> float:
prefix = text[:prefix_len]
actual = text[prefix_len:prefix_len + 200]
generated = model_callable(prefix)
actual_words = set(actual.lower().split())
gen_words = set(generated.lower().split())
if len(actual_words | gen_words) == 0:
return 0.0
return len(actual_words & gen_words) / len(actual_words | gen_words)
train_matches = [get_continuation_match(s, prefix_length) for s in training_samples[:100]]
held_out_matches = [get_continuation_match(s, prefix_length) for s in held_out_samples[:100]]
avg_train_match = sum(train_matches) / max(len(train_matches), 1)
avg_held_out_match = sum(held_out_matches) / max(len(held_out_matches), 1)
memorization_rate = avg_train_match - avg_held_out_match
return {
"avg_train_match": avg_train_match,
"avg_held_out_match": avg_held_out_match,
"memorization_rate": memorization_rate,
"privacy_risk": "high" if memorization_rate > 0.15 else "medium" if memorization_rate > 0.05 else "low",
"verbatim_memorization_count": sum(1 for m in train_matches if m > 0.75)
}
Measuring Your Model's Privacy Risk
Before deploying, measure how much membership information your model leaks:
import numpy as np
from sklearn.metrics import roc_auc_score
def measure_membership_inference_risk(
model_inference_fn: callable, # Function(example) → loss_value
train_examples: list[dict],
test_examples: list[dict],
) -> dict:
"""
Comprehensive membership inference risk assessment.
Runs multiple attack variants and reports the worst-case privacy risk.
Use this before deploying any model trained on sensitive data.
Args:
model_inference_fn: Function that takes an example and returns loss
train_examples: Examples from training set (max 500 for efficiency)
test_examples: Examples NOT in training set (same distribution)
"""
n = min(500, len(train_examples), len(test_examples))
# Collect losses
train_losses = [model_inference_fn(ex) for ex in train_examples[:n]]
test_losses = [model_inference_fn(ex) for ex in test_examples[:n]]
# Basic statistics
mean_train_loss = np.mean(train_losses)
mean_test_loss = np.mean(test_losses)
generalization_gap = mean_test_loss - mean_train_loss
# Attack 1: Simple threshold attack
# Use negative loss as "confidence" (lower loss = higher confidence)
all_neg_losses = [-l for l in train_losses] + [-l for l in test_losses]
all_labels = [1] * n + [0] * n
mi_auc = roc_auc_score(all_labels, all_neg_losses)
# Attack 2: Optimal threshold attack accuracy
threshold = np.mean(all_neg_losses)
predictions = [1 if nl > threshold else 0 for nl in all_neg_losses]
accuracy = sum(p == l for p, l in zip(predictions, all_labels)) / len(all_labels)
# Attack 3: Percentile-based threshold (finds optimal threshold)
thresholds = np.percentile([-l for l in train_losses], [25, 50, 75])
best_acc = 0.5 # Baseline (chance)
for thresh in thresholds:
preds = [1 if nl > thresh else 0 for nl in all_neg_losses]
acc = sum(p == l for p, l in zip(preds, all_labels)) / len(all_labels)
best_acc = max(best_acc, acc)
# Privacy risk assessment
if mi_auc > 0.75:
risk_level = "critical"
elif mi_auc > 0.65:
risk_level = "high"
elif mi_auc > 0.55:
risk_level = "medium"
else:
risk_level = "low"
recommendations = {
"critical": [
"Apply differential privacy (epsilon < 4) immediately",
"Do NOT deploy with sensitive training data",
"Consider synthetic data generation instead"
],
"high": [
"Apply differential privacy (epsilon < 8)",
"Increase regularization and early stopping",
"Restrict logprob API access"
],
"medium": [
"Apply regularization and dropout",
"Monitor for overfitting in production",
"Restrict logprob API access"
],
"low": [
"Privacy risk acceptable with access controls",
"Monitor generalization gap over time",
]
}
return {
"mean_train_loss": float(mean_train_loss),
"mean_test_loss": float(mean_test_loss),
"generalization_gap": float(generalization_gap),
"mi_attack_auc": float(mi_auc),
"threshold_attack_accuracy": float(accuracy),
"best_threshold_accuracy": float(best_acc),
"risk_level": risk_level,
"recommendations": recommendations[risk_level],
"n_samples": n
}
Defenses
1. Differential Privacy (DP-SGD)
The gold standard for membership inference defense. Mathematically bounds how much any individual record can influence the model:
import torch
import torch.nn as nn
def train_with_differential_privacy(
model: nn.Module,
train_loader,
epochs: int = 10,
epsilon: float = 8.0,
delta: float = 1e-5,
max_grad_norm: float = 1.0,
learning_rate: float = 1e-3
) -> dict:
"""
Train with DP-SGD (Differentially Private SGD).
DP-SGD provides formal privacy guarantees:
- Clips per-example gradients to max_grad_norm
- Adds calibrated Gaussian noise to clipped gradients
- Guarantees (epsilon, delta)-differential privacy
The privacy budget epsilon:
- epsilon < 1: Very strong privacy (medical/financial data standard)
- 1 <= epsilon <= 4: Strong privacy, significant utility cost
- 4 <= epsilon <= 10: Reasonable privacy for most sensitive use cases
- epsilon > 10: Weak privacy - consider whether DP is worth the cost
Delta should be << 1/N where N is training set size.
Typical: delta = 1e-5 for N ~ 100,000 examples.
Requires: pip install opacus
"""
try:
from opacus import PrivacyEngine
except ImportError:
print("Install opacus: pip install opacus")
return {"error": "opacus not installed"}
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=epochs,
target_epsilon=epsilon,
target_delta=delta,
max_grad_norm=max_grad_norm,
)
train_losses = []
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
n_batches = 0
for batch in train_loader:
inputs, labels = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
train_losses.append(epoch_loss / max(n_batches, 1))
epsilon_spent = privacy_engine.get_epsilon(delta)
print(f"Epoch {epoch+1}: loss={train_losses[-1]:.4f}, epsilon_spent={epsilon_spent:.3f}")
final_epsilon = privacy_engine.get_epsilon(delta)
return {
"final_epsilon": final_epsilon,
"target_epsilon": epsilon,
"delta": delta,
"max_grad_norm": max_grad_norm,
"train_losses": train_losses,
"privacy_guarantee": f"({final_epsilon:.2f}, {delta})-differential privacy"
}
def explain_dp_epsilon_tradeoffs() -> dict:
"""
Practical guide to choosing epsilon for DP-SGD.
Shows typical accuracy-privacy tradeoffs.
"""
return {
"epsilon_guide": {
"< 1": {
"privacy_level": "Very strong",
"typical_accuracy_drop": "5-15%",
"use_cases": "Medical records, financial data, biometrics",
"notes": "HIPAA/GDPR high-risk category compliance"
},
"1-4": {
"privacy_level": "Strong",
"typical_accuracy_drop": "3-10%",
"use_cases": "Sensitive personal data, employee records",
"notes": "Good balance for most regulatory requirements"
},
"4-10": {
"privacy_level": "Moderate",
"typical_accuracy_drop": "1-5%",
"use_cases": "General personal data, user behavior",
"notes": "Common production choice for large datasets"
},
"> 10": {
"privacy_level": "Weak",
"typical_accuracy_drop": "< 1%",
"use_cases": "Non-sensitive data",
"notes": "Consider whether DP overhead is worth the protection"
}
}
}
2. Regularization and Early Stopping
Less formal but simpler: reduce overfitting to reduce the train-test gap that membership inference exploits:
import torch
import torch.nn as nn
def train_with_privacy_regularization(
model: nn.Module,
train_loader,
val_loader,
epochs: int = 50,
l2_lambda: float = 0.01,
dropout_rate: float = 0.3,
early_stop_patience: int = 5,
learning_rate: float = 1e-3,
target_gen_gap: float = 0.1 # Stop if generalization gap exceeds this
) -> dict:
"""
Privacy-preserving training via regularization + early stopping.
Core idea: reduce train-test loss gap → reduce MI attack signal.
Techniques:
- L2 regularization: penalizes large weights, prevents memorization
- Dropout: random deactivation prevents reliance on specific examples
- Early stopping: halt before model memorizes training data
- Gradient clipping: limits impact of outlier examples
Not a formal guarantee like DP-SGD, but achieves meaningful privacy
improvement with minimal accuracy cost.
"""
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
weight_decay=l2_lambda # L2 regularization
)
criterion = nn.CrossEntropyLoss()
best_val_loss = float('inf')
patience_count = 0
training_history = []
for epoch in range(epochs):
# Training
model.train()
train_loss = 0.0
n_train_batches = 0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
# Gradient clipping: limits per-example influence
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
n_train_batches += 1
train_loss /= max(n_train_batches, 1)
# Validation
model.eval()
val_loss = 0.0
n_val_batches = 0
with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(inputs)
val_loss += criterion(outputs, labels).item()
n_val_batches += 1
val_loss /= max(n_val_batches, 1)
generalization_gap = val_loss - train_loss
training_history.append({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"generalization_gap": generalization_gap,
"mi_risk_estimate": "high" if generalization_gap > 0.3 else "medium" if generalization_gap > 0.1 else "low"
})
# Privacy-aware early stopping: stop if gap is too large
if generalization_gap > target_gen_gap and epoch > 5:
print(f"Privacy-aware early stop: generalization gap {generalization_gap:.3f} > target {target_gen_gap}")
break
# Standard early stopping on val loss
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_count = 0
else:
patience_count += 1
if patience_count >= early_stop_patience:
print(f"Early stopping at epoch {epoch+1} (patience exceeded)")
break
final_entry = training_history[-1]
return {
"epochs_trained": len(training_history),
"final_generalization_gap": final_entry["generalization_gap"],
"final_train_loss": final_entry["train_loss"],
"final_val_loss": final_entry["val_loss"],
"mi_risk_estimate": final_entry["mi_risk_estimate"],
"training_history": training_history
}
3. Machine Unlearning (Right to Be Forgotten)
When a user requests data deletion under GDPR, the model must "unlearn" their records:
import torch
import anthropic
client = anthropic.Anthropic()
def gradient_ascent_unlearning(
model: torch.nn.Module,
forget_set: list,
retain_set: list,
get_loss_fn: callable,
unlearn_lr: float = 1e-4,
unlearn_steps: int = 100,
retain_regularization: float = 0.5
) -> dict:
"""
Machine unlearning via gradient ascent on forget set.
Core idea:
- Gradient DESCENT minimizes loss (trains on data)
- Gradient ASCENT maximizes loss on forget set (untrains those examples)
- Regularize with gradient DESCENT on retain set to preserve performance
This is an approximation of full retraining - less expensive but
provides no formal guarantees. Verify with MI test afterward.
Args:
model: Trained model to unlearn from
forget_set: Examples to forget (e.g., user's deleted records)
retain_set: Examples to maintain performance on
get_loss_fn: Function(model, batch) → scalar loss
unlearn_lr: Learning rate for unlearning steps
unlearn_steps: Number of gradient steps
retain_regularization: Weight for retain set loss (higher = more conservative)
"""
optimizer = torch.optim.SGD(model.parameters(), lr=unlearn_lr)
batch_size = 8
forget_losses_history = []
retain_losses_history = []
for step in range(unlearn_steps):
optimizer.zero_grad()
# Sample forget batch
start_idx = (step * batch_size) % len(forget_set)
forget_batch = forget_set[start_idx:start_idx + batch_size]
if not forget_batch:
forget_batch = forget_set[:batch_size]
# Gradient ASCENT on forget set
forget_loss = get_loss_fn(model, forget_batch)
ascent_loss = -forget_loss # Negative = ascending
# Sample retain batch
retain_start = (step * batch_size * 4) % len(retain_set)
retain_batch = retain_set[retain_start:retain_start + batch_size * 4]
if not retain_batch:
retain_batch = retain_set[:batch_size * 4]
# Gradient DESCENT on retain set
retain_loss = get_loss_fn(model, retain_batch)
# Combined objective
total_loss = ascent_loss + retain_regularization * retain_loss
total_loss.backward()
# Clip gradients to prevent instability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
forget_losses_history.append(forget_loss.item())
retain_losses_history.append(retain_loss.item())
if (step + 1) % 20 == 0:
print(f"Step {step+1}/{unlearn_steps}: forget_loss={forget_loss.item():.4f}, retain_loss={retain_loss.item():.4f}")
# Evaluate unlearning success
initial_forget_loss = forget_losses_history[0]
final_forget_loss = forget_losses_history[-1]
# Successful unlearning: forget loss should have increased significantly
loss_increase_ratio = final_forget_loss / max(initial_forget_loss, 1e-8)
return {
"unlearning_steps": unlearn_steps,
"initial_forget_loss": initial_forget_loss,
"final_forget_loss": final_forget_loss,
"loss_increase_ratio": loss_increase_ratio,
"unlearning_successful": loss_increase_ratio > 2.0,
"final_retain_loss": retain_losses_history[-1],
"recommendation": "Run MI verification test to confirm unlearning" if loss_increase_ratio > 2.0 else "Increase unlearn_steps or lr and retry"
}
def verify_unlearning_with_mi_test(
model_before_unlearning: callable,
model_after_unlearning: callable,
forget_set_examples: list,
held_out_examples: list,
get_confidence_fn: callable
) -> dict:
"""
Verify that unlearning succeeded using a membership inference test.
After successful unlearning:
- Model's confidence on forget_set should match held-out (never-seen) examples
- Membership inference attack should perform near chance on forget set
Args:
model_before_unlearning: Original model
model_after_unlearning: Unlearned model
forget_set_examples: Examples that should be forgotten
held_out_examples: Examples never in training (ground truth non-members)
get_confidence_fn: Function(model, example) → float confidence
"""
import numpy as np
# Before unlearning: forget set should look like members
before_forget_confidences = [get_confidence_fn(model_before_unlearning, ex)
for ex in forget_set_examples]
held_out_confidences = [get_confidence_fn(model_before_unlearning, ex)
for ex in held_out_examples]
# After unlearning: forget set should look like non-members
after_forget_confidences = [get_confidence_fn(model_after_unlearning, ex)
for ex in forget_set_examples]
before_gap = np.mean(before_forget_confidences) - np.mean(held_out_confidences)
after_gap = np.mean(after_forget_confidences) - np.mean(held_out_confidences)
# Unlearning is verified if the gap closes to near-zero
gap_reduction = 1 - abs(after_gap) / max(abs(before_gap), 1e-8)
return {
"before_forget_confidence": float(np.mean(before_forget_confidences)),
"after_forget_confidence": float(np.mean(after_forget_confidences)),
"held_out_confidence": float(np.mean(held_out_confidences)),
"membership_gap_before": float(before_gap),
"membership_gap_after": float(after_gap),
"gap_reduction": float(gap_reduction),
"unlearning_verified": abs(after_gap) < 0.05, # Gap < 5%
"recommendation": (
"Unlearning complete - forget set indistinguishable from non-members"
if abs(after_gap) < 0.05
else f"Unlearning incomplete - gap={after_gap:.3f}. Run more steps or retrain."
)
}
Regulatory Compliance
Membership inference intersects directly with privacy regulations:
| Regulation | MI-Related Requirement | Technical Response |
|---|---|---|
| GDPR Art. 17 | Right to erasure from training data | Machine unlearning + verification |
| GDPR Art. 5(1)(c) | Data minimization | Train on minimum necessary data |
| HIPAA | PHI must not be inferrable from model | DP-SGD (epsilon < 4) + output filtering |
| CCPA | Delete personal data on request | Unlearning + full provenance tracking |
| EU AI Act | High-risk systems need privacy by design | DP training + MI risk assessment required |
Privacy Risk Assessment Checklist
import anthropic
client = anthropic.Anthropic()
def privacy_risk_assessment(model_config: dict) -> dict:
"""
Assess membership inference risk for a model deployment.
Returns a structured risk report with actionable recommendations.
model_config keys:
- data_sensitivity: "medical" | "financial" | "legal" | "biometric" | "general"
- generalization_gap: float (optional, from train/val evaluation)
- n_training_examples: int
- returns_logprobs: bool
- includes_eu_residents: bool
- training_approach: "standard" | "dp_sgd" | "federated"
- model_size: "small" | "medium" | "large" | "xlarge"
"""
risks = []
recommendations = []
# Check 1: Training data sensitivity
data_sensitivity = model_config.get("data_sensitivity", "general")
if data_sensitivity in ("medical", "financial", "legal", "biometric"):
risks.append({
"category": "data_sensitivity",
"level": "critical",
"detail": f"Training on {data_sensitivity} data creates severe MI privacy risk"
})
recommendations.append(f"Apply DP-SGD with epsilon < 4 for {data_sensitivity} data")
recommendations.append("Run pre-deployment MI risk measurement")
# Check 2: Generalization gap
gen_gap = model_config.get("generalization_gap", None)
if gen_gap is not None:
if gen_gap > 0.3:
risks.append({
"category": "overfitting",
"level": "high",
"detail": f"Large generalization gap ({gen_gap:.2f}) amplifies MI attack signal"
})
recommendations.append("Apply L2 regularization, dropout, and early stopping")
elif gen_gap > 0.1:
risks.append({
"category": "moderate_overfitting",
"level": "medium",
"detail": f"Moderate generalization gap ({gen_gap:.2f}) - some MI risk"
})
# Check 3: Training data size
n_training = model_config.get("n_training_examples", 0)
if n_training < 1000:
risks.append({
"category": "very_small_dataset",
"level": "critical",
"detail": "Very small training sets have extremely high memorization and MI risk"
})
recommendations.append("Use synthetic data generation or federated learning")
elif n_training < 10000:
risks.append({
"category": "small_dataset",
"level": "medium",
"detail": "Small training sets have higher memorization and MI risk"
})
recommendations.append("Consider data augmentation and DP-SGD")
# Check 4: Logprob exposure
if model_config.get("returns_logprobs", False):
risks.append({
"category": "logprob_exposure",
"level": "high",
"detail": "Returning logprobs gives attackers rich membership signal (full distribution)"
})
recommendations.append("Restrict logprob access or add calibrated noise to logprobs")
# Check 5: GDPR scope
if model_config.get("includes_eu_residents", False):
risks.append({
"category": "gdpr_scope",
"level": "medium",
"detail": "EU resident data triggers GDPR Art. 17 right-to-erasure requirements"
})
recommendations.append("Implement machine unlearning capability before EU deployment")
recommendations.append("Maintain per-user training data provenance for deletion requests")
# Check 6: Training approach
training_approach = model_config.get("training_approach", "standard")
if training_approach == "standard" and data_sensitivity in ("medical", "financial", "legal"):
risks.append({
"category": "no_privacy_training",
"level": "critical",
"detail": "Sensitive data trained without differential privacy"
})
# Check 7: Large model memorization risk
model_size = model_config.get("model_size", "medium")
if model_size in ("large", "xlarge") and n_training < 100000:
risks.append({
"category": "large_model_memorization",
"level": "high",
"detail": "Large models on small datasets have high verbatim memorization risk"
})
recommendations.append("Test for verbatim memorization before deployment")
# Overall risk
critical_risks = sum(1 for r in risks if r["level"] == "critical")
high_risks = sum(1 for r in risks if r["level"] == "high")
if critical_risks >= 2:
overall_risk = "critical"
elif critical_risks >= 1:
overall_risk = "critical"
elif high_risks >= 2:
overall_risk = "high"
elif high_risks >= 1 or risks:
overall_risk = "medium"
else:
overall_risk = "low"
return {
"overall_risk": overall_risk,
"risks": risks,
"recommendations": list(dict.fromkeys(recommendations)), # Deduplicate
"approved_for_deployment": overall_risk not in ("critical",),
"requires_dp_training": overall_risk in ("critical", "high") and data_sensitivity in ("medical", "financial", "legal", "biometric")
}
Common Mistakes
:::danger Mistake 1: Training on Sensitive Data Without DP Training on healthcare, financial, or legal records without differential privacy is a significant liability. Even with access controls on the training data, a deployed model that's queryable can leak membership information. The hospital settlement in our opening story started at $2.3M - and that was for a single plaintiff. Apply DP-SGD from the start for sensitive data; retrofitting is much harder. :::
:::danger Mistake 2: No Machine Unlearning Capability Under GDPR Article 17, individuals have the right to erasure of their personal data. If that data was used in model training, you need a way to remove its influence. Build machine unlearning infrastructure before you launch in EU markets. Courts have accepted MI attack evidence as proof of training membership - this is not theoretical. :::
:::warning Mistake 3: Exposing Logprobs by Default Raw logprob access gives attackers dramatically more information for membership inference than text outputs alone. The full probability distribution reveals far more about training membership than a sampled response. Restrict logprob access to verified use cases and add noise where possible. :::
:::warning Mistake 4: Ignoring Verbatim Memorization for LLMs LLMs don't just have statistical MI risk - they can reproduce training text verbatim. Test your models for verbatim memorization before deployment, especially if training data includes sensitive documents, PII, or proprietary text. Memorization rate tends to increase with model size and decrease with dataset size. :::
:::tip Best Practice: Privacy-Utility Profiling Before Deployment Before training, profile the expected privacy-utility tradeoff for your epsilon budget. Train small-scale models with different epsilon values and measure both task accuracy and MI-attack AUC. This tells you the epsilon where privacy benefits start outweighing performance costs for your specific task and dataset. Different tasks have very different epsilon sensitivity. :::
Interview Questions and Answers
Q1: What is membership inference and why does it matter in production systems?
Membership inference is the ability to determine, given a model and a data record, whether that record was used in training the model. It matters because: (1) it's a concrete privacy violation - knowing that someone's medical record was in a clinical model's training data reveals they were a patient; (2) it has direct regulatory consequences under GDPR Article 17 and HIPAA; (3) it can be exploited to verify data breaches. In production, the main risk is that models trained on sensitive personal data can be queried to reveal membership of specific individuals - and courts have accepted MI attack results as evidence.
Q2: What is the generalization gap and how does it relate to membership inference?
The generalization gap is the difference between training loss and validation/test loss. A model with a small gap generalizes well - it performs similarly on seen and unseen data. A model with a large gap is overfit - it has "memorized" training data, performing much better on training examples. Membership inference exploits this gap: training examples have lower loss than non-training examples, so a simple threshold classifier can distinguish them. Practically: a model with a gap of 0.1 has MI attack AUC around 0.55 (barely above chance); a model with a gap of 0.35 might have MI AUC of 0.75 (significant privacy risk). Reducing overfitting via regularization, dropout, and early stopping directly reduces MI vulnerability.
Q3: How does differential privacy defend against membership inference?
Differential privacy provides a mathematical guarantee: any single training example's influence on the model's parameters is bounded. DP-SGD achieves this by: (1) clipping per-example gradients to a maximum norm, limiting each example's contribution; (2) adding calibrated Gaussian noise to the clipped gradients before the update. The formal guarantee is (epsilon, delta)-DP: for any two training sets that differ by one example, the model distributions are within a multiplicative factor of e^epsilon of each other. When epsilon is small, the model's behavior is nearly identical whether or not any single record was included - making membership inference unreliable. The tradeoff: smaller epsilon = stronger privacy guarantee but larger noise = greater accuracy cost.
Q4: What is machine unlearning and why is it technically challenging?
Machine unlearning is the process of removing the influence of specific training examples from a trained model, without retraining from scratch. It's technically challenging because: (1) model weights are a non-linear function of all training examples combined - there's no simple undo for one example; (2) naive approaches (gradient ascent on the forget set) can degrade performance on other examples if not carefully regularized; (3) verifying that unlearning succeeded requires running membership inference tests - if the forget set still looks like members, unlearning failed; (4) it's computationally expensive to run frequently for large models. Current best approaches: gradient ascent with retain-set regularization (most practical), data partitioning with partial retraining (more principled), or model-agnostic approaches that require no access to weights (for API-based deployments). Full verification requires comparing to a model retrained without the forget set.
Q5: How would you evaluate whether your model is GDPR-compliant with respect to training data membership?
Five steps: (1) Scope assessment - identify what personal data of EU residents is in the training set and obtain proper legal basis for its use. (2) Privacy risk measurement - compute MI-attack AUC on held-out examples; AUC above 0.65 is a concern. (3) Technical controls - apply DP if AUC is high; test for verbatim memorization of personal data; ensure no PII appears in model outputs. (4) Process controls - maintain a record of which individuals' data was used in each model version; implement a deletion request workflow with unlearning capability. (5) Machine unlearning capability - before deploying on EU resident data, verify you can process Art. 17 deletion requests, either via gradient ascent unlearning or by retraining without the deleted data. Document all of this for GDPR Article 30 record-keeping requirements.
Q6: When should you use DP-SGD versus regularization-only approaches for privacy?
Use DP-SGD when: (1) your data is in a sensitive category under GDPR or HIPAA (medical, financial, biometric, legal); (2) you need formal, quantifiable privacy guarantees for regulatory compliance; (3) individuals could face real harm if their membership is inferred (patients, employees, crime victims). Use regularization-only when: (1) data is general consumer behavior without heightened sensitivity; (2) the model is not queryable externally; (3) the accuracy cost of DP-SGD (typically 3-10%) is unacceptable for your use case. Regularization reduces MI risk but provides no formal guarantee - an adversary with enough queries can still identify members. DP-SGD provides a formal bound on membership advantage. For truly sensitive data, DP-SGD is the defensible choice even with the accuracy cost.
