Source code for endgame.preprocessing.imbalance_geometric

"""Modern geometric SMOTE extensions for class imbalance handling.

This module provides advanced oversampling methods that extend SMOTE with
geometric and statistical techniques. All methods use only numpy/scipy
(no optional dependencies).

Algorithms
----------
- MultivariateGaussianSMOTE: Local Gaussian sampling around minority points
- SimplicialSMOTE: Simplicial complex-based sampling with Dirichlet coordinates
- CVSMOTEResampler: Cross-validation guided synthetic sample selection
- OverlapRegionDetector: Overlap-aware meta-method for any base sampler

References
----------
- "Do we need rebalancing strategies?" (ICLR 2025)
- Simplicial complex extension of SMOTE (KDD 2025)
- CV-informed SMOTE (ICLR 2025)
- Overlap Region Detection (AAAI 2025)
"""

from __future__ import annotations

from typing import Any

import numpy as np
from numpy.typing import ArrayLike
from sklearn.base import BaseEstimator, clone
from sklearn.utils.validation import check_X_y


def _compute_sampling_targets(
    y: np.ndarray,
    sampling_strategy: str | float | dict = "auto",
) -> dict:
    """Compute number of synthetic samples to generate per class.

    Interprets the ``sampling_strategy`` parameter following imblearn semantics
    and returns a dict ``{class_label: n_to_generate}``.

    Parameters
    ----------
    y : ndarray of shape (n_samples,)
        Target array.
    sampling_strategy : str, float, or dict
        - ``'auto'`` / ``'minority'`` / ``'not majority'``: up-sample relevant
          classes to match the majority count.
        - ``'all'``: up-sample every class to the majority count.
        - float: desired minority-to-majority ratio (0 < r <= 1).
        - dict: ``{class_label: desired_total_count}``.

    Returns
    -------
    targets : dict
        ``{class_label: n_synthetic_to_generate}`` (values >= 0).
    """
    unique, counts = np.unique(y, return_counts=True)
    class_counts = dict(zip(unique, counts))
    max_count = counts.max()
    majority_class = unique[np.argmax(counts)]

    if isinstance(sampling_strategy, dict):
        targets = {}
        for cls, desired in sampling_strategy.items():
            current = class_counts.get(cls, 0)
            targets[cls] = max(0, desired - current)
        return targets

    if isinstance(sampling_strategy, (int, float)) and not isinstance(
        sampling_strategy, bool
    ):
        # Float ratio: desired = ratio * majority_count
        desired = int(round(float(sampling_strategy) * max_count))
        targets = {}
        for cls, cnt in class_counts.items():
            if cls != majority_class and cnt < desired:
                targets[cls] = desired - cnt
        return targets

    # String strategies
    if sampling_strategy in ("auto", "minority", "not majority"):
        targets = {}
        for cls, cnt in class_counts.items():
            if cls != majority_class and cnt < max_count:
                targets[cls] = max_count - cnt
        return targets

    if sampling_strategy == "all":
        targets = {}
        for cls, cnt in class_counts.items():
            if cnt < max_count:
                targets[cls] = max_count - cnt
        return targets

    raise ValueError(
        f"Unknown sampling_strategy: {sampling_strategy!r}. "
        "Expected 'auto', 'minority', 'not majority', 'all', a float, or a dict."
    )


# =============================================================================
# MultivariateGaussianSMOTE
# =============================================================================


