Source code for endgame.signal.spectral
from __future__ import annotations
"""Spectral analysis transformers for signal processing.
Provides sklearn-compatible spectral analysis methods:
- FFT for frequency spectrum
- Welch's method for PSD estimation
- Multitaper PSD for high-resolution spectral analysis
- Band power extraction for frequency band features
References
----------
- scipy.signal: Welch PSD, periodogram
- MNE-Python: Multitaper implementation
- neurokit2: Band power calculations
"""
from typing import Any
import numpy as np
from scipy import signal
from scipy.fft import fft, fftfreq, rfft, rfftfreq
from endgame.signal.base import (
BaseFeatureExtractor,
BaseSignalTransformer,
ensure_2d_signals,
)
[docs]
class FFTTransformer(BaseSignalTransformer):
"""Fast Fourier Transform for frequency spectrum analysis.
Computes the FFT of input signals, returning either the full
complex spectrum, magnitude, power, or phase.
Parameters
----------
fs : float
Sample rate in Hz.
output : str, default='magnitude'
Output type: 'complex', 'magnitude', 'power', 'phase', 'db'.
one_sided : bool, default=True
If True, return only positive frequencies (real FFT).
n_fft : int, optional
FFT length. If None, uses signal length.
normalize : bool, default=True
If True, normalize FFT by signal length.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
freqs_ : np.ndarray
Frequency bins after fit.
Examples
--------
>>> fft_trans = FFTTransformer(fs=256, output='magnitude')
>>> spectrum = fft_trans.fit_transform(signal)
>>> freqs = fft_trans.freqs_
"""
def __init__(
self,
fs: float,
output: str = "magnitude",
one_sided: bool = True,
n_fft: int | None = None,
normalize: bool = True,
copy: bool = True,
):
super().__init__(fs=fs, copy=copy)
self.output = output
self.one_sided = one_sided
self.n_fft = n_fft
self.normalize = normalize
if output not in ("complex", "magnitude", "power", "phase", "db"):
raise ValueError(
f"output must be 'complex', 'magnitude', 'power', 'phase', or 'db', "
f"got {output}"
)
[docs]
def fit(self, X, y=None, **fit_params) -> FFTTransformer:
"""Fit the transformer (compute frequency bins).
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
n_fft = self.n_fft if self.n_fft is not None else n_samples
if self.one_sided:
self.freqs_ = rfftfreq(n_fft, 1 / self.fs)
else:
self.freqs_ = fftfreq(n_fft, 1 / self.fs)
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute FFT of signal.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
FFT result based on output parameter.
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
n_signals, n_samples = X_2d.shape
n_fft = self.n_fft if self.n_fft is not None else n_samples
# Compute FFT
if self.one_sided:
fft_result = rfft(X_2d, n=n_fft, axis=-1)
else:
fft_result = fft(X_2d, n=n_fft, axis=-1)
# Normalize
if self.normalize:
fft_result = fft_result / n_samples
# Convert to requested output
if self.output == "complex":
result = fft_result
elif self.output == "magnitude":
result = np.abs(fft_result)
elif self.output == "power":
result = np.abs(fft_result) ** 2
elif self.output == "phase":
result = np.angle(fft_result)
elif self.output == "db":
result = 20 * np.log10(np.abs(fft_result) + 1e-10)
# Restore shape for 3D input
if len(original_shape) == 3:
n_trials, n_channels, _ = original_shape
n_freqs = result.shape[-1]
result = result.reshape(n_trials, n_channels, n_freqs)
elif was_1d:
result = result.flatten()
return result
[docs]
def get_frequency_resolution(self) -> float:
"""Get frequency resolution in Hz."""
self._check_is_fitted()
return self.freqs_[1] - self.freqs_[0]
[docs]
class WelchPSD(BaseSignalTransformer):
"""Welch's method for Power Spectral Density estimation.
Estimates PSD by averaging modified periodograms from
overlapping segments.
Parameters
----------
fs : float
Sample rate in Hz.
nperseg : int, optional
Length of each segment. Default is 256 or signal length.
noverlap : int, optional
Number of points to overlap. Default is nperseg // 2.
nfft : int, optional
FFT length. Default is nperseg.
window : str or tuple, default='hann'
Window function to use.
detrend : str or False, default='constant'
Detrending method: 'constant', 'linear', or False.
scaling : str, default='density'
'density' for V^2/Hz, 'spectrum' for V^2.
average : str, default='mean'
Averaging method: 'mean' or 'median'.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
freqs_ : np.ndarray
Frequency bins after fit.
Examples
--------
>>> psd = WelchPSD(fs=256, nperseg=256)
>>> power = psd.fit_transform(signal)
>>> freqs = psd.freqs_
"""
def __init__(
self,
fs: float,
nperseg: int | None = None,
noverlap: int | None = None,
nfft: int | None = None,
window: str = "hann",
detrend: str | bool = "constant",
scaling: str = "density",
average: str = "mean",
copy: bool = True,
):
super().__init__(fs=fs, copy=copy)
self.nperseg = nperseg
self.noverlap = noverlap
self.nfft = nfft
self.window = window
self.detrend = detrend
self.scaling = scaling
self.average = average
[docs]
def fit(self, X, y=None, **fit_params) -> WelchPSD:
"""Fit the transformer.
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
nperseg = self.nperseg if self.nperseg is not None else min(256, n_samples)
# Compute frequencies
self.freqs_, _ = signal.welch(
np.zeros(n_samples),
fs=self.fs,
nperseg=nperseg,
noverlap=self.noverlap,
nfft=self.nfft,
window=self.window,
)
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute PSD using Welch's method.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
Power spectral density.
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
n_samples = X_2d.shape[1]
nperseg = self.nperseg if self.nperseg is not None else min(256, n_samples)
results = []
for sig in X_2d:
freqs, psd = signal.welch(
sig,
fs=self.fs,
nperseg=nperseg,
noverlap=self.noverlap,
nfft=self.nfft,
window=self.window,
detrend=self.detrend,
scaling=self.scaling,
average=self.average,
)
results.append(psd)
result = np.array(results)
# Restore shape for 3D input
if len(original_shape) == 3:
n_trials, n_channels, _ = original_shape
n_freqs = result.shape[-1]
result = result.reshape(n_trials, n_channels, n_freqs)
elif was_1d:
result = result.flatten()
return result
[docs]
def get_frequency_bands(
self,
X: np.ndarray,
bands: dict[str, tuple[float, float]],
) -> dict[str, np.ndarray]:
"""Get PSD values within specified frequency bands.
Parameters
----------
X : np.ndarray
Input signal.
bands : dict
Dictionary mapping band names to (low, high) frequency tuples.
Returns
-------
dict
Dictionary mapping band names to PSD arrays within each band.
"""
psd = self.transform(X)
freqs = self.freqs_
band_psds = {}
for name, (low, high) in bands.items():
mask = (freqs >= low) & (freqs <= high)
if psd.ndim == 1:
band_psds[name] = psd[mask]
else:
band_psds[name] = psd[..., mask]
return band_psds
[docs]
class MultitaperPSD(BaseSignalTransformer):
"""Multitaper Power Spectral Density estimation.
Uses multiple orthogonal tapers (DPSS/Slepian sequences) to reduce
variance in PSD estimation while maintaining frequency resolution.
Parameters
----------
fs : float
Sample rate in Hz.
bandwidth : float, optional
Frequency bandwidth of the tapers in Hz.
Default is 4 * fs / n_samples (NW=4).
n_tapers : int, optional
Number of tapers to use. Default is 2 * bandwidth - 1.
low_bias : bool, default=True
Only use tapers with concentration ratio > 0.9.
adaptive : bool, default=False
Use adaptive weighting for combining tapers.
normalization : str, default='full'
PSD normalization: 'full', 'length'.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
freqs_ : np.ndarray
Frequency bins after fit.
dpss_ : np.ndarray
DPSS tapers used.
eigenvalues_ : np.ndarray
Eigenvalues of the tapers.
References
----------
Thomson, D.J. (1982). Spectrum estimation and harmonic analysis.
Proceedings of the IEEE, 70(9), 1055-1096.
Examples
--------
>>> mt_psd = MultitaperPSD(fs=256, bandwidth=4)
>>> power = mt_psd.fit_transform(signal)
"""
def __init__(
self,
fs: float,
bandwidth: float | None = None,
n_tapers: int | None = None,
low_bias: bool = True,
adaptive: bool = False,
normalization: str = "full",
copy: bool = True,
):
super().__init__(fs=fs, copy=copy)
self.bandwidth = bandwidth
self.n_tapers = n_tapers
self.low_bias = low_bias
self.adaptive = adaptive
self.normalization = normalization
[docs]
def fit(self, X, y=None, **fit_params) -> MultitaperPSD:
"""Fit the transformer (compute DPSS tapers).
Parameters
----------
X : array-like
Input signal.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
n_samples = self._get_n_samples(X)
# Compute time-bandwidth product
if self.bandwidth is not None:
NW = self.bandwidth * n_samples / (2 * self.fs)
else:
NW = 4 # Default
# Number of tapers
if self.n_tapers is not None:
K = self.n_tapers
else:
K = int(2 * NW - 1)
K = max(1, K)
# Compute DPSS tapers
self.dpss_, self.eigenvalues_ = signal.windows.dpss(
n_samples, NW, K, return_ratios=True
)
# Filter low-bias tapers
if self.low_bias:
keep = self.eigenvalues_ > 0.9
if keep.sum() > 0:
self.dpss_ = self.dpss_[keep]
self.eigenvalues_ = self.eigenvalues_[keep]
# Compute frequency bins
self.freqs_ = rfftfreq(n_samples, 1 / self.fs)
self._n_samples = n_samples
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Compute multitaper PSD.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
Power spectral density.
"""
self._check_is_fitted()
X = self._validate_signal(X)
if self.copy:
X = X.copy()
X_2d, was_1d, original_shape = ensure_2d_signals(X)
results = []
for sig in X_2d:
psd = self._compute_multitaper_psd(sig)
results.append(psd)
result = np.array(results)
# Restore shape for 3D input
if len(original_shape) == 3:
n_trials, n_channels, _ = original_shape
n_freqs = result.shape[-1]
result = result.reshape(n_trials, n_channels, n_freqs)
elif was_1d:
result = result.flatten()
return result
def _compute_multitaper_psd(self, x: np.ndarray) -> np.ndarray:
"""Compute multitaper PSD for a single signal."""
n_tapers = len(self.dpss_)
# Apply each taper and compute FFT
tapered_spectra = []
for taper in self.dpss_:
tapered = x * taper
spectrum = rfft(tapered)
tapered_spectra.append(np.abs(spectrum) ** 2)
tapered_spectra = np.array(tapered_spectra)
if self.adaptive and n_tapers > 1:
# Adaptive weighting
psd = self._adaptive_weights(tapered_spectra)
else:
# Simple average
weights = self.eigenvalues_ / self.eigenvalues_.sum()
psd = np.average(tapered_spectra, axis=0, weights=weights)
# Normalization
if self.normalization == "full":
psd = psd / self.fs
elif self.normalization == "length":
psd = psd / len(x)
return psd
def _adaptive_weights(self, tapered_spectra: np.ndarray) -> np.ndarray:
"""Compute adaptive weights for multitaper combination.
Uses iterative method from Thomson (1982).
"""
n_tapers, n_freqs = tapered_spectra.shape
# Initial estimate (simple average)
psd = np.mean(tapered_spectra, axis=0)
# Iterate to find optimal weights
for _ in range(5):
weights = np.zeros((n_tapers, n_freqs))
for k in range(n_tapers):
weights[k] = (
self.eigenvalues_[k]
* psd
/ (
self.eigenvalues_[k] * psd
+ (1 - self.eigenvalues_[k]) * np.var(psd)
)
)
# Normalize weights
weights = weights / weights.sum(axis=0, keepdims=True)
# Update PSD estimate
psd = np.sum(weights * tapered_spectra, axis=0)
return psd
[docs]
class BandPowerExtractor(BaseFeatureExtractor):
"""Extract power in specified frequency bands.
Computes absolute and relative band powers from signals,
commonly used for EEG analysis.
Parameters
----------
fs : float
Sample rate in Hz.
bands : dict
Dictionary mapping band names to (low, high) frequency tuples.
Default includes standard EEG bands.
method : str, default='welch'
PSD estimation method: 'welch', 'multitaper', 'fft'.
relative : bool, default=True
If True, also compute relative band powers.
log_power : bool, default=False
If True, return log10 of band powers.
welch_params : dict, optional
Additional parameters for Welch's method.
multitaper_params : dict, optional
Additional parameters for multitaper method.
Attributes
----------
feature_names_ : list of str
Names of extracted features.
Examples
--------
>>> bands = {'alpha': (8, 13), 'beta': (13, 30)}
>>> bp = BandPowerExtractor(fs=256, bands=bands)
>>> features = bp.fit_transform(eeg_signals)
"""
# Default EEG bands
DEFAULT_BANDS = {
"delta": (0.5, 4),
"theta": (4, 8),
"alpha": (8, 13),
"beta": (13, 30),
"gamma": (30, 100),
}
def __init__(
self,
fs: float,
bands: dict[str, tuple[float, float]] | None = None,
method: str = "welch",
relative: bool = True,
log_power: bool = False,
welch_params: dict[str, Any] | None = None,
multitaper_params: dict[str, Any] | None = None,
):
super().__init__(fs=fs)
self.bands = bands if bands is not None else self.DEFAULT_BANDS.copy()
self.method = method
self.relative = relative
self.log_power = log_power
self.welch_params = welch_params or {}
self.multitaper_params = multitaper_params or {}
if method not in ("welch", "multitaper", "fft"):
raise ValueError(f"method must be 'welch', 'multitaper', or 'fft', got {method}")
[docs]
def fit(self, X, y=None, **fit_params) -> BandPowerExtractor:
"""Fit the extractor.
Parameters
----------
X : array-like
Input signals.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
# Create PSD estimator
if self.method == "welch":
self._psd = WelchPSD(fs=self.fs, **self.welch_params)
elif self.method == "multitaper":
self._psd = MultitaperPSD(fs=self.fs, **self.multitaper_params)
else:
self._psd = FFTTransformer(fs=self.fs, output="power")
self._psd.fit(X)
# Build feature names
self.feature_names_ = []
for band_name in self.bands:
self.feature_names_.append(f"bp_{band_name}_abs")
if self.relative:
for band_name in self.bands:
self.feature_names_.append(f"bp_{band_name}_rel")
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Extract band power features.
Parameters
----------
X : array-like
Input signals of shape (n_samples, n_timepoints) or
(n_samples, n_channels, n_timepoints).
Returns
-------
np.ndarray of shape (n_samples, n_features)
Extracted band power features.
"""
X = self._validate_signal(X)
X_2d, was_1d, original_shape = ensure_2d_signals(X)
# Get PSD
psd = self._psd.transform(X_2d)
freqs = self._psd.freqs_
# Frequency resolution for integration
freq_res = freqs[1] - freqs[0]
features = []
for sig_psd in psd:
sig_features = []
# Total power for relative calculation
total_power = np.trapz(sig_psd, dx=freq_res)
# Absolute band powers
abs_powers = {}
for band_name, (low, high) in self.bands.items():
mask = (freqs >= low) & (freqs <= high)
band_power = np.trapz(sig_psd[mask], dx=freq_res)
abs_powers[band_name] = band_power
if self.log_power:
sig_features.append(np.log10(band_power + 1e-10))
else:
sig_features.append(band_power)
# Relative band powers
if self.relative:
for band_name in self.bands:
rel_power = abs_powers[band_name] / (total_power + 1e-10)
sig_features.append(rel_power)
features.append(sig_features)
return np.array(features)
[docs]
def get_feature_names_out(self, input_features=None) -> list[str]:
"""Get output feature names."""
return self.feature_names_
[docs]
class SpectralFeatureExtractor(BaseFeatureExtractor):
"""Extract comprehensive spectral features from signals.
Computes a variety of frequency-domain features including:
- Spectral centroid, spread, skewness, kurtosis
- Spectral entropy
- Spectral edge frequencies
- Spectral flatness and rolloff
- Peak frequency
Parameters
----------
fs : float
Sample rate in Hz.
method : str, default='welch'
PSD estimation method: 'welch', 'multitaper', 'fft'.
edge_percentiles : list of float, default=[0.5, 0.75, 0.9, 0.95]
Percentiles for spectral edge frequency computation.
welch_params : dict, optional
Additional parameters for Welch's method.
Attributes
----------
feature_names_ : list of str
Names of extracted features.
Examples
--------
>>> extractor = SpectralFeatureExtractor(fs=256)
>>> features = extractor.fit_transform(signals)
"""
def __init__(
self,
fs: float,
method: str = "welch",
edge_percentiles: list[float] | None = None,
welch_params: dict[str, Any] | None = None,
):
super().__init__(fs=fs)
self.method = method
self.edge_percentiles = edge_percentiles or [0.5, 0.75, 0.9, 0.95]
self.welch_params = welch_params or {}
[docs]
def fit(self, X, y=None, **fit_params) -> SpectralFeatureExtractor:
"""Fit the extractor.
Parameters
----------
X : array-like
Input signals.
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
# Create PSD estimator
if self.method == "welch":
self._psd = WelchPSD(fs=self.fs, **self.welch_params)
elif self.method == "multitaper":
self._psd = MultitaperPSD(fs=self.fs)
else:
self._psd = FFTTransformer(fs=self.fs, output="power")
self._psd.fit(X)
# Build feature names
self.feature_names_ = [
"spectral_centroid",
"spectral_spread",
"spectral_skewness",
"spectral_kurtosis",
"spectral_entropy",
"spectral_flatness",
"spectral_rolloff",
"peak_frequency",
"mean_frequency",
"median_frequency",
]
for pct in self.edge_percentiles:
self.feature_names_.append(f"spectral_edge_{int(pct * 100)}")
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Extract spectral features.
Parameters
----------
X : array-like
Input signals.
Returns
-------
np.ndarray of shape (n_samples, n_features)
Extracted spectral features.
"""
X = self._validate_signal(X)
X_2d, was_1d, original_shape = ensure_2d_signals(X)
# Get PSD
psd = self._psd.transform(X_2d)
freqs = self._psd.freqs_
features = []
for sig_psd in psd:
sig_features = self._compute_spectral_features(sig_psd, freqs)
features.append(sig_features)
return np.array(features)
def _compute_spectral_features(
self, psd: np.ndarray, freqs: np.ndarray
) -> list[float]:
"""Compute spectral features for a single PSD."""
# Normalize PSD to probability distribution
psd_norm = psd / (psd.sum() + 1e-10)
# Spectral centroid (mean frequency weighted by power)
centroid = np.sum(freqs * psd_norm)
# Spectral spread (std of frequency distribution)
spread = np.sqrt(np.sum(((freqs - centroid) ** 2) * psd_norm))
# Spectral skewness
skewness = np.sum(((freqs - centroid) ** 3) * psd_norm) / (spread**3 + 1e-10)
# Spectral kurtosis
kurtosis = np.sum(((freqs - centroid) ** 4) * psd_norm) / (spread**4 + 1e-10)
# Spectral entropy
psd_prob = psd_norm + 1e-10
entropy = -np.sum(psd_prob * np.log2(psd_prob))
# Spectral flatness (geometric mean / arithmetic mean)
geom_mean = np.exp(np.mean(np.log(psd + 1e-10)))
arith_mean = np.mean(psd)
flatness = geom_mean / (arith_mean + 1e-10)
# Spectral rolloff (frequency below which 85% of power lies)
cumsum = np.cumsum(psd)
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
rolloff = freqs[min(rolloff_idx, len(freqs) - 1)]
# Peak frequency
peak_freq = freqs[np.argmax(psd)]
# Mean frequency
mean_freq = centroid
# Median frequency
median_idx = np.searchsorted(cumsum, 0.5 * cumsum[-1])
median_freq = freqs[min(median_idx, len(freqs) - 1)]
features = [
centroid,
spread,
skewness,
kurtosis,
entropy,
flatness,
rolloff,
peak_freq,
mean_freq,
median_freq,
]
# Spectral edge frequencies
for pct in self.edge_percentiles:
edge_idx = np.searchsorted(cumsum, pct * cumsum[-1])
edge_freq = freqs[min(edge_idx, len(freqs) - 1)]
features.append(edge_freq)
return features
[docs]
def get_feature_names_out(self, input_features=None) -> list[str]:
"""Get output feature names."""
return self.feature_names_
[docs]
def compute_psd(
x: np.ndarray,
fs: float,
method: str = "welch",
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Convenience function to compute PSD.
Parameters
----------
x : np.ndarray
Input signal.
fs : float
Sample rate in Hz.
method : str, default='welch'
PSD estimation method: 'welch', 'multitaper', 'fft'.
**kwargs
Additional parameters for the PSD method.
Returns
-------
freqs : np.ndarray
Frequency bins.
psd : np.ndarray
Power spectral density.
Examples
--------
>>> freqs, psd = compute_psd(signal, fs=256, method='welch')
"""
if method == "welch":
estimator = WelchPSD(fs=fs, **kwargs)
elif method == "multitaper":
estimator = MultitaperPSD(fs=fs, **kwargs)
elif method == "fft":
estimator = FFTTransformer(fs=fs, output="power", **kwargs)
else:
raise ValueError(f"Unknown method: {method}")
estimator.fit(x)
psd = estimator.transform(x)
return estimator.freqs_, psd
[docs]
def compute_band_power(
x: np.ndarray,
fs: float,
band: tuple[float, float],
method: str = "welch",
relative: bool = False,
) -> float:
"""Compute power in a frequency band.
Parameters
----------
x : np.ndarray
Input signal.
fs : float
Sample rate in Hz.
band : tuple of (low, high)
Frequency band in Hz.
method : str, default='welch'
PSD estimation method.
relative : bool, default=False
If True, return relative band power.
Returns
-------
float
Band power.
Examples
--------
>>> alpha_power = compute_band_power(eeg, fs=256, band=(8, 13))
"""
freqs, psd = compute_psd(x, fs, method=method)
# Flatten if needed
if psd.ndim > 1:
psd = psd.flatten()
freq_res = freqs[1] - freqs[0]
mask = (freqs >= band[0]) & (freqs <= band[1])
band_power = np.trapz(psd[mask], dx=freq_res)
if relative:
total_power = np.trapz(psd, dx=freq_res)
return band_power / (total_power + 1e-10)
return band_power