Skip to main content

Wavelets and Multiscale Analysis

Reading time: ~40 minutes Interview relevance: Medium-high for audio/speech ML, IoT, and signal processing roles; conceptual knowledge expected for ML Research Engineers Target roles: ML Engineer (Audio/Speech), Research Engineer, IoT/Sensor ML, Signal Processing ML

The Real Interview Moment

You're interviewing at a medical AI company. The interviewer asks: "We have ECG signals sampled at 500 Hz. We need to detect both slow baseline wander (0.05–1 Hz) and rapid QRS complexes (5–40 Hz) simultaneously. Fourier analysis doesn't work well here - why not, and what would you use instead?"

The correct diagnosis: "Fourier analysis assumes stationarity - it gives you global frequency content but no temporal localization. A QRS complex lasts 80ms; its spectral content averaged over 10 seconds is meaningless. STFT helps, but its time-frequency resolution is fixed by the window size - good for low frequencies, bad for high ones (or vice versa).

Wavelets solve this by using a basis that is simultaneously localized in time AND frequency. They use shorter windows for high frequencies (good time resolution) and longer windows for low frequencies (good frequency resolution) - matching how signals actually behave. For ECG, a Daubechies wavelet at multiple scales would give you both the slow drift and the sharp QRS."

This is the core value proposition of wavelets: adaptive time-frequency resolution that the Fourier transform cannot provide.

The Limitation of Fourier: No Time Localization

The DFT of a signal asks: "Which frequencies are present globally?" But for non-stationary signals, we need: "Which frequencies are present at which time?"

The STFT solves this partially via windowing - but with a fundamental trade-off:

Heisenberg uncertainty principle for signals: ΔtΔf14π\Delta t \cdot \Delta f \geq \frac{1}{4\pi}

You cannot simultaneously have arbitrarily precise time AND frequency resolution. The window size determines this trade-off:

  • Short window: Good time resolution (locate events precisely in time), poor frequency resolution (blurry spectrum)
  • Long window: Good frequency resolution, poor time resolution

The STFT uses a fixed window - so it has the same time-frequency resolution at all scales. For a 40ms window:

  • A 500 Hz component is resolved perfectly (10+ cycles in window)
  • A 5 Hz component is unresolvable (0.2 cycles in window)

Wavelets use variable window length: shorter windows at high frequencies, longer windows at low frequencies. This matches the Heisenberg trade-off optimally across scales.

The Wavelet: A Localized Oscillation

A wavelet is a brief oscillating function ψ(t)\psi(t) that satisfies:

ψ(t)dt=0(zero mean)\int_{-\infty}^{\infty} \psi(t)\, dt = 0 \quad \text{(zero mean)} ψ(t)2dt=1(unit energy)\int_{-\infty}^{\infty} |\psi(t)|^2\, dt = 1 \quad \text{(unit energy)}