[docs] class MultivariateGaussianSMOTE(BaseEstimator): """Multivariate Gaussian SMOTE oversampler. For each minority sample, fits a local multivariate Gaussian from its k-nearest minority neighbours and samples new points from it. Parameters ---------- sampling_strategy : str, float, or dict, default='auto' See :func:`_compute_sampling_targets` for semantics. k_neighbors : int, default=5 Number of nearest minority neighbours for covariance estimation. regularization : float, default=1e-6 Ridge added to the diagonal of local covariance matrices to ensure positive-definiteness. random_state : int or None, default=None Random seed. References ---------- "Do we need rebalancing strategies?" (ICLR 2025) """ def __init__( self, sampling_strategy: str | float | dict = "auto", k_neighbors: int = 5, regularization: float = 1e-6, random_state: int | None = None, ): self.sampling_strategy = sampling_strategy self.k_neighbors = k_neighbors self.regularization = regularization self.random_state = random_state
[docs] def fit(self, X: ArrayLike, y: ArrayLike) -> MultivariateGaussianSMOTE: """Fit the sampler (validates input and computes targets). Parameters ---------- X : array-like of shape (n_samples, n_features) y : array-like of shape (n_samples,) Returns ------- self """ X, y = check_X_y(X, y) self.targets_ = _compute_sampling_targets(y, self.sampling_strategy) return self
[docs] def fit_resample( self, X: ArrayLike, y: ArrayLike ) -> tuple[np.ndarray, np.ndarray]: """Fit and resample the dataset. Parameters ---------- X : array-like of shape (n_samples, n_features) y : array-like of shape (n_samples,) Returns ------- X_resampled : ndarray y_resampled : ndarray """ from scipy.spatial import KDTree self.fit(X, y) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) rng = np.random.RandomState(self.random_state) synthetic_X: list[np.ndarray] = [] synthetic_y: list[np.ndarray] = [] for cls, n_synthetic in self.targets_.items(): if n_synthetic <= 0: continue X_cls = X[y == cls] n_cls = len(X_cls) if n_cls == 0: continue k = min(self.k_neighbors, n_cls - 1) if k < 1: # Only one sample — replicate with small noise noise = rng.randn(n_synthetic, X.shape[1]) * self.regularization synthetic_X.append(np.tile(X_cls[0], (n_synthetic, 1)) + noise) synthetic_y.append(np.full(n_synthetic, cls)) continue tree = KDTree(X_cls) # Distribute samples uniformly across minority points samples_per_point = np.full(n_cls, n_synthetic // n_cls) samples_per_point[: n_synthetic % n_cls] += 1 for i, n_gen in enumerate(samples_per_point): if n_gen == 0: continue # k+1 because query includes the point itself _, nn_idx = tree.query(X_cls[i], k=k + 1) nn_idx = nn_idx[1:] # exclude self neighbours = X_cls[nn_idx] mean = neighbours.mean(axis=0) cov = np.cov(neighbours, rowvar=False) if cov.ndim == 0: cov = np.array([[cov]]) cov += np.eye(cov.shape[0]) * self.regularization samples = rng.multivariate_normal(mean, cov, size=int(n_gen)) synthetic_X.append(samples) synthetic_y.append(np.full(int(n_gen), cls)) if synthetic_X: X_out = np.vstack([X] + synthetic_X) y_out = np.concatenate([y] + synthetic_y) else: X_out, y_out = X.copy(), y.copy() return X_out, y_out
# ============================================================================= # SimplicialSMOTE # =============================================================================
[docs] class SimplicialSMOTE(BaseEstimator): """Simplicial complex SMOTE oversampler. Builds simplicial complexes from the k-NN graph of minority samples and generates new points inside simplices using Dirichlet-distributed barycentric coordinates. Parameters ---------- sampling_strategy : str, float, or dict, default='auto' See :func:`_compute_sampling_targets`. k_neighbors : int, default=5 Number of nearest neighbours for graph construction. simplex_dim : int, default=2 Dimension of the simplices to sample from (2 = triangles, 3 = tetrahedra). Clamped to ``min(simplex_dim, k_neighbors)``. random_state : int or None, default=None Random seed. References ---------- Simplicial complex extension of SMOTE (KDD 2025) """ def __init__( self, sampling_strategy: str | float | dict = "auto", k_neighbors: int = 5, simplex_dim: int = 2, random_state: int | None = None, ): self.sampling_strategy = sampling_strategy self.k_neighbors = k_neighbors self.simplex_dim = simplex_dim self.random_state = random_state
[docs] def fit(self, X: ArrayLike, y: ArrayLike) -> SimplicialSMOTE: """Fit the sampler.""" X, y = check_X_y(X, y) self.targets_ = _compute_sampling_targets(y, self.sampling_strategy) return self
def _build_simplices( self, X_cls: np.ndarray, k: int, simplex_dim: int, rng: np.random.RandomState, ) -> list[np.ndarray]: """Build simplices from k-NN graph. Returns list of index arrays, each of length ``simplex_dim + 1``. """ from scipy.spatial import KDTree n = len(X_cls) tree = KDTree(X_cls) _, nn_indices = tree.query(X_cls, k=min(k + 1, n)) simplices = [] for i in range(n): neighbours = nn_indices[i, 1:] # exclude self if len(neighbours) < simplex_dim: continue # Form simplices from point i and simplex_dim of its neighbours for _ in range(max(1, len(neighbours) // simplex_dim)): chosen = rng.choice(neighbours, size=simplex_dim, replace=False) simplices.append(np.concatenate([[i], chosen])) return simplices
[docs] def fit_resample( self, X: ArrayLike, y: ArrayLike ) -> tuple[np.ndarray, np.ndarray]: """Fit and resample.""" self.fit(X, y) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) rng = np.random.RandomState(self.random_state) synthetic_X: list[np.ndarray] = [] synthetic_y: list[np.ndarray] = [] for cls, n_synthetic in self.targets_.items(): if n_synthetic <= 0: continue X_cls = X[y == cls] n_cls = len(X_cls) if n_cls == 0: continue k = min(self.k_neighbors, n_cls - 1) effective_dim = min(self.simplex_dim, k) if k < 1 or effective_dim < 1: # Fallback: replicate with noise noise = rng.randn(n_synthetic, X.shape[1]) * 1e-6 synthetic_X.append(np.tile(X_cls[0], (n_synthetic, 1)) + noise) synthetic_y.append(np.full(n_synthetic, cls)) continue simplices = self._build_simplices(X_cls, k, effective_dim, rng) if not simplices: continue generated = 0 batch: list[np.ndarray] = [] while generated < n_synthetic: simplex_idx = simplices[rng.randint(len(simplices))] vertices = X_cls[simplex_idx] # (simplex_dim+1, n_features) # Dirichlet-distributed barycentric coordinates weights = rng.dirichlet(np.ones(len(simplex_idx))) point = weights @ vertices batch.append(point) generated += 1 synthetic_X.append(np.array(batch)) synthetic_y.append(np.full(n_synthetic, cls)) if synthetic_X: X_out = np.vstack([X] + synthetic_X) y_out = np.concatenate([y] + synthetic_y) else: X_out, y_out = X.copy(), y.copy() return X_out, y_out
# ============================================================================= # CVSMOTEResampler # =============================================================================
[docs] class CVSMOTEResampler(BaseEstimator): """Cross-validation guided SMOTE oversampler. Generates a pool of candidate synthetic samples via SMOTE-style interpolation, then uses cross-validation to retain only those that improve a scorer metric. Parameters ---------- sampling_strategy : str, float, or dict, default='auto' See :func:`_compute_sampling_targets`. k_neighbors : int, default=5 Nearest neighbours for SMOTE interpolation. cv : int, default=3 Number of cross-validation folds for candidate evaluation. estimator : estimator or None, default=None Classifier used to score candidate batches. Defaults to ``LogisticRegression(max_iter=500)``. scoring : str, default='f1_macro' Scoring metric for cross-validation (sklearn convention). candidate_pool_factor : float, default=2.0 Generate this many times the required synthetic samples as candidates, then keep the best subset. random_state : int or None, default=None Random seed. References ---------- CV-informed SMOTE (ICLR 2025) """ def __init__( self, sampling_strategy: str | float | dict = "auto", k_neighbors: int = 5, cv: int = 3, estimator: Any = None, scoring: str = "f1_macro", candidate_pool_factor: float = 2.0, random_state: int | None = None, ): self.sampling_strategy = sampling_strategy self.k_neighbors = k_neighbors self.cv = cv self.estimator = estimator self.scoring = scoring self.candidate_pool_factor = candidate_pool_factor self.random_state = random_state
[docs] def fit(self, X: ArrayLike, y: ArrayLike) -> CVSMOTEResampler: """Fit the sampler.""" X, y = check_X_y(X, y) self.targets_ = _compute_sampling_targets(y, self.sampling_strategy) return self
@staticmethod def _smote_interpolate( X_cls: np.ndarray, n_synthetic: int, k: int, rng: np.random.RandomState, ) -> np.ndarray: """Generate SMOTE-style interpolated samples.""" from scipy.spatial import KDTree n_cls = len(X_cls) tree = KDTree(X_cls) _, nn_indices = tree.query(X_cls, k=min(k + 1, n_cls)) samples = [] for _ in range(n_synthetic): idx = rng.randint(n_cls) neighbours = nn_indices[idx, 1:] nn = neighbours[rng.randint(len(neighbours))] lam = rng.uniform() samples.append(X_cls[idx] + lam * (X_cls[nn] - X_cls[idx])) return np.array(samples)
[docs] def fit_resample( self, X: ArrayLike, y: ArrayLike ) -> tuple[np.ndarray, np.ndarray]: """Fit and resample.""" from sklearn.linear_model import LogisticRegression from sklearn.model_selection import StratifiedKFold, cross_val_score self.fit(X, y) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) rng = np.random.RandomState(self.random_state) estimator = self.estimator if estimator is None: estimator = LogisticRegression( max_iter=500, random_state=self.random_state ) synthetic_X: list[np.ndarray] = [] synthetic_y: list[np.ndarray] = [] for cls, n_synthetic in self.targets_.items(): if n_synthetic <= 0: continue X_cls = X[y == cls] n_cls = len(X_cls) if n_cls == 0: continue k = min(self.k_neighbors, n_cls - 1) if k < 1: synthetic_X.append(np.tile(X_cls[0], (n_synthetic, 1))) synthetic_y.append(np.full(n_synthetic, cls)) continue # Generate candidate pool n_candidates = int(n_synthetic * self.candidate_pool_factor) candidates = self._smote_interpolate(X_cls, n_candidates, k, rng) # Score each candidate by adding it and running CV # For efficiency, score in batches rather than one-by-one batch_size = max(1, n_synthetic // 5) rng.shuffle(candidates) # Baseline score cv_split = StratifiedKFold( n_splits=min(self.cv, min(np.bincount(y.astype(int)))), shuffle=True, random_state=self.random_state, ) baseline = np.mean( cross_val_score( clone(estimator), X, y, cv=cv_split, scoring=self.scoring ) ) best_candidates = [] for start in range(0, len(candidates), batch_size): batch = candidates[start : start + batch_size] X_aug = np.vstack([X] + [np.array(c).reshape(1, -1) for c in best_candidates] + [batch]) y_aug = np.concatenate( [y] + [np.full(len(best_candidates), cls)] + [np.full(len(batch), cls)] ) cv_split_aug = StratifiedKFold( n_splits=min(self.cv, min(np.bincount(y_aug.astype(int)))), shuffle=True, random_state=self.random_state, ) score = np.mean( cross_val_score( clone(estimator), X_aug, y_aug, cv=cv_split_aug, scoring=self.scoring, ) ) if score >= baseline: best_candidates.extend(batch) baseline = score if len(best_candidates) >= n_synthetic: break # If we didn't get enough from CV selection, fill with remaining if len(best_candidates) < n_synthetic: shortfall = n_synthetic - len(best_candidates) extra = self._smote_interpolate(X_cls, shortfall, k, rng) best_candidates.extend(extra) final = np.array(best_candidates[:n_synthetic]) synthetic_X.append(final) synthetic_y.append(np.full(n_synthetic, cls)) if synthetic_X: X_out = np.vstack([X] + synthetic_X) y_out = np.concatenate([y] + synthetic_y) else: X_out, y_out = X.copy(), y.copy() return X_out, y_out
# ============================================================================= # OverlapRegionDetector # =============================================================================
[docs] class OverlapRegionDetector(BaseEstimator): """Overlap Region Detection meta-method for class imbalance. Identifies the overlap region between classes using classifier uncertainty, then applies a base sampler with overlap awareness. Algorithm --------- 1. Train a classifier to get predicted probabilities. 2. Samples with high uncertainty (max prob < 1 - threshold) are labelled as "overlap". 3. Apply the base sampler on the augmented label space. 4. Map generated samples back to original labels. Parameters ---------- sampling_strategy : str, float, or dict, default='auto' See :func:`_compute_sampling_targets`. base_sampler : str or estimator, default='smote' Base oversampling method. If a string, looked up in the combined sampler registries. Otherwise must support ``fit_resample(X, y)``. overlap_estimator : estimator or None, default=None Classifier for overlap detection. Defaults to ``RandomForestClassifier(n_estimators=100)``. k_neighbors : int, default=5 Passed to base sampler when constructed from string. threshold : float, default=0.3 Uncertainty threshold: a sample is in the overlap region if ``max(predicted_proba) < 1 - threshold``. random_state : int or None, default=None Random seed. References ---------- Overlap Region Detection (AAAI 2025) """ def __init__( self, sampling_strategy: str | float | dict = "auto", base_sampler: str | Any = "smote", overlap_estimator: Any = None, k_neighbors: int = 5, threshold: float = 0.3, random_state: int | None = None, ): self.sampling_strategy = sampling_strategy self.base_sampler = base_sampler self.overlap_estimator = overlap_estimator self.k_neighbors = k_neighbors self.threshold = threshold self.random_state = random_state
[docs] def fit(self, X: ArrayLike, y: ArrayLike) -> OverlapRegionDetector: """Fit the sampler.""" X, y = check_X_y(X, y) self.targets_ = _compute_sampling_targets(y, self.sampling_strategy) return self
def _get_base_sampler(self): """Resolve the base sampler from string or return as-is.""" if isinstance(self.base_sampler, str): from endgame.preprocessing.imbalance import ALL_SAMPLERS key = self.base_sampler if key not in ALL_SAMPLERS: raise ValueError( f"Unknown base_sampler '{key}'. " f"Available: {list(ALL_SAMPLERS.keys())}" ) SamplerClass = ALL_SAMPLERS[key] import inspect sig = inspect.signature(SamplerClass.__init__) params: dict[str, Any] = {} if "random_state" in sig.parameters: params["random_state"] = self.random_state if "k_neighbors" in sig.parameters: params["k_neighbors"] = self.k_neighbors if "sampling_strategy" in sig.parameters: params["sampling_strategy"] = "auto" return SamplerClass(**params) return clone(self.base_sampler)
[docs] def fit_resample( self, X: ArrayLike, y: ArrayLike ) -> tuple[np.ndarray, np.ndarray]: """Fit and resample with overlap awareness.""" from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_predict self.fit(X, y) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) # Step 1: Detect overlap region overlap_est = self.overlap_estimator if overlap_est is None: overlap_est = RandomForestClassifier( n_estimators=100, random_state=self.random_state, n_jobs=-1 ) # Use cross-validated predictions to avoid overfitting n_cv = min(5, min(np.bincount(y.astype(int)))) n_cv = max(2, n_cv) proba = cross_val_predict( clone(overlap_est), X, y, cv=n_cv, method="predict_proba" ) # Step 2: Identify overlap samples (high uncertainty) max_proba = proba.max(axis=1) overlap_mask = max_proba < (1.0 - self.threshold) self.n_overlap_ = int(overlap_mask.sum()) # Step 3: Create augmented labels # Use a sentinel value that won't collide with existing classes unique_classes = np.unique(y) overlap_label = unique_classes.max() + 1 if len(unique_classes) > 0 else -1 y_augmented = y.copy() y_augmented[overlap_mask] = overlap_label # Step 4: Apply base sampler base = self._get_base_sampler() # If all minority samples are in the overlap region, fall back # to standard resampling on original labels unique_aug = np.unique(y_augmented) min_class_counts = [] for cls in unique_aug: min_class_counts.append(np.sum(y_augmented == cls)) if min(min_class_counts) < 2: # Not enough samples per augmented class; fall back base_fallback = self._get_base_sampler() # Override sampling strategy to 'auto' on original labels return base_fallback.fit_resample(X, y) try: X_res, y_res_aug = base.fit_resample(X, y_augmented) except Exception: # Fallback on any error with augmented labels base_fallback = self._get_base_sampler() return base_fallback.fit_resample(X, y) # Step 5: Map overlap label back to original minority classes # Assign overlap-generated samples to the nearest original class overlap_generated = y_res_aug == overlap_label if overlap_generated.any(): # For overlap-generated samples, assign to nearest original # minority class centre from scipy.spatial import KDTree # Build centres for original classes class_centres = {} for cls in unique_classes: mask = y == cls if mask.any(): class_centres[cls] = X[mask].mean(axis=0) if class_centres: centre_labels = list(class_centres.keys()) centre_points = np.array( [class_centres[c] for c in centre_labels] ) tree = KDTree(centre_points) _, nearest_idx = tree.query(X_res[overlap_generated]) y_res_aug[overlap_generated] = np.array(centre_labels)[ nearest_idx ] return X_res, y_res_aug
# Category dict for registration GEOMETRIC_SAMPLERS = { "multivariate_gaussian_smote": MultivariateGaussianSMOTE, "simplicial_smote": SimplicialSMOTE, "cv_smote": CVSMOTEResampler, "overlap_region_detector": OverlapRegionDetector, }