Source code for endgame.signal.base
from __future__ import annotations
"""Base classes for signal processing transformers.
Provides sklearn-compatible base classes that handle:
- Flexible input formats (1D, 2D, 3D arrays)
- Sample rate tracking
- Channel-wise operations
- Integration with time series module
"""
from abc import abstractmethod
from typing import Any
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
[docs]
class SignalMixin:
"""Mixin providing common signal processing functionality.
Handles input validation, reshaping, and sample rate management
for all signal processing transformers.
"""
_estimator_type = "transformer"
def _validate_signal(
self,
X: Any,
ensure_2d: bool = True,
) -> np.ndarray:
"""Validate and convert input signal.
Parameters
----------
X : array-like
Input signal. Can be:
- 1D: (n_samples,) single channel
- 2D: (n_samples, n_channels) or (n_trials, n_samples)
- 3D: (n_trials, n_channels, n_samples)
ensure_2d : bool, default=True
If True, ensures output is at least 2D.
Returns
-------
np.ndarray
Validated signal array.
"""
X = np.asarray(X, dtype=np.float64)
if X.ndim == 1:
if ensure_2d:
X = X.reshape(1, -1) # (1, n_samples)
elif X.ndim == 2:
pass # Already 2D
elif X.ndim == 3:
pass # 3D is fine
else:
raise ValueError(f"Expected 1D, 2D, or 3D array, got shape {X.shape}")
return X
def _get_n_samples(self, X: np.ndarray) -> int:
"""Get number of time samples from array."""
if X.ndim == 1:
return X.shape[0]
elif X.ndim == 2:
return X.shape[1] # (n_channels, n_samples)
else:
return X.shape[2] # (n_trials, n_channels, n_samples)
def _get_n_channels(self, X: np.ndarray) -> int:
"""Get number of channels from array."""
if X.ndim == 1:
return 1
elif X.ndim == 2:
return X.shape[0]
else:
return X.shape[1]
def _apply_along_axis(
self,
func,
X: np.ndarray,
axis: int = -1,
**kwargs,
) -> np.ndarray:
"""Apply function along time axis.
Parameters
----------
func : callable
Function to apply.
X : np.ndarray
Input signal.
axis : int, default=-1
Axis to apply along (typically time axis).
**kwargs
Additional arguments to pass to func.
Returns
-------
np.ndarray
Transformed signal.
"""
return np.apply_along_axis(func, axis, X, **kwargs)
def _check_fs(self, fs: float | None = None) -> float:
"""Check and return sample rate."""
if fs is not None:
return float(fs)
if hasattr(self, 'fs') and self.fs is not None:
return float(self.fs)
raise ValueError(
"Sample rate (fs) must be specified either in constructor or method call"
)
[docs]
class BaseSignalTransformer(BaseEstimator, TransformerMixin, SignalMixin):
"""Base class for all signal processing transformers.
Provides sklearn-compatible interface with signal-specific extensions.
Parameters
----------
fs : float, optional
Sample rate in Hz. Required for frequency-dependent operations.
copy : bool, default=True
Whether to copy input data before processing.
Attributes
----------
n_samples_seen_ : int
Number of samples processed during fit.
n_channels_seen_ : int
Number of channels seen during fit.
"""
def __init__(
self,
fs: float | None = None,
copy: bool = True,
):
self.fs = fs
self.copy = copy
self._is_fitted = False
self.n_samples_seen_: int | None = None
self.n_channels_seen_: int | None = None
def _check_is_fitted(self) -> None:
"""Raise error if not fitted."""
if not self._is_fitted:
raise RuntimeError(
f"{self.__class__.__name__} has not been fitted. "
"Call 'fit' before using this transformer."
)
[docs]
def fit(self, X, y=None, **fit_params) -> BaseSignalTransformer:
"""Fit the transformer.
Most signal transformers don't need fitting, but this provides
a consistent sklearn interface.
Parameters
----------
X : array-like
Input signal.
y : ignored
**fit_params : dict
Additional parameters.
Returns
-------
self
"""
X = self._validate_signal(X)
self.n_samples_seen_ = self._get_n_samples(X)
self.n_channels_seen_ = self._get_n_channels(X)
self._is_fitted = True
return self
[docs]
@abstractmethod
def transform(self, X) -> np.ndarray:
"""Transform the signal.
Parameters
----------
X : array-like
Input signal.
Returns
-------
np.ndarray
Transformed signal.
"""
pass
[docs]
def fit_transform(self, X, y=None, **fit_params) -> np.ndarray:
"""Fit and transform in one step."""
return self.fit(X, y, **fit_params).transform(X)
[docs]
def inverse_transform(self, X) -> np.ndarray:
"""Inverse transform (if applicable).
Parameters
----------
X : array-like
Transformed signal.
Returns
-------
np.ndarray
Reconstructed signal.
Raises
------
NotImplementedError
If inverse transform is not supported.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support inverse_transform"
)
[docs]
class BaseFeatureExtractor(BaseEstimator, TransformerMixin, SignalMixin):
"""Base class for feature extraction from signals.
Unlike transformers that output signals, feature extractors output
feature vectors suitable for machine learning.
Parameters
----------
fs : float, optional
Sample rate in Hz.
"""
def __init__(self, fs: float | None = None):
self.fs = fs
self._is_fitted = False
self.feature_names_: list[str] | None = None
[docs]
def fit(self, X, y=None, **fit_params) -> BaseFeatureExtractor:
"""Fit the feature extractor.
Parameters
----------
X : array-like
Input signals of shape (n_samples, n_timepoints) or
(n_samples, n_channels, n_timepoints).
y : ignored
Returns
-------
self
"""
X = self._validate_signal(X)
self._is_fitted = True
return self
[docs]
@abstractmethod
def transform(self, X) -> np.ndarray:
"""Extract features from signals.
Parameters
----------
X : array-like
Input signals.
Returns
-------
np.ndarray of shape (n_samples, n_features)
Extracted features.
"""
pass
[docs]
def fit_transform(self, X, y=None, **fit_params) -> np.ndarray:
"""Fit and transform in one step."""
return self.fit(X, y, **fit_params).transform(X)
[docs]
def get_feature_names_out(self, input_features=None) -> list[str]:
"""Get output feature names.
Returns
-------
List[str]
Feature names.
"""
if self.feature_names_ is None:
raise RuntimeError("Feature names not available. Call fit first.")
return self.feature_names_
def ensure_2d_signals(X: np.ndarray) -> tuple[np.ndarray, bool, tuple]:
"""Ensure signals are 2D for processing.
Parameters
----------
X : np.ndarray
Input of shape (n_samples,), (n_channels, n_samples), or
(n_trials, n_channels, n_samples).
Returns
-------
X_2d : np.ndarray
2D array of shape (n_signals, n_samples).
was_1d : bool
Whether input was 1D.
original_shape : tuple
Original shape for reshaping back.
"""
original_shape = X.shape
was_1d = X.ndim == 1
if X.ndim == 1:
X_2d = X.reshape(1, -1)
elif X.ndim == 2:
X_2d = X
elif X.ndim == 3:
# Flatten trials and channels
n_trials, n_channels, n_samples = X.shape
X_2d = X.reshape(n_trials * n_channels, n_samples)
else:
raise ValueError(f"Expected 1D, 2D, or 3D array, got shape {X.shape}")
return X_2d, was_1d, original_shape
def restore_shape(X: np.ndarray, was_1d: bool, original_shape: tuple) -> np.ndarray:
"""Restore array to original shape after processing.
Parameters
----------
X : np.ndarray
Processed 2D array.
was_1d : bool
Whether original was 1D.
original_shape : tuple
Original shape.
Returns
-------
np.ndarray
Array reshaped to match original dimensionality.
"""
if was_1d:
return X.flatten()
elif len(original_shape) == 2:
return X
elif len(original_shape) == 3:
n_trials, n_channels, _ = original_shape
n_samples_out = X.shape[1]
return X.reshape(n_trials, n_channels, n_samples_out)
return X