from __future__ import annotations
"""Spatial filtering and Riemannian geometry for multi-channel signals.
Provides sklearn-compatible spatial filtering methods:
- Common Spatial Patterns (CSP)
- Tangent Space projection
- Covariance estimation
These methods operate on multi-channel signals (e.g., EEG) and extract
features that capture spatial relationships between channels.
References
----------
- Koles et al. (1990): Common Spatial Patterns
- Barachant et al. (2012): Riemannian geometry for EEG
- pyriemann library: Riemannian geometry implementations
"""
import numpy as np
from scipy import linalg
from endgame.signal.base import (
BaseFeatureExtractor,
BaseSignalTransformer,
)
def _covariance(X: np.ndarray, estimator: str = "empirical") -> np.ndarray:
"""Compute covariance matrix.
Parameters
----------
X : np.ndarray
Signal of shape (n_channels, n_samples).
estimator : str, default='empirical'
Covariance estimator: 'empirical', 'oas', 'lwf', 'scm'.
Returns
-------
np.ndarray
Covariance matrix of shape (n_channels, n_channels).
"""
n_channels, n_samples = X.shape
if estimator == "empirical":
return np.cov(X)
elif estimator == "scm":
# Sample covariance matrix (not unbiased)
return X @ X.T / n_samples
elif estimator == "oas":
# Oracle Approximating Shrinkage
try:
from sklearn.covariance import OAS
oas = OAS().fit(X.T)
return oas.covariance_
except ImportError:
return np.cov(X)
elif estimator == "lwf":
# Ledoit-Wolf shrinkage
try:
from sklearn.covariance import LedoitWolf
lw = LedoitWolf().fit(X.T)
return lw.covariance_
except ImportError:
return np.cov(X)
else:
raise ValueError(f"Unknown estimator: {estimator}")
def _matrix_power(A: np.ndarray, power: float) -> np.ndarray:
"""Compute matrix power using eigendecomposition.
Parameters
----------
A : np.ndarray
Symmetric positive definite matrix.
power : float
Power to raise matrix to.
Returns
-------
np.ndarray
A^power.
"""
eigenvalues, eigenvectors = linalg.eigh(A)
eigenvalues = np.maximum(eigenvalues, 1e-10) # Ensure positive
return eigenvectors @ np.diag(eigenvalues**power) @ eigenvectors.T
def _logm(A: np.ndarray) -> np.ndarray:
"""Compute matrix logarithm for SPD matrix.
Parameters
----------
A : np.ndarray
Symmetric positive definite matrix.
Returns
-------
np.ndarray
Matrix logarithm.
"""
eigenvalues, eigenvectors = linalg.eigh(A)
eigenvalues = np.maximum(eigenvalues, 1e-10)
return eigenvectors @ np.diag(np.log(eigenvalues)) @ eigenvectors.T
def _expm(A: np.ndarray) -> np.ndarray:
"""Compute matrix exponential for symmetric matrix.
Parameters
----------
A : np.ndarray
Symmetric matrix.
Returns
-------
np.ndarray
Matrix exponential.
"""
eigenvalues, eigenvectors = linalg.eigh(A)
return eigenvectors @ np.diag(np.exp(eigenvalues)) @ eigenvectors.T
def _mean_riemann(covmats: np.ndarray, tol: float = 1e-8, max_iter: int = 50) -> np.ndarray:
"""Compute Riemannian mean of covariance matrices.
Uses iterative geodesic descent to find the Frechet mean.
Parameters
----------
covmats : np.ndarray
Covariance matrices of shape (n_matrices, n_channels, n_channels).
tol : float, default=1e-8
Convergence tolerance.
max_iter : int, default=50
Maximum iterations.
Returns
-------
np.ndarray
Riemannian mean of shape (n_channels, n_channels).
"""
n_matrices = len(covmats)
# Initialize with arithmetic mean
mean = np.mean(covmats, axis=0)
for _ in range(max_iter):
# Compute tangent vectors at current mean
mean_sqrt_inv = _matrix_power(mean, -0.5)
tangent_sum = np.zeros_like(mean)
for cov in covmats:
# Project to tangent space
tangent = _logm(mean_sqrt_inv @ cov @ mean_sqrt_inv)
tangent_sum += tangent
tangent_mean = tangent_sum / n_matrices
# Check convergence
if np.linalg.norm(tangent_mean, "fro") < tol:
break
# Update mean
mean_sqrt = _matrix_power(mean, 0.5)
mean = mean_sqrt @ _expm(tangent_mean) @ mean_sqrt
return mean
[docs]
class CovarianceEstimator(BaseSignalTransformer):
"""Estimate covariance matrices from multi-channel signals.
Parameters
----------
fs : float, optional
Sample rate in Hz.
estimator : str, default='empirical'
Covariance estimator: 'empirical', 'oas', 'lwf', 'scm'.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
n_channels_ : int
Number of channels.
Examples
--------
>>> cov_est = CovarianceEstimator(estimator='oas')
>>> covmats = cov_est.fit_transform(X) # X: (n_trials, n_channels, n_samples)
"""
def __init__(
self,
fs: float | None = None,
estimator: str = "empirical",
copy: bool = True,
):
super().__init__(fs=fs, copy=copy)
self.estimator = estimator
[docs]
def fit(self, X, y=None, **fit_params) -> CovarianceEstimator:
"""Fit the estimator.
Parameters
----------
X : np.ndarray
Multi-channel signals of shape (n_trials, n_channels, n_samples).
y : ignored
Returns
-------
self
"""
X = np.asarray(X)
if X.ndim != 3:
raise ValueError(f"Expected 3D array (trials, channels, samples), got shape {X.shape}")
self.n_channels_ = X.shape[1]
self._is_fitted = True
return self
[docs]
class CSP(BaseSignalTransformer):
"""Common Spatial Patterns for discriminative spatial filtering.
CSP finds spatial filters that maximize variance for one class
while minimizing variance for another class.
Parameters
----------
fs : float, optional
Sample rate in Hz.
n_components : int, default=4
Number of CSP components (filters) per class.
reg : float or str, optional
Regularization parameter or method ('ledoit_wolf', 'oas').
log : bool, default=True
Apply log transform to variances (recommended for classification).
cov_estimator : str, default='empirical'
Covariance estimator.
copy : bool, default=True
Whether to copy input data.
Attributes
----------
filters_ : np.ndarray
Spatial filters of shape (n_filters, n_channels).
patterns_ : np.ndarray
Spatial patterns of shape (n_channels, n_filters).
eigenvalues_ : np.ndarray
Eigenvalues corresponding to each filter.
References
----------
Koles, Z. J., et al. (1990). Spatial patterns underlying population
differences in the background EEG.
Examples
--------
>>> csp = CSP(n_components=4)
>>> features = csp.fit_transform(X, y) # X: (n_trials, n_channels, n_samples)
"""
def __init__(
self,
fs: float | None = None,
n_components: int = 4,
reg: float | str | None = None,
log: bool = True,
cov_estimator: str = "empirical",
copy: bool = True,
):
super().__init__(fs=fs, copy=copy)
self.n_components = n_components
self.reg = reg
self.log = log
self.cov_estimator = cov_estimator
[docs]
def fit(self, X, y, **fit_params) -> CSP:
"""Fit CSP spatial filters.
Parameters
----------
X : np.ndarray
Multi-channel signals of shape (n_trials, n_channels, n_samples).
y : np.ndarray
Binary class labels.
Returns
-------
self
"""
X = np.asarray(X)
y = np.asarray(y)
if X.ndim != 3:
raise ValueError(f"Expected 3D array (trials, channels, samples), got shape {X.shape}")
classes = np.unique(y)
if len(classes) != 2:
raise ValueError(f"CSP requires exactly 2 classes, got {len(classes)}")
n_trials, n_channels, n_samples = X.shape
# Compute class covariances
covmats = np.zeros((n_trials, n_channels, n_channels))
for i in range(n_trials):
covmats[i] = _covariance(X[i], self.cov_estimator)
# Average covariance per class
cov_class1 = np.mean(covmats[y == classes[0]], axis=0)
cov_class2 = np.mean(covmats[y == classes[1]], axis=0)
# Apply regularization
if self.reg is not None:
if isinstance(self.reg, (int, float)):
cov_class1 = (1 - self.reg) * cov_class1 + self.reg * np.eye(n_channels)
cov_class2 = (1 - self.reg) * cov_class2 + self.reg * np.eye(n_channels)
# Solve generalized eigenvalue problem
# cov_class1 * W = cov_composite * W * D
cov_composite = cov_class1 + cov_class2
# Whitening transform
eigenvalues, eigenvectors = linalg.eigh(cov_composite)
eigenvalues = np.maximum(eigenvalues, 1e-10)
whitening = eigenvectors @ np.diag(1.0 / np.sqrt(eigenvalues)) @ eigenvectors.T
# Apply whitening to class1 covariance
cov_class1_white = whitening @ cov_class1 @ whitening.T
# Eigenvectors of whitened class1 covariance
eigenvalues, eigenvectors = linalg.eigh(cov_class1_white)
# Sort by eigenvalue (most discriminative first and last)
sorted_idx = np.argsort(eigenvalues)
# Take first and last n_components
n_comp = min(self.n_components, n_channels // 2)
selected_idx = np.concatenate([sorted_idx[:n_comp], sorted_idx[-n_comp:]])
self.eigenvalues_ = eigenvalues[selected_idx]
# Spatial filters
self.filters_ = (whitening.T @ eigenvectors[:, selected_idx]).T
# Spatial patterns (for visualization)
self.patterns_ = linalg.pinv(self.filters_).T
self.n_channels_ = n_channels
self._is_fitted = True
return self
[docs]
class TangentSpace(BaseFeatureExtractor):
"""Tangent space projection for Riemannian geometry.
Projects covariance matrices onto the tangent space at the
Riemannian mean, enabling use of Euclidean classifiers.
Parameters
----------
fs : float, optional
Sample rate in Hz.
metric : str, default='riemann'
Metric for mean computation: 'riemann', 'euclid', 'logeuclid'.
reference : np.ndarray, optional
Reference point (mean). If None, computed from training data.
cov_estimator : str, default='empirical'
Covariance estimator for input signals.
Attributes
----------
reference_ : np.ndarray
Reference covariance matrix.
feature_names_ : list
Names of tangent space features.
References
----------
Barachant, A., et al. (2012). Multiclass brain-computer interface
classification by Riemannian geometry.
Examples
--------
>>> ts = TangentSpace()
>>> # From covariance matrices
>>> features = ts.fit_transform(covmats)
>>> # Or from raw signals (computes covariances internally)
>>> features = ts.fit_transform(X, input_type='signals')
"""
def __init__(
self,
fs: float | None = None,
metric: str = "riemann",
reference: np.ndarray | None = None,
cov_estimator: str = "empirical",
):
super().__init__(fs=fs)
self.metric = metric
self.reference = reference
self.cov_estimator = cov_estimator
[docs]
def fit(self, X, y=None, input_type: str = "covariances", **fit_params) -> TangentSpace:
"""Fit the tangent space projector.
Parameters
----------
X : np.ndarray
Either covariance matrices (n_trials, n_channels, n_channels)
or raw signals (n_trials, n_channels, n_samples).
y : ignored
input_type : str, default='covariances'
Type of input: 'covariances' or 'signals'.
Returns
-------
self
"""
X = np.asarray(X)
# Convert signals to covariances if needed
if input_type == "signals":
if X.ndim != 3:
raise ValueError(f"Expected 3D array for signals, got shape {X.shape}")
covmats = np.zeros((X.shape[0], X.shape[1], X.shape[1]))
for i in range(X.shape[0]):
covmats[i] = _covariance(X[i], self.cov_estimator)
else:
covmats = X
if covmats.ndim != 3 or covmats.shape[1] != covmats.shape[2]:
raise ValueError(
f"Expected covariance matrices (n, ch, ch), got shape {covmats.shape}"
)
self.n_channels_ = covmats.shape[1]
# Compute reference (mean)
if self.reference is not None:
self.reference_ = self.reference
elif self.metric == "riemann":
self.reference_ = _mean_riemann(covmats)
elif self.metric == "euclid":
self.reference_ = np.mean(covmats, axis=0)
elif self.metric == "logeuclid":
log_covmats = np.array([_logm(cov) for cov in covmats])
self.reference_ = _expm(np.mean(log_covmats, axis=0))
else:
raise ValueError(f"Unknown metric: {self.metric}")
# Feature names (upper triangular elements)
n_features = self.n_channels_ * (self.n_channels_ + 1) // 2
self.feature_names_ = [f"ts_{i}" for i in range(n_features)]
self._is_fitted = True
return self
[docs]
class FilterBankCSP(BaseSignalTransformer):
"""Filter Bank Common Spatial Patterns.
Applies CSP to multiple frequency bands and concatenates features.
Parameters
----------
fs : float
Sample rate in Hz.
bands : list of tuple
Frequency bands as (low, high) tuples.
n_components : int, default=4
Number of CSP components per band.
filter_order : int, default=5
Order of bandpass filters.
Examples
--------
>>> bands = [(4, 8), (8, 12), (12, 30)]
>>> fbcsp = FilterBankCSP(fs=256, bands=bands)
>>> features = fbcsp.fit_transform(X, y)
"""
def __init__(
self,
fs: float,
bands: list[tuple[float, float]],
n_components: int = 4,
filter_order: int = 5,
):
super().__init__(fs=fs)
self.bands = bands
self.n_components = n_components
self.filter_order = filter_order
[docs]
def fit(self, X, y, **fit_params) -> FilterBankCSP:
"""Fit FilterBank CSP.
Parameters
----------
X : np.ndarray
Multi-channel signals of shape (n_trials, n_channels, n_samples).
y : np.ndarray
Binary class labels.
Returns
-------
self
"""
from scipy.signal import butter, sosfiltfilt
X = np.asarray(X)
y = np.asarray(y)
self.csps_ = []
self.filters_ = []
for low, high in self.bands:
# Design bandpass filter
sos = butter(self.filter_order, [low, high], btype="bandpass", fs=self.fs, output="sos")
self.filters_.append(sos)
# Filter data
X_filtered = np.zeros_like(X)
for i in range(X.shape[0]):
for ch in range(X.shape[1]):
X_filtered[i, ch] = sosfiltfilt(sos, X[i, ch])
# Fit CSP for this band
csp = CSP(n_components=self.n_components)
csp.fit(X_filtered, y)
self.csps_.append(csp)
self._is_fitted = True
return self