Source code for endgame.signal.wavelets
from __future__ import annotations
"""Wavelet transform module for signal processing.
Provides sklearn-compatible wavelet transforms:
- Continuous Wavelet Transform (CWT)
- Discrete Wavelet Transform (DWT)
- Wavelet Packet Decomposition
- Wavelet-based feature extraction
Requires: PyWavelets (pywt)
References
----------
- PyWavelets: https://pywavelets.readthedocs.io/
- MNE-Python: Time-frequency analysis
"""
import numpy as np
from endgame.signal.base import (
BaseFeatureExtractor,
BaseSignalTransformer,
ensure_2d_signals,
)
# Import pywt (optional dependency)
try:
import pywt
HAS_PYWT = True
except ImportError:
HAS_PYWT = False
def _check_pywt():
"""Check if PyWavelets is available."""
if not HAS_PYWT:
raise ImportError(
"PyWavelets is required for wavelet transforms. "
"Install with: pip install PyWavelets"
)
[docs]
class CWTTransformer(BaseSignalTransformer):
"""Continuous Wavelet Transform for time-frequency analysis.
Computes CWT using complex Morlet or other wavelets,
producing a time-frequency representation.
Parameters
----------
fs : float
Sample rate in Hz.
wavelet : str, default='morl'
Wavelet to use. Options: 'morl' (Morlet), 'cmor' (complex Morlet),
'cgau' (complex Gaussian), 'mexh' (Mexican hat), 'gaus'.
freqs : array-like, optional
Frequencies to analyze in Hz. If None, uses logarithmically
spaced frequencies from 1 Hz to Nyquist.
n_freqs : int, default=32
Number of frequency bins if freqs is None.
output : str, default='power'
Output type: 'complex', 'magnitude', 'power', 'phase'.
normalize : bool, default=True
Whether to normalize coefficients.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
scales_ : np.ndarray
Wavelet scales used.
freqs_ : np.ndarray
Corresponding frequencies in Hz.
Examples
--------
>>> cwt = CWTTransformer(fs=256, freqs=np.arange(1, 50))
>>> tfr = cwt.fit_transform(signal) # (n_freqs, n_samples)
"""
def __init__(
self,
fs: float,
wavelet: str = "morl",
freqs: np.ndarray | None = None,
n_freqs: int = 32,
output: str = "power",
normalize: bool = True,
copy: bool = True,
):
_check_pywt()
super().__init__(fs=fs, copy=copy)
self.wavelet = wavelet
self.freqs = freqs
self.n_freqs = n_freqs
self.output = output
self.normalize = normalize
if output not in ("complex", "magnitude", "power", "phase"):
raise ValueError(
f"output must be 'complex', 'magnitude', 'power', or 'phase', "
f"got {output}"
)
[docs]
def fit(self, X, y=None, **fit_params) -> CWTTransformer:
"""Fit the transformer (compute scales).
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
# Determine frequencies
if self.freqs is not None:
self.freqs_ = np.asarray(self.freqs)
else:
# Logarithmically spaced from 1 Hz to Nyquist/2
nyquist = self.fs / 2
self.freqs_ = np.logspace(0, np.log10(nyquist / 2), self.n_freqs)
# Convert frequencies to scales
# For Morlet wavelet: scale = (center_frequency * fs) / freq
center_freq = pywt.central_frequency(self.wavelet)
self.scales_ = center_freq * self.fs / self.freqs_
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute CWT of signal.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
CWT coefficients. Shape depends on input:
- 1D input: (n_freqs, n_samples)
- 2D input: (n_channels, n_freqs, n_samples)
- 3D input: (n_trials, n_channels, n_freqs, n_samples)
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
results = []
for sig in X_2d:
coeffs, _ = pywt.cwt(sig, self.scales_, self.wavelet, 1 / self.fs)
# Normalize
if self.normalize:
coeffs = coeffs / np.sqrt(self.scales_[:, np.newaxis])
# Convert to requested output
if self.output == "complex":
result = coeffs
elif self.output == "magnitude":
result = np.abs(coeffs)
elif self.output == "power":
result = np.abs(coeffs) ** 2
elif self.output == "phase":
result = np.angle(coeffs)
results.append(result)
results = np.array(results)
# Reshape output
if was_1d:
return results[0] # (n_freqs, n_samples)
elif len(original_shape) == 2:
return results # (n_channels, n_freqs, n_samples)
else:
# Reshape for 3D input
n_trials, n_channels, _ = original_shape
n_freqs, n_samples = results.shape[1], results.shape[2]
return results.reshape(n_trials, n_channels, n_freqs, n_samples)
[docs]
def get_frequencies(self) -> np.ndarray:
"""Get the analysis frequencies."""
self._check_is_fitted()
return self.freqs_
[docs]
class DWTTransformer(BaseSignalTransformer):
"""Discrete Wavelet Transform for multi-resolution analysis.
Decomposes signal into approximation and detail coefficients
at multiple scales.
Parameters
----------
fs : float, optional
Sample rate in Hz.
wavelet : str, default='db4'
Wavelet to use. Options: 'db1'-'db20', 'sym2'-'sym20',
'coif1'-'coif5', 'bior', 'rbio', etc.
level : int, optional
Decomposition level. If None, uses maximum level.
mode : str, default='symmetric'
Signal extension mode: 'symmetric', 'periodic', 'reflect', etc.
output : str, default='coeffs'
Output type: 'coeffs' (all coefficients), 'detail' (detail only),
'approx' (approximation only).
copy : bool, default=True
Whether to copy input data.
Attributes
----------
level_ : int
Actual decomposition level used.
coeff_lengths_ : list of int
Length of each coefficient array.
Examples
--------
>>> dwt = DWTTransformer(wavelet='db4', level=4)
>>> coeffs = dwt.fit_transform(signal)
"""
def __init__(
self,
fs: float | None = None,
wavelet: str = "db4",
level: int | None = None,
mode: str = "symmetric",
output: str = "coeffs",
copy: bool = True,
):
_check_pywt()
super().__init__(fs=fs, copy=copy)
self.wavelet = wavelet
self.level = level
self.mode = mode
self.output = output
if output not in ("coeffs", "detail", "approx"):
raise ValueError(
f"output must be 'coeffs', 'detail', or 'approx', got {output}"
)
[docs]
def fit(self, X, y=None, **fit_params) -> DWTTransformer:
"""Fit the transformer.
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
# Determine decomposition level
if self.level is not None:
self.level_ = self.level
else:
self.level_ = pywt.dwt_max_level(n_samples, self.wavelet)
# Store coefficient structure by doing a trial decomposition
trial_coeffs = pywt.wavedec(
np.zeros(n_samples), self.wavelet, mode=self.mode, level=self.level_
)
self.coeff_lengths_ = [len(c) for c in trial_coeffs]
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute DWT of signal.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
DWT coefficients concatenated into a single array.
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
results = []
for sig in X_2d:
coeffs = pywt.wavedec(sig, self.wavelet, mode=self.mode, level=self.level_)
if self.output == "coeffs":
# Concatenate all coefficients
result = np.concatenate(coeffs)
elif self.output == "approx":
# Only approximation coefficients
result = coeffs[0]
elif self.output == "detail":
# Only detail coefficients
result = np.concatenate(coeffs[1:])
results.append(result)
result = np.array(results)
if was_1d:
return result.flatten()
elif len(original_shape) == 3:
n_trials, n_channels, _ = original_shape
n_coeffs = result.shape[-1]
return result.reshape(n_trials, n_channels, n_coeffs)
return result
[docs]
def inverse_transform(self, coeffs) -> np.ndarray:
"""Reconstruct signal from DWT coefficients.
Parameters
----------
coeffs : array-like
DWT coefficients (concatenated).
Returns
-------
np.ndarray
Reconstructed signal.
"""
self._check_is_fitted()
if self.output != "coeffs":
raise ValueError("inverse_transform only works with output='coeffs'")
coeffs = np.asarray(coeffs)
coeffs_2d, was_1d, original_shape = ensure_2d_signals(coeffs)
results = []
for coeff_vec in coeffs_2d:
# Split concatenated coefficients
coeff_list = self._split_coeffs(coeff_vec)
# Reconstruct
sig = pywt.waverec(coeff_list, self.wavelet, mode=self.mode)
results.append(sig)
result = np.array(results)
if was_1d:
return result.flatten()
return result
def _split_coeffs(self, coeff_vec: np.ndarray) -> list[np.ndarray]:
"""Split concatenated coefficients back into list."""
# This is a simplified split - actual lengths depend on signal length
# For proper reconstruction, we'd need to store coefficient lengths
coeffs = pywt.wavedec(
np.zeros(self.n_samples_seen_), self.wavelet, mode=self.mode, level=self.level_
)
result = []
idx = 0
for c in coeffs:
length = len(c)
result.append(coeff_vec[idx : idx + length])
idx += length
return result
[docs]
class WaveletPacketTransformer(BaseSignalTransformer):
"""Wavelet Packet Decomposition for full frequency resolution.
Unlike DWT which only decomposes approximation coefficients,
wavelet packets decompose both approximation and detail
coefficients at each level.
Parameters
----------
fs : float, optional
Sample rate in Hz.
wavelet : str, default='db4'
Wavelet to use.
level : int, optional
Decomposition level. If None, uses maximum level.
mode : str, default='symmetric'
Signal extension mode.
order : str, default='freq'
Node ordering: 'freq' (frequency order) or 'natural'.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
level_ : int
Actual decomposition level used.
n_nodes_ : int
Number of terminal nodes (2^level).
Examples
--------
>>> wp = WaveletPacketTransformer(wavelet='db4', level=3)
>>> coeffs = wp.fit_transform(signal) # (8, n_coeffs)
"""
def __init__(
self,
fs: float | None = None,
wavelet: str = "db4",
level: int | None = None,
mode: str = "symmetric",
order: str = "freq",
copy: bool = True,
):
_check_pywt()
super().__init__(fs=fs, copy=copy)
self.wavelet = wavelet
self.level = level
self.mode = mode
self.order = order
[docs]
def fit(self, X, y=None, **fit_params) -> WaveletPacketTransformer:
"""Fit the transformer.
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
# Determine decomposition level
if self.level is not None:
self.level_ = self.level
else:
self.level_ = min(4, pywt.dwt_max_level(n_samples, self.wavelet))
self.n_nodes_ = 2**self.level_
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute wavelet packet decomposition.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
Wavelet packet coefficients.
Shape: (n_signals, n_nodes, n_coeffs_per_node)
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
results = []
for sig in X_2d:
wp = pywt.WaveletPacket(sig, self.wavelet, mode=self.mode)
# Get terminal nodes at specified level
nodes = [
node.path for node in wp.get_level(self.level_, order=self.order)
]
# Extract coefficients from each node
node_coeffs = []
for path in nodes:
node_coeffs.append(wp[path].data)
# Pad to same length
max_len = max(len(c) for c in node_coeffs)
padded = [np.pad(c, (0, max_len - len(c))) for c in node_coeffs]
results.append(np.array(padded))
result = np.array(results)
if was_1d:
return result[0] # (n_nodes, n_coeffs)
return result
[docs]
def get_frequency_bands(self) -> list[tuple[float, float]]:
"""Get frequency bands for each wavelet packet node.
Returns
-------
list of (low, high) tuples
Frequency bands in Hz for each node.
"""
self._check_is_fitted()
if self.fs is None:
raise ValueError("fs must be specified to get frequency bands")
nyquist = self.fs / 2
band_width = nyquist / self.n_nodes_
bands = []
for i in range(self.n_nodes_):
if self.order == "freq":
# Gray code ordering for frequency order
low = i * band_width
high = (i + 1) * band_width
else:
low = i * band_width
high = (i + 1) * band_width
bands.append((low, high))
return bands
[docs]
class WaveletFeatureExtractor(BaseFeatureExtractor):
"""Extract features from wavelet coefficients.
Computes statistical features from wavelet decomposition
at multiple scales.
Parameters
----------
fs : float, optional
Sample rate in Hz.
wavelet : str, default='db4'
Wavelet to use.
level : int, optional
Decomposition level.
features : list of str, optional
Features to extract. Default: ['energy', 'mean', 'std', 'entropy'].
use_packet : bool, default=False
If True, use wavelet packet decomposition.
Attributes
----------
feature_names_ : list of str
Names of extracted features.
Examples
--------
>>> extractor = WaveletFeatureExtractor(wavelet='db4', level=4)
>>> features = extractor.fit_transform(signals)
"""
DEFAULT_FEATURES = ["energy", "mean", "std", "entropy"]
def __init__(
self,
fs: float | None = None,
wavelet: str = "db4",
level: int | None = None,
features: list[str] | None = None,
use_packet: bool = False,
):
_check_pywt()
super().__init__(fs=fs)
self.wavelet = wavelet
self.level = level
self.features = features if features is not None else self.DEFAULT_FEATURES.copy()
self.use_packet = use_packet
[docs]
def fit(self, X, y=None, **fit_params) -> WaveletFeatureExtractor:
"""Fit the extractor.
Parameters
----------
X : array-like
Input signals.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
# Determine level
if self.level is not None:
self.level_ = self.level
else:
self.level_ = min(4, pywt.dwt_max_level(n_samples, self.wavelet))
# Number of coefficient sets
if self.use_packet:
self.n_coeff_sets_ = 2**self.level_
self._band_names = [f"wp_{i}" for i in range(self.n_coeff_sets_)]
else:
self.n_coeff_sets_ = self.level_ + 1 # approx + details
self._band_names = ["approx"] + [f"detail_{i}" for i in range(1, self.level_ + 1)]
# Build feature names
self.feature_names_ = []
for band in self._band_names:
for feat in self.features:
self.feature_names_.append(f"wavelet_{band}_{feat}")
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Extract wavelet features.
Parameters
----------
X : array-like
Input signals.
Returns
-------
np.ndarray of shape (n_samples, n_features)
Extracted features.
"""
X = self._validate_signal(X)
X_2d, was_1d, original_shape = ensure_2d_signals(X)
all_features = []
for sig in X_2d:
if self.use_packet:
coeff_sets = self._get_packet_coeffs(sig)
else:
coeff_sets = pywt.wavedec(sig, self.wavelet, level=self.level_)
sig_features = []
for coeffs in coeff_sets:
for feat in self.features:
sig_features.append(self._compute_feature(coeffs, feat))
all_features.append(sig_features)
return np.array(all_features)
def _get_packet_coeffs(self, x: np.ndarray) -> list[np.ndarray]:
"""Get wavelet packet coefficients."""
wp = pywt.WaveletPacket(x, self.wavelet, mode="symmetric")
nodes = [node.path for node in wp.get_level(self.level_, order="freq")]
return [wp[path].data for path in nodes]
def _compute_feature(self, coeffs: np.ndarray, feature: str) -> float:
"""Compute a single feature from coefficients."""
if feature == "energy":
return np.sum(coeffs**2)
elif feature == "mean":
return np.mean(coeffs)
elif feature == "std":
return np.std(coeffs)
elif feature == "entropy":
# Shannon entropy of normalized squared coefficients
squared = coeffs**2
normalized = squared / (np.sum(squared) + 1e-10)
return -np.sum(normalized * np.log2(normalized + 1e-10))
elif feature == "max":
return np.max(np.abs(coeffs))
elif feature == "min":
return np.min(np.abs(coeffs))
elif feature == "range":
return np.ptp(coeffs)
else:
raise ValueError(f"Unknown feature: {feature}")
[docs]
def compute_cwt(
x: np.ndarray,
fs: float,
freqs: np.ndarray | None = None,
wavelet: str = "morl",
) -> tuple[np.ndarray, np.ndarray]:
"""Convenience function to compute CWT.
Parameters
----------
x : np.ndarray
Input signal.
fs : float
Sample rate in Hz.
freqs : np.ndarray, optional
Frequencies to analyze.
wavelet : str, default='morl'
Wavelet to use.
Returns
-------
freqs : np.ndarray
Analysis frequencies.
coeffs : np.ndarray
CWT coefficients (power).
Examples
--------
>>> freqs, power = compute_cwt(signal, fs=256)
"""
_check_pywt()
cwt = CWTTransformer(fs=fs, freqs=freqs, wavelet=wavelet, output="power")
cwt.fit(x)
coeffs = cwt.transform(x)
return cwt.freqs_, coeffs
[docs]
def compute_dwt(
x: np.ndarray,
wavelet: str = "db4",
level: int | None = None,
) -> list[np.ndarray]:
"""Convenience function to compute DWT.
Parameters
----------
x : np.ndarray
Input signal.
wavelet : str, default='db4'
Wavelet to use.
level : int, optional
Decomposition level.
Returns
-------
list of np.ndarray
Coefficients [approx, detail1, detail2, ...].
Examples
--------
>>> coeffs = compute_dwt(signal, wavelet='db4', level=4)
>>> approx, d1, d2, d3, d4 = coeffs
"""
_check_pywt()
return pywt.wavedec(x, wavelet, level=level)
[docs]
def reconstruct_from_dwt(
coeffs: list[np.ndarray],
wavelet: str = "db4",
) -> np.ndarray:
"""Reconstruct signal from DWT coefficients.
Parameters
----------
coeffs : list of np.ndarray
Coefficients [approx, detail1, detail2, ...].
wavelet : str, default='db4'
Wavelet used for decomposition.
Returns
-------
np.ndarray
Reconstructed signal.
"""
_check_pywt()
return pywt.waverec(coeffs, wavelet)