from __future__ import annotations
"""Digital filtering for signal processing.
Provides sklearn-compatible wrappers around scipy.signal filters
with sensible defaults for biosignal processing.
Supported filters:
- Butterworth (lowpass, highpass, bandpass, bandstop)
- FIR (window-based design)
- Savitzky-Golay (smoothing with polynomial fitting)
- Notch (powerline interference removal)
- Median (impulse noise removal)
Examples
--------
>>> from endgame.signal import ButterworthFilter, NotchFilter
>>>
>>> # Bandpass filter for EEG
>>> bp = ButterworthFilter(lowcut=0.5, highcut=50, fs=256, order=4)
>>> filtered = bp.fit_transform(raw_eeg)
>>>
>>> # Remove powerline noise
>>> notch = NotchFilter(freq=50, fs=256, Q=30)
>>> clean = notch.fit_transform(filtered)
"""
from typing import Literal
import numpy as np
from scipy import signal as scipy_signal
from endgame.signal.base import (
BaseSignalTransformer,
ensure_2d_signals,
restore_shape,
)
[docs]
class ButterworthFilter(BaseSignalTransformer):
"""Butterworth IIR filter with zero-phase filtering.
Applies a Butterworth filter using scipy.signal.filtfilt for
zero-phase distortion (forward-backward filtering).
Parameters
----------
lowcut : float, optional
Low cutoff frequency in Hz. Required for highpass/bandpass.
highcut : float, optional
High cutoff frequency in Hz. Required for lowpass/bandpass.
fs : float
Sample rate in Hz.
order : int, default=4
Filter order. Higher = sharper rolloff but more ringing.
btype : str, optional
Filter type: 'lowpass', 'highpass', 'bandpass', 'bandstop'.
Auto-detected from lowcut/highcut if not specified.
padlen : int, optional
Padding length for filtfilt. Default is 3 * max(len(a), len(b)).
Examples
--------
>>> # Bandpass filter 0.5-50 Hz
>>> filt = ButterworthFilter(lowcut=0.5, highcut=50, fs=256, order=4)
>>> filtered = filt.fit_transform(signal)
>>> # Highpass filter (remove DC and drift)
>>> filt = ButterworthFilter(lowcut=0.1, fs=256)
>>> filtered = filt.fit_transform(signal)
Notes
-----
Uses second-order sections (SOS) format for numerical stability,
especially at higher orders.
"""
def __init__(
self,
lowcut: float | None = None,
highcut: float | None = None,
fs: float = None,
order: int = 4,
btype: Literal["lowpass", "highpass", "bandpass", "bandstop"] | None = None,
padlen: int | None = None,
):
super().__init__(fs=fs)
self.lowcut = lowcut
self.highcut = highcut
self.order = order
self.btype = btype
self.padlen = padlen
self._sos = None
def _design_filter(self) -> np.ndarray:
"""Design the Butterworth filter."""
fs = self._check_fs()
nyq = fs / 2
# Determine filter type
if self.btype is not None:
btype = self.btype
elif self.lowcut is not None and self.highcut is not None:
btype = "bandpass"
elif self.lowcut is not None:
btype = "highpass"
elif self.highcut is not None:
btype = "lowpass"
else:
raise ValueError("Must specify lowcut and/or highcut")
# Determine critical frequencies
if btype == "lowpass":
Wn = self.highcut / nyq
elif btype == "highpass":
Wn = self.lowcut / nyq
elif btype in ("bandpass", "bandstop"):
Wn = [self.lowcut / nyq, self.highcut / nyq]
else:
raise ValueError(f"Unknown btype: {btype}")
# Validate frequencies
Wn = np.asarray(Wn)
if np.any(Wn <= 0) or np.any(Wn >= 1):
raise ValueError(
f"Critical frequencies must be between 0 and Nyquist ({nyq} Hz). "
f"Got: lowcut={self.lowcut}, highcut={self.highcut}"
)
# Design filter using SOS for stability
sos = scipy_signal.butter(self.order, Wn, btype=btype, output="sos")
return sos
[docs]
def fit(self, X, y=None, **fit_params) -> ButterworthFilter:
"""Fit the filter (designs the filter coefficients)."""
super().fit(X, y, **fit_params)
self._sos = self._design_filter()
return self
[docs]
def get_frequency_response(
self,
n_points: int = 512,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the filter's frequency response.
Parameters
----------
n_points : int, default=512
Number of frequency points.
Returns
-------
freqs : np.ndarray
Frequency array in Hz.
response : np.ndarray
Magnitude response (absolute value).
"""
self._check_is_fitted()
fs = self._check_fs()
w, h = scipy_signal.sosfreqz(self._sos, worN=n_points, fs=fs)
return w, np.abs(h)
[docs]
class FIRFilter(BaseSignalTransformer):
"""FIR filter using window-based design.
Finite Impulse Response filter designed using the window method.
Linear phase (no phase distortion) when using filtfilt.
Parameters
----------
lowcut : float, optional
Low cutoff frequency in Hz.
highcut : float, optional
High cutoff frequency in Hz.
fs : float
Sample rate in Hz.
numtaps : int, default=101
Number of filter taps (filter length). Must be odd for Type I filter.
window : str, default='hamming'
Window function: 'hamming', 'hann', 'blackman', 'kaiser', etc.
pass_zero : bool or str, optional
If True, DC gain is 1. Auto-detected if not specified.
Examples
--------
>>> # Bandpass FIR filter
>>> filt = FIRFilter(lowcut=1, highcut=40, fs=256, numtaps=101)
>>> filtered = filt.fit_transform(signal)
Notes
-----
FIR filters are inherently stable and have linear phase, but require
more taps than IIR filters for sharp transitions.
"""
def __init__(
self,
lowcut: float | None = None,
highcut: float | None = None,
fs: float = None,
numtaps: int = 101,
window: str = "hamming",
pass_zero: bool | str | None = None,
):
super().__init__(fs=fs)
self.lowcut = lowcut
self.highcut = highcut
self.numtaps = numtaps
self.window = window
self.pass_zero = pass_zero
self._b = None
def _design_filter(self) -> np.ndarray:
"""Design the FIR filter."""
fs = self._check_fs()
nyq = fs / 2
# Ensure odd number of taps
numtaps = self.numtaps
if numtaps % 2 == 0:
numtaps += 1
# Determine cutoff and pass_zero
if self.lowcut is not None and self.highcut is not None:
cutoff = [self.lowcut, self.highcut]
pass_zero = self.pass_zero if self.pass_zero is not None else False
elif self.lowcut is not None:
cutoff = self.lowcut
pass_zero = self.pass_zero if self.pass_zero is not None else False
elif self.highcut is not None:
cutoff = self.highcut
pass_zero = self.pass_zero if self.pass_zero is not None else True
else:
raise ValueError("Must specify lowcut and/or highcut")
b = scipy_signal.firwin(
numtaps,
cutoff,
fs=fs,
window=self.window,
pass_zero=pass_zero,
)
return b
[docs]
def fit(self, X, y=None, **fit_params) -> FIRFilter:
"""Fit the filter."""
super().fit(X, y, **fit_params)
self._b = self._design_filter()
return self
[docs]
class SavgolFilter(BaseSignalTransformer):
"""Savitzky-Golay smoothing filter.
Performs polynomial smoothing using local least-squares fitting.
Preserves signal features better than simple moving average.
Parameters
----------
window_length : int, default=11
Window size in samples. Must be odd and > polyorder.
polyorder : int, default=3
Polynomial order for fitting.
deriv : int, default=0
Derivative order (0 = smoothing, 1 = first derivative, etc.)
delta : float, default=1.0
Sample spacing (for derivative calculation).
mode : str, default='interp'
Edge handling: 'mirror', 'constant', 'nearest', 'wrap', 'interp'.
Examples
--------
>>> # Smooth signal
>>> filt = SavgolFilter(window_length=11, polyorder=3)
>>> smoothed = filt.fit_transform(noisy_signal)
>>> # Compute smooth derivative
>>> filt = SavgolFilter(window_length=11, polyorder=3, deriv=1, delta=1/fs)
>>> velocity = filt.fit_transform(position)
Notes
-----
Savitzky-Golay filters are particularly good for preserving peaks
and other high-frequency features while removing noise.
"""
def __init__(
self,
window_length: int = 11,
polyorder: int = 3,
deriv: int = 0,
delta: float = 1.0,
mode: str = "interp",
fs: float | None = None,
):
super().__init__(fs=fs)
self.window_length = window_length
self.polyorder = polyorder
self.deriv = deriv
self.delta = delta
self.mode = mode
[docs]
def fit(self, X, y=None, **fit_params) -> SavgolFilter:
"""Fit the filter."""
super().fit(X, y, **fit_params)
# Validate parameters
if self.window_length % 2 == 0:
raise ValueError("window_length must be odd")
if self.window_length <= self.polyorder:
raise ValueError("window_length must be > polyorder")
return self
[docs]
class NotchFilter(BaseSignalTransformer):
"""Notch filter for removing specific frequencies.
Typically used to remove powerline interference (50/60 Hz)
and its harmonics.
Parameters
----------
freq : float or List[float]
Notch frequency/frequencies in Hz.
fs : float
Sample rate in Hz.
Q : float, default=30
Quality factor. Higher Q = narrower notch.
harmonics : int, default=0
Number of harmonics to also filter (0 = fundamental only).
Examples
--------
>>> # Remove 60 Hz powerline noise
>>> notch = NotchFilter(freq=60, fs=256, Q=30)
>>> clean = notch.fit_transform(signal)
>>> # Remove 50 Hz and first 3 harmonics
>>> notch = NotchFilter(freq=50, fs=1000, Q=30, harmonics=3)
>>> clean = notch.fit_transform(signal)
"""
def __init__(
self,
freq: float | list[float],
fs: float,
Q: float = 30,
harmonics: int = 0,
):
super().__init__(fs=fs)
self.freq = freq
self.Q = Q
self.harmonics = harmonics
self._filters = []
[docs]
def fit(self, X, y=None, **fit_params) -> NotchFilter:
"""Design notch filter(s)."""
super().fit(X, y, **fit_params)
fs = self._check_fs()
nyq = fs / 2
# Collect all frequencies to notch
freqs = [self.freq] if isinstance(self.freq, (int, float)) else list(self.freq)
# Add harmonics
if self.harmonics > 0:
base_freqs = freqs.copy()
for f in base_freqs:
for h in range(2, self.harmonics + 2):
harmonic = f * h
if harmonic < nyq:
freqs.append(harmonic)
# Design filters
self._filters = []
for f in freqs:
if f < nyq:
b, a = scipy_signal.iirnotch(f, self.Q, fs)
self._filters.append((b, a))
return self
[docs]
class FilterBank(BaseSignalTransformer):
"""Bank of parallel bandpass filters.
Decomposes signal into multiple frequency bands simultaneously.
Useful for spectral analysis and feature extraction.
Parameters
----------
bands : Dict[str, Tuple[float, float]]
Dictionary mapping band names to (low, high) frequency tuples.
Example: {'alpha': (8, 13), 'beta': (13, 30)}
fs : float
Sample rate in Hz.
filter_type : str, default='butterworth'
Filter type: 'butterworth' or 'fir'.
order : int, default=4
Filter order (for Butterworth).
numtaps : int, default=101
Number of taps (for FIR).
output : str, default='dict'
Output format: 'dict' (separate bands) or 'stack' (3D array).
Examples
--------
>>> bands = {
... 'delta': (0.5, 4),
... 'theta': (4, 8),
... 'alpha': (8, 13),
... 'beta': (13, 30),
... 'gamma': (30, 100),
... }
>>> fb = FilterBank(bands=bands, fs=256)
>>> band_signals = fb.fit_transform(eeg_signal)
>>> alpha_signal = band_signals['alpha']
"""
def __init__(
self,
bands: dict,
fs: float,
filter_type: Literal["butterworth", "fir"] = "butterworth",
order: int = 4,
numtaps: int = 101,
output: Literal["dict", "stack"] = "dict",
):
super().__init__(fs=fs)
self.bands = bands
self.filter_type = filter_type
self.order = order
self.numtaps = numtaps
self.output = output
self._filters = {}
[docs]
def fit(self, X, y=None, **fit_params) -> FilterBank:
"""Design all bandpass filters."""
super().fit(X, y, **fit_params)
for name, (low, high) in self.bands.items():
if self.filter_type == "butterworth":
filt = ButterworthFilter(
lowcut=low, highcut=high, fs=self.fs, order=self.order
)
else:
filt = FIRFilter(
lowcut=low, highcut=high, fs=self.fs, numtaps=self.numtaps
)
filt.fit(X)
self._filters[name] = filt
return self
@property
def band_names(self) -> list[str]:
"""Get list of band names."""
return list(self.bands.keys())