Source code for endgame.signal.connectivity
from __future__ import annotations
"""Connectivity and EEG-specific feature extraction.
Provides sklearn-compatible feature extractors for:
- Coherence between channels
- Phase-locking value (PLV)
- Cross-correlation
- Burst and suppression detection
- Spike detection
These measures quantify relationships between channels and
detect specific patterns in biosignals.
References
----------
- Nunez & Srinivasan (2006): EEG coherence
- Lachaux et al. (1999): Phase-locking value
- Steriade et al. (1994): Burst-suppression patterns
"""
import numpy as np
from scipy import signal as scipy_signal
from endgame.signal.base import (
BaseFeatureExtractor,
ensure_2d_signals,
)
[docs]
def coherence(
x: np.ndarray,
y: np.ndarray,
fs: float,
nperseg: int | None = None,
noverlap: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute magnitude-squared coherence between two signals.
Parameters
----------
x, y : np.ndarray
Input signals.
fs : float
Sample rate in Hz.
nperseg : int, optional
Segment length.
noverlap : int, optional
Overlap between segments.
Returns
-------
freqs : np.ndarray
Frequency bins.
coh : np.ndarray
Coherence values (0-1).
"""
nperseg = nperseg if nperseg is not None else min(256, len(x))
freqs, coh = scipy_signal.coherence(x, y, fs=fs, nperseg=nperseg, noverlap=noverlap)
return freqs, coh
[docs]
def phase_locking_value(
x: np.ndarray,
y: np.ndarray,
fs: float,
band: tuple[float, float] | None = None,
) -> float:
"""Compute phase-locking value between two signals.
PLV measures the consistency of phase difference between signals.
Parameters
----------
x, y : np.ndarray
Input signals.
fs : float
Sample rate in Hz.
band : tuple, optional
Frequency band (low, high) to filter before computing PLV.
Returns
-------
float
Phase-locking value (0-1).
References
----------
Lachaux, J. P., et al. (1999). Measuring phase synchrony in brain signals.
Human brain mapping, 8(4), 194-208.
"""
from scipy.signal import butter, hilbert, sosfiltfilt
# Optionally bandpass filter
if band is not None:
sos = butter(4, band, btype="bandpass", fs=fs, output="sos")
x = sosfiltfilt(sos, x)
y = sosfiltfilt(sos, y)
# Compute instantaneous phase using Hilbert transform
phase_x = np.angle(hilbert(x))
phase_y = np.angle(hilbert(y))
# Phase difference
phase_diff = phase_x - phase_y
# PLV is the magnitude of the mean phase difference vector
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
return plv
[docs]
def cross_correlation(
x: np.ndarray,
y: np.ndarray,
max_lag: int | None = None,
normalize: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute cross-correlation between two signals.
Parameters
----------
x, y : np.ndarray
Input signals.
max_lag : int, optional
Maximum lag. If None, uses len(x) - 1.
normalize : bool, default=True
Normalize to correlation coefficient (-1 to 1).
Returns
-------
lags : np.ndarray
Lag values.
corr : np.ndarray
Cross-correlation values.
"""
n = len(x)
if max_lag is None:
max_lag = n - 1
# Full cross-correlation
corr = np.correlate(x - np.mean(x), y - np.mean(y), mode="full")
if normalize:
corr = corr / (n * np.std(x) * np.std(y))
# Extract relevant lags
mid = len(corr) // 2
lags = np.arange(-max_lag, max_lag + 1)
corr = corr[mid - max_lag : mid + max_lag + 1]
return lags, corr
[docs]
def detect_bursts(
x: np.ndarray,
fs: float,
threshold_std: float = 2.0,
min_duration_ms: float = 100.0,
) -> list[tuple[int, int]]:
"""Detect burst periods in a signal.
Bursts are periods where signal amplitude exceeds a threshold.
Parameters
----------
x : np.ndarray
Input signal.
fs : float
Sample rate in Hz.
threshold_std : float, default=2.0
Threshold in standard deviations above mean.
min_duration_ms : float, default=100.0
Minimum burst duration in milliseconds.
Returns
-------
list of (start, end) tuples
Burst intervals as sample indices.
"""
# Compute envelope using Hilbert transform
envelope = np.abs(scipy_signal.hilbert(x))
# Threshold
threshold = np.mean(envelope) + threshold_std * np.std(envelope)
# Find above-threshold regions
above = envelope > threshold
# Find burst boundaries
diff = np.diff(above.astype(int))
starts = np.where(diff == 1)[0] + 1
ends = np.where(diff == -1)[0] + 1
# Handle edge cases
if above[0]:
starts = np.concatenate([[0], starts])
if above[-1]:
ends = np.concatenate([ends, [len(x)]])
# Filter by minimum duration
min_samples = int(min_duration_ms * fs / 1000)
bursts = []
for start, end in zip(starts, ends):
if end - start >= min_samples:
bursts.append((start, end))
return bursts
[docs]
def detect_suppressions(
x: np.ndarray,
fs: float,
threshold_uv: float = 10.0,
min_duration_ms: float = 500.0,
) -> list[tuple[int, int]]:
"""Detect suppression periods in a signal.
Suppressions are periods of low amplitude activity.
Parameters
----------
x : np.ndarray
Input signal (assumed to be in microvolts).
fs : float
Sample rate in Hz.
threshold_uv : float, default=10.0
Amplitude threshold in microvolts.
min_duration_ms : float, default=500.0
Minimum suppression duration in milliseconds.
Returns
-------
list of (start, end) tuples
Suppression intervals as sample indices.
"""
# Compute envelope
envelope = np.abs(scipy_signal.hilbert(x))
# Find below-threshold regions
below = envelope < threshold_uv
# Find suppression boundaries
diff = np.diff(below.astype(int))
starts = np.where(diff == 1)[0] + 1
ends = np.where(diff == -1)[0] + 1
# Handle edge cases
if below[0]:
starts = np.concatenate([[0], starts])
if below[-1]:
ends = np.concatenate([ends, [len(x)]])
# Filter by minimum duration
min_samples = int(min_duration_ms * fs / 1000)
suppressions = []
for start, end in zip(starts, ends):
if end - start >= min_samples:
suppressions.append((start, end))
return suppressions
[docs]
def detect_spikes(
x: np.ndarray,
fs: float,
threshold_std: float = 3.0,
max_duration_ms: float = 70.0,
min_duration_ms: float = 20.0,
) -> list[int]:
"""Detect spikes in a signal.
Spikes are brief, high-amplitude transients.
Parameters
----------
x : np.ndarray
Input signal.
fs : float
Sample rate in Hz.
threshold_std : float, default=3.0
Threshold in standard deviations.
max_duration_ms : float, default=70.0
Maximum spike duration in milliseconds.
min_duration_ms : float, default=20.0
Minimum spike duration in milliseconds.
Returns
-------
list of int
Spike peak locations as sample indices.
"""
# Threshold
threshold = np.mean(x) + threshold_std * np.std(x)
# Find peaks above threshold
peaks, properties = scipy_signal.find_peaks(
np.abs(x),
height=threshold,
distance=int(min_duration_ms * fs / 1000),
)
# Filter by width
max_samples = int(max_duration_ms * fs / 1000)
min_samples = int(min_duration_ms * fs / 1000)
widths, _, _, _ = scipy_signal.peak_widths(np.abs(x), peaks)
valid_spikes = []
for peak, width in zip(peaks, widths):
if min_samples <= width <= max_samples:
valid_spikes.append(peak)
return valid_spikes
[docs]
class CoherenceFeatureExtractor(BaseFeatureExtractor):
"""Extract coherence features between all channel pairs.
Parameters
----------
fs : float
Sample rate in Hz.
bands : dict, optional
Frequency bands for band-averaged coherence.
Default includes standard EEG bands.
nperseg : int, optional
Segment length for coherence estimation.
Examples
--------
>>> coh = CoherenceFeatureExtractor(fs=256)
>>> features = coh.fit_transform(X) # X: (n_trials, n_channels, n_samples)
"""
DEFAULT_BANDS = {
"delta": (0.5, 4),
"theta": (4, 8),
"alpha": (8, 13),
"beta": (13, 30),
"gamma": (30, 50),
}
def __init__(
self,
fs: float,
bands: dict[str, tuple[float, float]] | None = None,
nperseg: int | None = None,
):
super().__init__(fs=fs)
self.bands = bands if bands is not None else self.DEFAULT_BANDS.copy()
self.nperseg = nperseg
[docs]
def fit(self, X, y=None, **fit_params) -> CoherenceFeatureExtractor:
"""Fit the extractor.
Parameters
----------
X : np.ndarray
Multi-channel signals of shape (n_trials, n_channels, n_samples).
y : ignored
Returns
-------
self
"""
X = np.asarray(X)
if X.ndim != 3:
raise ValueError(f"Expected 3D array, got shape {X.shape}")
self.n_channels_ = X.shape[1]
# Number of channel pairs
n_pairs = self.n_channels_ * (self.n_channels_ - 1) // 2
# Feature names
self.feature_names_ = []
self._channel_pairs = []
for i in range(self.n_channels_):
for j in range(i + 1, self.n_channels_):
self._channel_pairs.append((i, j))
for band_name in self.bands:
self.feature_names_.append(f"coh_{i}_{j}_{band_name}")
self._is_fitted = True
return self
[docs]
def transform(self, X) -> np.ndarray:
"""Extract coherence features.
Parameters
----------
X : np.ndarray
Multi-channel signals of shape (n_trials, n_channels, n_samples).
Returns
-------
np.ndarray
Coherence features of shape (n_trials, n_features).
"""
X = np.asarray(X)
n_trials = X.shape[0]
all_features = []
for trial in range(n_trials):
features = []
for i, j in self._channel_pairs:
freqs, coh = coherence(
X[trial, i], X[trial, j], self.fs, self.nperseg
)
# Band-averaged coherence
for band_name, (low, high) in self.bands.items():
mask = (freqs >= low) & (freqs <= high)
band_coh = np.mean(coh[mask]) if np.any(mask) else 0.0
features.append(band_coh)
all_features.append(features)
return np.array(all_features)
[docs]
class PLVFeatureExtractor(BaseFeatureExtractor):
"""Extract phase-locking value features between channel pairs.
Parameters
----------
fs : float
Sample rate in Hz.
bands : dict, optional
Frequency bands for band-specific PLV.
Examples
--------
>>> plv = PLVFeatureExtractor(fs=256)
>>> features = plv.fit_transform(X)
"""
DEFAULT_BANDS = {
"theta": (4, 8),
"alpha": (8, 13),
"beta": (13, 30),
}
def __init__(
self,
fs: float,
bands: dict[str, tuple[float, float]] | None = None,
):
super().__init__(fs=fs)
self.bands = bands if bands is not None else self.DEFAULT_BANDS.copy()
[docs]
def fit(self, X, y=None, **fit_params) -> PLVFeatureExtractor:
X = np.asarray(X)
if X.ndim != 3:
raise ValueError(f"Expected 3D array, got shape {X.shape}")
self.n_channels_ = X.shape[1]
self.feature_names_ = []
self._channel_pairs = []
for i in range(self.n_channels_):
for j in range(i + 1, self.n_channels_):
self._channel_pairs.append((i, j))
for band_name in self.bands:
self.feature_names_.append(f"plv_{i}_{j}_{band_name}")
self._is_fitted = True
return self
[docs]
def transform(self, X) -> np.ndarray:
X = np.asarray(X)
n_trials = X.shape[0]
all_features = []
for trial in range(n_trials):
features = []
for i, j in self._channel_pairs:
for band_name, band in self.bands.items():
plv = phase_locking_value(
X[trial, i], X[trial, j], self.fs, band
)
features.append(plv)
all_features.append(features)
return np.array(all_features)
[docs]
class BurstSuppressionFeatures(BaseFeatureExtractor):
"""Extract burst-suppression features from signals.
Commonly used in EEG analysis for detecting anesthesia depth
and certain pathological states.
Parameters
----------
fs : float
Sample rate in Hz.
burst_threshold_std : float, default=2.0
Threshold for burst detection.
suppression_threshold : float, default=10.0
Threshold for suppression detection.
min_burst_ms : float, default=100.0
Minimum burst duration.
min_suppression_ms : float, default=500.0
Minimum suppression duration.
Examples
--------
>>> bsr = BurstSuppressionFeatures(fs=256)
>>> features = bsr.fit_transform(signals)
"""
def __init__(
self,
fs: float,
burst_threshold_std: float = 2.0,
suppression_threshold: float = 10.0,
min_burst_ms: float = 100.0,
min_suppression_ms: float = 500.0,
):
super().__init__(fs=fs)
self.burst_threshold_std = burst_threshold_std
self.suppression_threshold = suppression_threshold
self.min_burst_ms = min_burst_ms
self.min_suppression_ms = min_suppression_ms
[docs]
def fit(self, X, y=None, **fit_params) -> BurstSuppressionFeatures:
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
self.feature_names_ = [
"n_bursts",
"burst_rate",
"mean_burst_duration",
"std_burst_duration",
"n_suppressions",
"suppression_rate",
"mean_suppression_duration",
"std_suppression_duration",
"burst_suppression_ratio",
]
return self
[docs]
def transform(self, X) -> np.ndarray:
X = self._validate_signal(X)
X_2d, was_1d, original_shape = ensure_2d_signals(X)
all_features = []
for sig in X_2d:
# Detect bursts
bursts = detect_bursts(
sig, self.fs, self.burst_threshold_std, self.min_burst_ms
)
# Detect suppressions
suppressions = detect_suppressions(
sig, self.fs, self.suppression_threshold, self.min_suppression_ms
)
# Compute features
n_samples = len(sig)
duration_sec = n_samples / self.fs
# Burst statistics
n_bursts = len(bursts)
burst_rate = n_bursts / duration_sec if duration_sec > 0 else 0
if n_bursts > 0:
burst_durations = [(end - start) / self.fs * 1000 for start, end in bursts]
mean_burst_dur = np.mean(burst_durations)
std_burst_dur = np.std(burst_durations) if n_bursts > 1 else 0
total_burst_samples = sum(end - start for start, end in bursts)
else:
mean_burst_dur = 0
std_burst_dur = 0
total_burst_samples = 0
# Suppression statistics
n_suppressions = len(suppressions)
suppression_rate = n_suppressions / duration_sec if duration_sec > 0 else 0
if n_suppressions > 0:
supp_durations = [(end - start) / self.fs * 1000 for start, end in suppressions]
mean_supp_dur = np.mean(supp_durations)
std_supp_dur = np.std(supp_durations) if n_suppressions > 1 else 0
total_supp_samples = sum(end - start for start, end in suppressions)
else:
mean_supp_dur = 0
std_supp_dur = 0
total_supp_samples = 0
# Burst-suppression ratio
bsr = total_supp_samples / n_samples if n_samples > 0 else 0
features = [
n_bursts,
burst_rate,
mean_burst_dur,
std_burst_dur,
n_suppressions,
suppression_rate,
mean_supp_dur,
std_supp_dur,
bsr,
]
all_features.append(features)
return np.array(all_features)
[docs]
class SpikeFeatures(BaseFeatureExtractor):
"""Extract spike-related features from signals.
Parameters
----------
fs : float
Sample rate in Hz.
threshold_std : float, default=3.0
Detection threshold in standard deviations.
max_duration_ms : float, default=70.0
Maximum spike duration.
Examples
--------
>>> spike_feat = SpikeFeatures(fs=256)
>>> features = spike_feat.fit_transform(signals)
"""
def __init__(
self,
fs: float,
threshold_std: float = 3.0,
max_duration_ms: float = 70.0,
):
super().__init__(fs=fs)
self.threshold_std = threshold_std
self.max_duration_ms = max_duration_ms
[docs]
def fit(self, X, y=None, **fit_params) -> SpikeFeatures:
X = self._validate_signal(X)
super().fit(X, y, **fit_params)
self.feature_names_ = [
"n_spikes",
"spike_rate",
"mean_spike_amplitude",
"std_spike_amplitude",
"mean_spike_interval",
]
return self
[docs]
def transform(self, X) -> np.ndarray:
X = self._validate_signal(X)
X_2d, was_1d, original_shape = ensure_2d_signals(X)
all_features = []
for sig in X_2d:
spikes = detect_spikes(sig, self.fs, self.threshold_std, self.max_duration_ms)
n_samples = len(sig)
duration_sec = n_samples / self.fs
n_spikes = len(spikes)
spike_rate = n_spikes / duration_sec if duration_sec > 0 else 0
if n_spikes > 0:
amplitudes = np.abs(sig[spikes])
mean_amp = np.mean(amplitudes)
std_amp = np.std(amplitudes) if n_spikes > 1 else 0
if n_spikes > 1:
intervals = np.diff(spikes) / self.fs * 1000 # in ms
mean_interval = np.mean(intervals)
else:
mean_interval = 0
else:
mean_amp = 0
std_amp = 0
mean_interval = 0
features = [n_spikes, spike_rate, mean_amp, std_amp, mean_interval]
all_features.append(features)
return np.array(all_features)
[docs]
class ConnectivityFeatureExtractor(BaseFeatureExtractor):
"""Comprehensive connectivity feature extractor.
Combines coherence, PLV, and cross-correlation features.
Parameters
----------
fs : float
Sample rate in Hz.
include_coherence : bool, default=True
Include coherence features.
include_plv : bool, default=True
Include PLV features.
bands : dict, optional
Frequency bands for analysis.
Examples
--------
>>> conn = ConnectivityFeatureExtractor(fs=256)
>>> features = conn.fit_transform(X)
"""
DEFAULT_BANDS = {
"theta": (4, 8),
"alpha": (8, 13),
"beta": (13, 30),
}
def __init__(
self,
fs: float,
include_coherence: bool = True,
include_plv: bool = True,
bands: dict[str, tuple[float, float]] | None = None,
):
super().__init__(fs=fs)
self.include_coherence = include_coherence
self.include_plv = include_plv
self.bands = bands if bands is not None else self.DEFAULT_BANDS.copy()
[docs]
def fit(self, X, y=None, **fit_params) -> ConnectivityFeatureExtractor:
X = np.asarray(X)
if X.ndim != 3:
raise ValueError(f"Expected 3D array, got shape {X.shape}")
self._extractors = []
self.feature_names_ = []
if self.include_coherence:
coh = CoherenceFeatureExtractor(fs=self.fs, bands=self.bands)
coh.fit(X)
self._extractors.append(coh)
self.feature_names_.extend(coh.feature_names_)
if self.include_plv:
plv = PLVFeatureExtractor(fs=self.fs, bands=self.bands)
plv.fit(X)
self._extractors.append(plv)
self.feature_names_.extend(plv.feature_names_)
self._is_fitted = True
return self
[docs]
def transform(self, X) -> np.ndarray:
X = np.asarray(X)
feature_arrays = []
for ext in self._extractors:
feats = ext.transform(X)
feature_arrays.append(feats)
return np.hstack(feature_arrays)