The zero-mean condition ensures ψ\psi is oscillatory (has both positive and negative parts - it's a "small wave" = wavelet). The unit-energy condition is normalization.

From a mother wavelet ψ(t)\psi(t), we generate a family of daughter wavelets by scaling (stretching/compressing) and translating (shifting):

ψs,τ(t)=1sψ ⁣(tτs)\psi_{s,\tau}(t) = \frac{1}{\sqrt{s}} \psi\!\left(\frac{t - \tau}{s}\right)

where:

  • s>0s > 0: scale parameter (inversely related to frequency: large ss = low frequency)
  • τ\tau: translation parameter (time position of the wavelet)
  • 1s\frac{1}{\sqrt{s}}: normalization to preserve energy across scales

Large scale ss → stretched wavelet → captures slow variations (low frequency) Small scale ss → compressed wavelet → captures rapid variations (high frequency)

Continuous Wavelet Transform (CWT)

The CWT measures the similarity between the signal and the daughter wavelet at each (s,τ)(s, \tau):

W(s,τ)=x(t)1sψ ⁣(tτs)dtW(s, \tau) = \int_{-\infty}^{\infty} x(t) \cdot \frac{1}{\sqrt{s}}\psi^*\!\left(\frac{t-\tau}{s}\right) dt

where ψ\psi^* denotes complex conjugate.

The result is a 2D scalogram: scale × time, showing "which scales (frequencies) are present at each time."

import numpy as np
import matplotlib.pyplot as plt
import pywt
from scipy.signal import chirp, morlet2

# ─── Common Mother Wavelets ──────────────────────────────────────────────────
def mexican_hat_wavelet(t: np.ndarray) -> np.ndarray:
"""
Mexican Hat (Ricker) wavelet: second derivative of Gaussian.
ψ(t) = (1 - t²) * exp(-t²/2)

Good for: detecting peaks, ridges, blob-like features.
Used in: seismic analysis, image edge detection.
"""
return (1 - t**2) * np.exp(-t**2 / 2)

def morlet_wavelet(t: np.ndarray, omega0: float = 5.0) -> np.ndarray:
"""
Morlet wavelet: complex sinusoid modulated by Gaussian.
ψ(t) = exp(iω₀t) * exp(-t²/2) (real part: cos-modulated Gaussian)

Good for: oscillatory signals, speech, audio, EEG.
omega0: central frequency parameter (typically 5-6)
The larger omega0, the more oscillatory (better freq resolution,
worse time resolution).
"""
return np.real(np.exp(1j * omega0 * t) * np.exp(-t**2 / 2))

def haar_wavelet(t: np.ndarray) -> np.ndarray:
"""
Haar wavelet: simplest possible wavelet.
ψ(t) = +1 for t ∈ [0, 0.5), -1 for t ∈ [0.5, 1), 0 elsewhere.

Good for: step detection, simple compression.
Pros: compact support, computationally efficient.
Cons: discontinuous - poor for smooth signals.
"""
result = np.zeros_like(t, dtype=float)
result[(t >= 0) & (t < 0.5)] = 1.0
result[(t >= 0.5) & (t < 1.0)] = -1.0
return result

# Display wavelet properties
t = np.linspace(-4, 4, 1000)
wavelets = {
'Mexican Hat': mexican_hat_wavelet(t),
'Morlet (ω₀=5)': morlet_wavelet(t, omega0=5.0),
'Haar': haar_wavelet((t + 4) / 8) # shift to [0,1] domain
}

print("Wavelet Properties:")
print(f"{'Name':<20} | {'Mean':>10} | {'Energy':>10} | {'Support':>10}")
print("-" * 60)
for name, w in wavelets.items():
dt = t[1] - t[0]
mean = np.trapz(w, t)
energy = np.trapz(w**2, t)
nonzero = np.sum(np.abs(w) > 0.01) * dt
print(f"{name:<20} | {mean:>10.4f} | {energy:>10.4f} | {nonzero:>8.2f}s")

Computing the CWT with PyWavelets

def compute_cwt(
signal: np.ndarray,
fs: float = 1.0,
wavelet: str = 'cmor1.5-1.0', # complex Morlet
scales: np.ndarray = None,
freqs_hz: np.ndarray = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute Continuous Wavelet Transform using PyWavelets.

wavelet options (pywt):
- 'cmor1.5-1.0': complex Morlet (bandwidth=1.5, center_freq=1.0)
- 'mexh': Mexican Hat (real)
- 'morl': Morlet (real approximation)
- 'gaus1'...'gaus8': Gaussian derivatives

Returns:
coefs: (n_scales, T) complex array - CWT coefficients
freqs: (n_scales,) array - corresponding frequencies in Hz
scales: (n_scales,) array - scale parameters used
"""
if scales is None:
# Create logarithmically spaced scales covering 1-100 Hz
freqs_target = np.logspace(np.log10(1), np.log10(fs/2), num=50)
scales = pywt.frequency2scale(wavelet, freqs_target / fs)

coefs, freqs_normalized = pywt.cwt(signal, scales, wavelet)
freqs = freqs_normalized * fs # convert to Hz

return coefs, freqs, scales

# Test signal: two frequency components appear at different times
fs = 500
t = np.linspace(0, 2, 2 * fs, endpoint=False)

# 50 Hz in first second, 150 Hz in second second
signal_nonstationay = np.where(t < 1.0,
np.sin(2 * np.pi * 50 * t),
np.sin(2 * np.pi * 150 * t))
signal_nonstationay += 0.1 * np.random.randn(len(t))

coefs, freqs, scales = compute_cwt(signal_nonstationay, fs=fs)
power = np.abs(coefs)**2

print(f"CWT output shape: {coefs.shape} ({len(freqs)} scales × {len(t)} timepoints)")
print(f"Frequency range: {freqs[-1]:.1f}{freqs[0]:.1f} Hz")

# Find peak frequency at two time points
t1_idx = len(t) // 4 # t = 0.5s (in 50 Hz zone)
t2_idx = 3 * len(t) // 4 # t = 1.5s (in 150 Hz zone)

peak_f1 = freqs[np.argmax(power[:, t1_idx])]
peak_f2 = freqs[np.argmax(power[:, t2_idx])]
print(f"\nPeak frequency at t=0.5s: {peak_f1:.1f} Hz (true: 50 Hz)")
print(f"Peak frequency at t=1.5s: {peak_f2:.1f} Hz (true: 150 Hz)")

Discrete Wavelet Transform (DWT)

The CWT is highly redundant - there are infinitely many (s,τ)(s, \tau) combinations. The DWT uses a dyadic (powers of 2) grid of scales and translations, giving a non-redundant (orthonormal) representation.

Dyadic subsampling: W[j,k]=x(t)2j/2ψ(2jtk)dtW[j, k] = \int x(t) \cdot 2^{-j/2} \psi(2^{-j}t - k)\, dt

where jj is the level (scale) and kk is the translation index.

Multiresolution Analysis (MRA)

The DWT implements Multiresolution Analysis: decompose the signal into approximations (low frequency content) and details (high frequency content) at each level.

At each level jj:

  • Approximation coefficients aja_j: low-pass filtered + downsampled (coarse signal)
  • Detail coefficients djd_j: high-pass filtered + downsampled (fine structure at scale 2j2^j)

The filter pair {h,g}\{h, g\} (low-pass and high-pass) is the key - different wavelets correspond to different filter choices.

def dwt_multiresolution(
signal: np.ndarray,
wavelet: str = 'db4',
levels: int = 5
) -> dict:
"""
Discrete Wavelet Transform with Multiresolution Analysis.

Decomposition tree:
signal → [A1, D1] (level 1: A1=approx, D1=detail)
→ [A2, D2] from A1
→ [A3, D3] from A2
...
→ [A_L, D_L] from A_{L-1}

Reconstruction: signal ≈ A_L + D_L + D_{L-1} + ... + D_1

wavelet options:
- 'haar': Haar (simplest, step-like)
- 'db4': Daubechies 4 (smooth, 4 vanishing moments)
- 'db8': Daubechies 8 (smoother, longer support)
- 'sym8': Symlet 8 (nearly symmetric, smooth)
- 'coif3': Coiflet 3 (near-symmetric, good for signals with smooth trends)
- 'bior3.5': Biorthogonal (different filters for analysis and synthesis)
"""
# Decompose
coeffs = pywt.wavedec(signal, wavelet=wavelet, level=levels)
# coeffs = [cA_L, cD_L, cD_{L-1}, ..., cD_1]
# Length of each level: len(cD_j) ≈ len(signal) / 2^j

print(f"DWT Decomposition (wavelet={wavelet}, levels={levels})")
print(f"{'Level':>8} | {'Type':>8} | {'Length':>8} | {'Frequency band':>20} | {'Energy':>10}")
print("-" * 65)

fs = 1.0 # normalized
energies = []
for i, c in enumerate(coeffs):
if i == 0:
level = levels
ctype = 'Approx'
f_high = fs / (2**(levels+1))
f_low = 0
else:
level = levels - i + 1
ctype = f'Detail {levels - i + 1}'
f_high = fs / (2**i)
f_low = fs / (2**(i+1))

energy = np.sum(c**2)
energies.append(energy)
band = f"{f_low:.3f}{f_high:.3f}"
print(f"{level:>8} | {ctype:>8} | {len(c):>8} | {band:>20} | {energy:>10.2f}")

total_energy = np.sum(energies)
print(f"\nTotal energy preserved: {total_energy:.4f} (original: {np.sum(signal**2):.4f})")

return {
'coeffs': coeffs,
'wavelet': wavelet,
'levels': levels
}

# Apply DWT to ECG-like signal
np.random.seed(42)
n = 1024 # power of 2 for clean DWT
t_ecg = np.linspace(0, 4, n) # 4 seconds at 256 Hz
fs_ecg = 256

# Simulate ECG: slow baseline wander + QRS complex
baseline = 0.1 * np.sin(2 * np.pi * 0.05 * t_ecg) # 0.05 Hz baseline drift
qrs = np.zeros(n)
for beat_t in [0.5, 1.3, 2.1, 2.9]: # 4 beats
idx = int(beat_t * fs_ecg)
if idx < n:
qrs[max(0,idx-10):min(n,idx+10)] += np.exp(-np.linspace(-3,3,20)**2)
noise = 0.05 * np.random.randn(n)
ecg_signal = baseline + qrs + noise

dwt_result = dwt_multiresolution(ecg_signal, wavelet='db4', levels=5)

Wavelet Denoising: The Thresholding Approach

One of the most powerful wavelet applications is signal denoising. The key insight: signal energy is concentrated in a few large wavelet coefficients; noise is spread across many small coefficients.

Denoising algorithm (Donoho & Johnstone, 1994):

  1. DWT: decompose the signal into wavelet coefficients
  2. Threshold: zero out or shrink small coefficients (noise)
  3. Inverse DWT: reconstruct from thresholded coefficients

Two thresholding strategies:

  • Hard thresholding: c^=c1c>λ\hat{c} = c \cdot \mathbf{1}_{|c| > \lambda} - keeps or kills coefficients
  • Soft thresholding: c^=sign(c)max(cλ,0)\hat{c} = \text{sign}(c) \cdot \max(|c| - \lambda, 0) - shrinks by λ\lambda

The universal threshold (VisuShrink): λ=σ2lnN\lambda^* = \sigma\sqrt{2 \ln N} where σ\sigma is the noise level (estimated from the finest-scale detail coefficients).

def wavelet_denoise(
signal: np.ndarray,
wavelet: str = 'db4',
levels: int = None,
threshold_method: str = 'soft',
threshold_rule: str = 'universal'
) -> np.ndarray:
"""
Denoise a signal using wavelet thresholding.

Steps:
1. DWT decomposition
2. Estimate noise from finest scale detail coefficients
3. Apply threshold to detail coefficients (not approximation!)
4. Inverse DWT reconstruction

threshold_method: 'soft' (shrink) or 'hard' (kill/keep)
threshold_rule:
'universal': λ = σ√(2 ln N) - tends to over-smooth
'bayes': BayesShrink - minimizes Bayes risk (often better)
'minimax': minimize maximum risk
"""
n = len(signal)
if levels is None:
levels = min(5, pywt.dwt_max_level(n, wavelet))

# Decompose
coeffs = pywt.wavedec(signal, wavelet=wavelet, level=levels)

# Estimate noise standard deviation from finest detail (level 1)
# Robust estimate: median absolute deviation / 0.6745 (MAD estimator)
finest_detail = coeffs[-1]
sigma_est = np.median(np.abs(finest_detail)) / 0.6745

print(f"Noise level estimate σ = {sigma_est:.4f}")

# Compute threshold
if threshold_rule == 'universal':
threshold = sigma_est * np.sqrt(2 * np.log(n))
elif threshold_rule == 'minimax':
threshold = sigma_est * (0.3936 + 0.1829 * np.log2(n))
else: # bayes / default
threshold = sigma_est # simplified BayesShrink

print(f"Threshold λ = {threshold:.4f} ({threshold_rule})")

# Apply thresholding to DETAIL coefficients only
# (preserve approximation coefficients - they contain the signal trend)
coeffs_thresholded = [coeffs[0]] # keep approximation unchanged
for detail_c in coeffs[1:]:
if threshold_method == 'soft':
thresholded = pywt.threshold(detail_c, threshold, mode='soft')
else: # hard
thresholded = pywt.threshold(detail_c, threshold, mode='hard')
coeffs_thresholded.append(thresholded)

# Reconstruct
denoised = pywt.waverec(coeffs_thresholded, wavelet=wavelet)

return denoised[:n] # trim to original length


# Test denoising
noise_level = 0.2
clean_signal = np.sin(2 * np.pi * 5 * t_ecg) + 0.3 * np.sin(2 * np.pi * 15 * t_ecg)
noisy_signal = clean_signal + np.random.normal(0, noise_level, n)

denoised_soft = wavelet_denoise(noisy_signal, wavelet='db4', threshold_method='soft')
denoised_hard = wavelet_denoise(noisy_signal, wavelet='db4', threshold_method='hard')

# Evaluate
noisy_rmse = np.sqrt(np.mean((noisy_signal - clean_signal)**2))
soft_rmse = np.sqrt(np.mean((denoised_soft - clean_signal)**2))
hard_rmse = np.sqrt(np.mean((denoised_hard - clean_signal)**2))

print(f"\nDenoising Results:")
print(f" Noisy signal RMSE: {noisy_rmse:.4f}")
print(f" Soft thresholding RMSE: {soft_rmse:.4f} ({(1-soft_rmse/noisy_rmse)*100:.1f}% reduction)")
print(f" Hard thresholding RMSE: {hard_rmse:.4f} ({(1-hard_rmse/noisy_rmse)*100:.1f}% reduction)")

Wavelet Packets: Full Decomposition Tree

Standard DWT only decomposes the approximation branch (low-frequency subband). Wavelet packets decompose both approximation and detail at every level, giving the full binary tree of subbands.

def wavelet_packet_features(
signal: np.ndarray,
wavelet: str = 'db4',
levels: int = 4
) -> dict:
"""
Wavelet packet decomposition for feature extraction.

Produces 2^L subbands at level L, each covering a frequency band of width fs/2^L.
From each subband, extract: energy, entropy, mean, std.

Applications:
- EEG: delta (0-4Hz), theta (4-8Hz), alpha (8-13Hz), beta (13-30Hz), gamma (>30Hz)
- Fault diagnosis: specific frequency bands indicate specific fault types
- Audio: sub-band features for speech/music classification
"""
wp = pywt.WaveletPacket(signal, wavelet=wavelet, mode='symmetric', maxlevel=levels)

# Extract all nodes at level `levels`
nodes = wp.get_level(levels, 'freq')

features = {}
for node in nodes:
coefs = node.data
band = f"band_{node.path}"

# Energy in this subband
energy = np.sum(coefs**2)

# Shannon entropy of normalized squared coefficients
p = coefs**2 / (energy + 1e-10)
p = p[p > 0]
entropy = -np.sum(p * np.log(p))

features[band] = {
'energy': energy,
'entropy': entropy,
'mean': np.mean(coefs),
'std': np.std(coefs),
'max_abs': np.max(np.abs(coefs))
}

# Print energy distribution
total_energy = sum(f['energy'] for f in features.values())
print(f"\nWavelet Packet Energy Distribution (wavelet={wavelet}, levels={levels}):")
print(f"Total energy: {total_energy:.4f}")
print(f"Number of subbands: {len(features)}")

# Top 3 energy bands
sorted_bands = sorted(features.items(), key=lambda x: x[1]['energy'], reverse=True)
print("\nTop 3 energy subbands:")
for band, f in sorted_bands[:3]:
pct = f['energy'] / total_energy * 100
print(f" {band}: energy={f['energy']:.4f} ({pct:.1f}%)")

return features

wp_feats = wavelet_packet_features(ecg_signal, wavelet='db4', levels=4)

# Build feature vector for ML
feature_vector = np.array([
f['energy'] for f in wp_feats.values()
] + [
f['entropy'] for f in wp_feats.values()
])
print(f"\nFeature vector shape: {feature_vector.shape}") # 2 * 2^4 = 32 features

Wavelet Features for ML Models

def extract_wavelet_ml_features(
series: np.ndarray,
wavelet: str = 'db4',
levels: int = 4,
feature_types: list = None
) -> np.ndarray:
"""
Extract wavelet-based features for use in standard ML models.

For time series classification (e.g., motor fault detection,
activity recognition, ECG classification), wavelet features often
outperform raw features because they:
1. Capture signal at multiple temporal scales
2. Are robust to small temporal shifts
3. Separate smooth trends from transient events
4. Provide compact representation (coefficients can be compressed)

feature_types: subset of ['energy', 'entropy', 'stats']
"""
if feature_types is None:
feature_types = ['energy', 'entropy', 'stats']

coeffs = pywt.wavedec(series, wavelet=wavelet, level=levels)

features = []

for level_coefs in coeffs:
if 'energy' in feature_types:
features.append(np.sum(level_coefs**2))

if 'entropy' in feature_types:
p = level_coefs**2
p_norm = p / (np.sum(p) + 1e-10)
p_norm = p_norm[p_norm > 0]
features.append(-np.sum(p_norm * np.log(p_norm + 1e-10)))

if 'stats' in feature_types:
features.extend([
np.mean(np.abs(level_coefs)),
np.std(level_coefs),
np.max(np.abs(level_coefs)),
np.percentile(np.abs(level_coefs), 75)
])

return np.array(features)

# Example: classify two types of signals
np.random.seed(0)
n_samples = 200
n_signal_len = 256

# Class 0: low-frequency dominant signal
# Class 1: high-frequency dominant signal
X_features = []
y_labels = []

for _ in range(n_samples // 2):
# Class 0: slow oscillation with noise
sig = np.sin(2 * np.pi * 3 * np.linspace(0, 1, n_signal_len))
sig += 0.3 * np.random.randn(n_signal_len)
X_features.append(extract_wavelet_ml_features(sig, levels=4))
y_labels.append(0)

for _ in range(n_samples // 2):
# Class 1: fast oscillation with noise
sig = np.sin(2 * np.pi * 40 * np.linspace(0, 1, n_signal_len))
sig += 0.3 * np.random.randn(n_signal_len)
X_features.append(extract_wavelet_ml_features(sig, levels=4))
y_labels.append(1)

X = np.array(X_features)
y = np.array(y_labels)

print(f"Feature matrix shape: {X.shape}")
print(f"Feature range: [{X.min():.4f}, {X.max():.4f}]")

# Train simple classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

clf = RandomForestClassifier(n_estimators=50, random_state=42)
scores = cross_val_score(clf, X, y, cv=5, scoring='accuracy')
print(f"\nRandom Forest on wavelet features:")
print(f" CV accuracy: {scores.mean():.4f} ± {scores.std():.4f}")

The Connection to Deep Learning: WaveNet and CNNs

Wavelet-Inspired CNN Architectures

Modern deep learning architectures have independently rediscovered wavelet concepts.

WaveNet (van den Oord et al., 2016): A convolutional architecture for raw audio generation that uses dilated causal convolutions. The dilations follow powers of 2: 1, 2, 4, 8, 16, ... - exactly the dyadic scale hierarchy of wavelets.

At dilation d=2kd=2^k: the convolution sees samples spaced 2k2^k apart - equivalent to processing the signal at scale 2k2^k. Stack KK layers: the receptive field is 2K2^K samples - the full signal.

import torch
import torch.nn as nn

class DilatedCausalConv(nn.Module):
"""
Dilated causal convolution - the core of WaveNet.
Dilation d: kernel sees samples at positions t, t-d, t-2d, ...

Stack with d = 1, 2, 4, 8, ..., 2^K to get receptive field 2^(K+1).
This is equivalent to wavelet multiresolution in a learned sense.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int = 2, dilation: int = 1):
super().__init__()
self.conv = nn.Conv1d(
in_channels, out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=(kernel_size - 1) * dilation # causal padding
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, channels, time)
out = self.conv(x)
out = out[:, :, :-self.conv.padding[0]] # remove future padding (causal)
return out

class WaveNetBlock(nn.Module):
"""
WaveNet residual block with gated activation.
"""
def __init__(self, channels: int, dilation: int):
super().__init__()
self.filter_conv = DilatedCausalConv(channels, channels, dilation=dilation)
self.gate_conv = DilatedCausalConv(channels, channels, dilation=dilation)
self.residual_conv = nn.Conv1d(channels, channels, kernel_size=1)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Gated activation: tanh(filter) * sigmoid(gate)
h = torch.tanh(self.filter_conv(x)) * torch.sigmoid(self.gate_conv(x))
skip = self.residual_conv(h)
return x + skip, skip # residual + skip connection

class SimpleWaveNet(nn.Module):
"""
Simplified WaveNet for time series.
Stack of dilated causal convolutions with exponentially increasing dilation.
"""
def __init__(
self,
input_size: int = 1,
channels: int = 32,
n_blocks: int = 4,
n_layers_per_block: int = 5, # dilations: 1,2,4,8,16 per block
output_size: int = 1
):
super().__init__()

self.input_conv = nn.Conv1d(input_size, channels, kernel_size=1)

self.blocks = nn.ModuleList()
for b in range(n_blocks):
for l in range(n_layers_per_block):
dilation = 2**l
self.blocks.append(WaveNetBlock(channels, dilation))

self.output_net = nn.Sequential(
nn.ReLU(),
nn.Conv1d(channels, channels, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels, output_size, kernel_size=1)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, input_size, time)
h = self.input_conv(x)

skip_sum = 0
for block in self.blocks:
h, skip = block(h)
skip_sum = skip_sum + skip

return self.output_net(skip_sum)


# Test WaveNet
batch_size, channels, seq_len = 4, 1, 256
x_test = torch.randn(batch_size, channels, seq_len)
model = SimpleWaveNet(input_size=1, channels=16, n_blocks=2, n_layers_per_block=4)
y_test = model(x_test)

# Compute receptive field
n_layers = 2 * 4 # n_blocks * n_layers_per_block
receptive_field = 1 + sum(2**l for l in range(4)) * 2 # 2 repeating blocks
print(f"WaveNet receptive field: ~{receptive_field} samples")
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {y_test.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Wavelet Pooling in Neural Networks

Replacing max-pooling with wavelet-based pooling gives shift-invariant downsampling:

class WaveletPool1D(nn.Module):
"""
Wavelet-based downsampling: preserves more information than max/avg pooling.
Uses the approximation coefficients from one level of DWT.

Advantages over max-pooling:
- Shift invariant (reduces aliasing artifacts)
- Preserves low-frequency content (approximation coefficients)
- Antialiasing by design (wavelet filter is low-pass)
"""

def __init__(self, wavelet: str = 'db2'):
super().__init__()
self.wavelet = wavelet
# Get wavelet filter coefficients
w = pywt.Wavelet(wavelet)
self.lo_d = torch.tensor(w.dec_lo[::-1], dtype=torch.float32)
self.hi_d = torch.tensor(w.dec_hi[::-1], dtype=torch.float32)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input: (batch, channels, T)
Output: (batch, channels, T//2) - approximation coefficients only
"""
batch, channels, T = x.shape
filter_len = len(self.lo_d)

# Reshape filter for grouped convolution (apply same filter to each channel)
lo_filter = self.lo_d.view(1, 1, -1).expand(channels, 1, -1)

# Pad for same-length convolution
pad = filter_len - 1
x_padded = torch.nn.functional.pad(x, (pad, 0))

# Low-pass filter (approximation) + downsample by 2
approx = torch.nn.functional.conv1d(
x_padded,
lo_filter,
stride=2,
groups=channels
)
return approx

Choosing the Right Wavelet

WaveletPropertiesBest for
HaarStep-like, compact supportStep detection, simple compression
Daubechies db2–db20Smooth, N vanishing momentsSmooth signals, compression
Symlet sym2–sym20Nearly symmetric db variantImages, signals requiring symmetry
Coiflet coif1–coif5Near-symmetric, signal and wavelet have N vanishing momentsSmooth slowly varying signals
MorletComplex sinusoid × GaussianOscillatory signals (EEG, audio, vibration)
Mexican Hat2nd derivative of GaussianPeaks, ridges, blob detection
Biorthogonal biorDifferent analysis/synthesis filtersCompression (JPEG2000), perfect reconstruction

Vanishing moments NN: a wavelet has NN vanishing moments if tkψ(t)dt=0\int t^k \psi(t)dt = 0 for k=0,1,,N1k = 0, 1, \ldots, N-1. More vanishing moments → polynomials up to degree N1N-1 produce zero detail coefficients → better compression for smooth signals.

def select_wavelet(
signal: np.ndarray,
candidate_wavelets: list = None,
levels: int = 4,
metric: str = 'energy_compaction'
) -> str:
"""
Select the best wavelet for a signal based on energy compaction.

Energy compaction: good wavelets concentrate energy in few large coefficients.
Measure: percentage of energy in the top 10% of coefficients.
Higher is better (more compact representation).
"""
if candidate_wavelets is None:
candidate_wavelets = ['haar', 'db2', 'db4', 'db8', 'sym4', 'sym8', 'coif3']

results = {}
for wname in candidate_wavelets:
try:
coeffs = pywt.wavedec(signal, wavelet=wname, level=levels)
all_coefs = np.concatenate([c for c in coeffs])

if metric == 'energy_compaction':
# What fraction of energy is in the top 10% of coefficients?
energy = all_coefs**2
total_energy = np.sum(energy)
n_top = max(1, len(all_coefs) // 10)
top_energy = np.sum(np.sort(energy)[-n_top:])
results[wname] = top_energy / total_energy

elif metric == 'sparsity':
# What fraction of coefficients are "significant" (above threshold)?
threshold = 0.01 * np.max(np.abs(all_coefs))
results[wname] = 1.0 - np.mean(np.abs(all_coefs) > threshold)

except Exception:
pass

best_wavelet = max(results, key=results.get)

print(f"Wavelet selection (metric: {metric}):")
for name, score in sorted(results.items(), key=lambda x: x[1], reverse=True):
marker = " ← BEST" if name == best_wavelet else ""
print(f" {name:<12}: {score:.4f}{marker}")

return best_wavelet

print("Selecting wavelet for smooth sinusoidal signal:")
smooth_signal = np.sin(2 * np.pi * 5 * np.linspace(0, 2, 512))
best_w = select_wavelet(smooth_signal)

print("\nSelecting wavelet for step signal:")
step_signal = np.where(np.linspace(0, 2, 512) > 1, 1.0, 0.0)
best_w_step = select_wavelet(step_signal)

Interview Questions

Q1: What is the difference between Fourier analysis and wavelet analysis? When do you use each?

Fourier analysis decomposes a signal into sine waves of different frequencies. It gives a global frequency representation - tells you which frequencies exist in the entire signal, but not when they appear. The Fourier transform is ideal for stationary signals where frequency content doesn't change over time.

Wavelet analysis decomposes a signal into scaled and translated copies of a localized "mother wavelet." It gives simultaneous time-frequency information - at each time and scale, how much of that frequency component exists at that moment.

Key mathematical difference: Fourier basis functions (eiωte^{i\omega t}) have infinite support (extend over all time); wavelets have finite support (localized in time). This gives wavelets time localization that Fourier lacks.

Heisenberg uncertainty tradeoff: Both obey ΔtΔf1/4π\Delta t \cdot \Delta f \geq 1/4\pi.

  • Fourier: Δt=,Δf=0\Delta t = \infty, \Delta f = 0 (infinite time support, perfect frequency localization)
  • STFT: fixed window → fixed Δt\Delta t, Δf\Delta f for all frequencies
  • Wavelets: Δt\Delta t scales with frequency (shorter at high frequency, longer at low) → better overall time-frequency resolution

When to use Fourier:

  • Stationary signals (constant frequency content)
  • Periodic signals needing precise frequency identification
  • Fast convolution (convolution theorem)
  • When only global spectrum matters

When to use wavelets:

  • Non-stationary signals (ECG, speech, seismic)
  • Signal denoising (sparse representation + thresholding)
  • Multi-scale analysis where events occur at different temporal scales
  • Image compression (JPEG2000 uses wavelets)
Q2: Explain wavelet denoising. Why does it work?

Wavelet denoising works because signal and noise have fundamentally different representations in the wavelet domain:

Signal (smooth or piecewise regular): Most energy is concentrated in a few large wavelet coefficients at relevant scales. A smooth sine wave has very sparse wavelet coefficients; a sharp edge has a few large coefficients at fine scales but zero at coarse scales.

Gaussian white noise: Energy is spread uniformly across ALL wavelet coefficients at ALL scales and times. Each coefficient gets a small amount of noise energy.

The algorithm (Donoho-Johnstone, 1994):

  1. DWT: transform noisy signal y=f+ϵy = f + \epsilon to wavelet domain: y~=f~+ϵ~\tilde{y} = \tilde{f} + \tilde{\epsilon}
  2. Threshold: apply f~^=Tλ(y~)\hat{\tilde{f}} = T_\lambda(\tilde{y}) where TλT_\lambda is soft or hard thresholding
    • Hard: zero out y~<λ|\tilde{y}| < \lambda, keep others
    • Soft: zero out y~<λ|\tilde{y}| < \lambda, shrink others by λ\lambda
  3. Inverse DWT: reconstruct f^=W1(f~^)\hat{f} = W^{-1}(\hat{\tilde{f}})

Universal threshold: λ=σ^2lnN\lambda^* = \hat{\sigma}\sqrt{2\ln N} (minimax rate-optimal for Gaussian noise)

Why it works theoretically: Wavelets provide an approximately sparse representation of signals in common function classes (Hölder continuous, piecewise smooth, Sobolev spaces). Thresholding in a sparse basis is a near-optimal denoising procedure.

Soft vs hard:

  • Hard thresholding: can produce "ringing" artifacts around kept coefficients (discontinuous threshold function)
  • Soft thresholding: smoother reconstruction, introduces bias (subtracts λ\lambda from all kept coefficients), but better MSE in practice
Q3: How are WaveNet's dilated causal convolutions related to wavelets?

WaveNet (van den Oord et al., 2016) uses dilated causal convolutions with exponentially increasing dilation factors: 1, 2, 4, 8, ..., 2K12^{K-1}. This creates a dyadic hierarchy - exactly the scale structure of the DWT.

Structural parallels:

  • DWT level jj: filters the signal at scale 2j2^j, sees patterns of length 2j2^j samples

  • WaveNet dilation 2j2^j: convolution sees samples at positions t,t2j,t22j,t, t-2^j, t-2\cdot2^j, \ldots - patterns of temporal scale 2j2^j

  • DWT filters: fixed, analytically designed (Daubechies filters optimize smoothness/vanishing moments)

  • WaveNet filters: learned end-to-end from data - the network discovers what "wavelet-like" filters are best for audio generation

  • DWT approximation: low-pass filter + downsample

  • WaveNet stacked dilations: same time resolution throughout but increasing receptive field (no downsampling - necessary for generative models)

Key difference: DWT downsamples (reduces temporal resolution at each level); WaveNet doesn't (maintains full resolution for generation). Wavelets are a fixed basis; WaveNet learns a data-adaptive one.

Practical implication: The intuition that "audio has temporal structure at multiple scales" - pitch (milliseconds), phonemes (tens of ms), words (hundreds of ms) - is directly encoded in WaveNet's architecture via the dyadic dilation structure.

Q4: What are vanishing moments in wavelets, and why do they matter for compression?

A wavelet ψ\psi has NN vanishing moments if:

tkψ(t)dt=0for k=0,1,,N1\int_{-\infty}^{\infty} t^k \psi(t)\, dt = 0 \quad \text{for } k = 0, 1, \ldots, N-1

Practically: ψ\psi is orthogonal to all polynomials of degree <N< N.

Why this matters for compression:

If a signal region is well-approximated by a polynomial of degree <N< N, then all wavelet detail coefficients in that region are (approximately) zero. A smooth signal segment → few non-zero wavelet coefficients → sparse representation → high compression ratio.

Example: Haar wavelet has 1 vanishing moment. It produces zero detail coefficients for constant signal segments but non-zero for linear trends. Daubechies db4 has 4 vanishing moments. It produces zero detail coefficients for linear, quadratic, and cubic polynomial segments - much better for smooth signals.

Practical consequence:

  • Images and audio have many smooth regions (locally polynomial) → db4+ gives excellent compression
  • JPEG2000 uses biorthogonal 9/7 wavelet with 4 vanishing moments
  • For signals with step discontinuities (seismic, fault signals): fewer vanishing moments (even Haar) may suffice and gives shorter support (better time localization)

Trade-off: More vanishing moments = better compression of smooth signals, but longer filter support = larger filter → more computation and edge artifacts for finite-length signals.

Q5: How would you use wavelets for anomaly detection in a multivariate time series?

Wavelet-based anomaly detection pipeline:

  1. Decompose each channel: Apply DWT to each time series dimension. Use an appropriate wavelet (Daubechies db4 for smooth signals, db2 or Haar for step-like).

  2. Extract subband features: For each detail level jj and sliding window:

    • Energy: Ej=kdj[k]2E_j = \sum_k |d_j[k]|^2
    • Entropy: Hj=kpj[k]logpj[k]H_j = -\sum_k p_j[k] \log p_j[k]
    • Max coefficient: maxkdj[k]\max_k |d_j[k]|
  3. Build normal behavior profile: Train on normal operation data. For each subband, fit a multivariate Gaussian (or more robust: KDE, OCSVM, Isolation Forest) to the feature distribution.

  4. Anomaly scoring: For new windows, compute the Mahalanobis distance or log-likelihood from the normal distribution in wavelet feature space. High distance = anomaly.

  5. Multi-scale localization: Different faults manifest at different scales:

    • Bearing fault (150 Hz): anomaly in fine-scale detail coefficients (level 1-2)
    • Imbalance (fundamental frequency, e.g., 25 Hz): anomaly in mid-scale (level 3-4)
    • Overheating trend (minutes-long): anomaly in coarse approximation

Why wavelets beat raw signal anomaly detection:

  • Decompose the fault signature from the baseline signal (noise + normal vibration)
  • Faults are localized in specific frequency bands - wavelet subbands isolate these
  • More robust than FFT-based methods when fault timing is unknown (non-stationary)

Real-world applications: Rolling bearing fault detection, wind turbine gearbox monitoring, motor current signature analysis, seismic event detection.

Key Takeaways

  • Wavelets = localized oscillations: simultaneously localized in time AND frequency, unlike Fourier basis
  • CWT gives a continuous time-scale (time-frequency) map - the scalogram
  • DWT gives a discrete, orthonormal, multi-resolution decomposition: approximation + detail at each level
  • Multiresolution Analysis: DWT splits signal into coarse approximation (low-freq) and fine details (high-freq) recursively
  • Wavelet denoising: signal is sparse in wavelet domain; noise is dense → threshold detail coefficients to remove noise
  • More vanishing moments → better compression of smooth signals; longer filter support
  • WaveNet uses dyadic dilated convolutions - the same scale hierarchy as DWT, but with learned filters
  • Choose wavelet based on signal characteristics: Daubechies for smooth signals, Haar for steps, Morlet for oscillatory

This completes Module 10: Time Series Mathematics.

Return to Module Overview →

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Wavelet Transform demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.