Source code for endgame.models.linear.mars

"""Multivariate Adaptive Regression Splines (MARS) Implementation.

MARS builds piecewise linear regression models by automatically discovering knots
(thresholds) in the data where the relationship between features and target changes
slope. It was invented by Jerome Friedman in 1991.

This module provides:
- MARSRegressor: MARS for regression tasks
- MARSClassifier: MARS for classification via logistic regression on basis functions

References
----------
Friedman, J. (1991). Multivariate adaptive regression splines.
The Annals of Statistics, 19(1), 1-67.

Friedman, J. (1993). Fast MARS. Stanford University Technical Report 110.

Milborrow, S. Earth package vignette (R implementation reference).
"""

from __future__ import annotations

import numpy as np
from numpy.typing import ArrayLike, NDArray
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.linear_model import LogisticRegression
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

from endgame.models.linear.basis import BasisFunction, HingeSpec, LinearBasisFunction


[docs] class MARSRegressor(BaseEstimator, RegressorMixin): """Multivariate Adaptive Regression Splines for regression. MARS builds a piecewise linear model by discovering knots (thresholds) where the relationship between features and target changes. The model is an additive combination of hinge functions: max(0, x - knot) and max(0, knot - x). Parameters ---------- max_terms : int, default=None Maximum number of basis functions (including intercept). If None, defaults to min(100, max(20, 2 * n_features)) + 1. max_degree : int, default=1 Maximum degree of interactions. 1 = additive model (no interactions) 2 = pairwise interactions allowed 3 = three-way interactions allowed penalty : float, default=3.0 Generalized Cross-Validation (GCV) penalty per knot. Higher values produce simpler models. Typical range: 2-4. thresh : float, default=0.001 Forward pass stopping threshold. Stops adding terms when R^2 improvement falls below this value. min_span : int, default=None Minimum number of observations between knots. If None, automatically calculated based on data size. endspan : int, default=None Minimum observations before first knot and after last knot. If None, automatically calculated based on data size. fast_k : int, default=20 In the forward pass, only consider the best fast_k parent terms when searching for new basis functions. Set to 0 to consider all parents (slower but potentially better). This is "Fast MARS." feature_names : list of str, default=None Names for features (used in summary output). allow_linear : bool, default=True If True, allows linear terms (no hinge) for features that appear to have purely linear relationships. Attributes ---------- basis_functions_ : list of BasisFunction The selected basis functions after pruning. coef_ : ndarray of shape (n_basis_functions,) Coefficients for each basis function. intercept_ : float The intercept (coefficient of the constant basis function). n_features_in_ : int Number of features seen during fit. feature_names_in_ : ndarray of shape (n_features_in_,) Names of features seen during fit. gcv_ : float Generalized Cross-Validation score of the final model. rsq_ : float R-squared of the final model on training data. forward_pass_record_ : list Record of terms added during forward pass (for diagnostics). pruning_record_ : list Record of pruning decisions (for diagnostics). Examples -------- >>> from endgame.models import MARSRegressor >>> import numpy as np >>> X = np.random.randn(100, 3) >>> y = np.maximum(0, X[:, 0] - 0.5) + 2 * X[:, 1] + np.random.randn(100) * 0.1 >>> model = MARSRegressor(max_degree=1) >>> model.fit(X, y) MARSRegressor(max_degree=1) >>> print(model.summary()) # doctest: +SKIP >>> predictions = model.predict(X) References ---------- Friedman, J. (1991). Multivariate adaptive regression splines. The Annals of Statistics, 19(1), 1-67. Friedman, J. (1993). Fast MARS. Stanford University Technical Report 110. Milborrow, S. Earth package vignette (R implementation reference). """ def __init__( self, max_terms: int | None = None, max_degree: int = 1, penalty: float = 3.0, thresh: float = 0.001, min_span: int | None = None, endspan: int | None = None, fast_k: int = 20, feature_names: list[str] | None = None, allow_linear: bool = True, ): self.max_terms = max_terms self.max_degree = max_degree self.penalty = penalty self.thresh = thresh self.min_span = min_span self.endspan = endspan self.fast_k = fast_k self.feature_names = feature_names self.allow_linear = allow_linear
[docs] def fit( self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None, ) -> MARSRegressor: """Fit the MARS model. Parameters ---------- X : array-like of shape (n_samples, n_features) Training data. y : array-like of shape (n_samples,) Target values. sample_weight : array-like of shape (n_samples,), default=None Individual weights for each sample. Returns ------- self : object Fitted estimator. """ if self.feature_names is None and hasattr(X, "columns"): self.feature_names = list(X.columns) # Validate inputs X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True) n_samples, n_features = X.shape # Store feature names self.n_features_in_ = n_features if self.feature_names is not None: self.feature_names_in_ = np.array(self.feature_names) else: self.feature_names_in_ = np.array([f"x{i}" for i in range(n_features)]) # Handle sample weights if sample_weight is not None: sample_weight = np.asarray(sample_weight, dtype=np.float64) if sample_weight.shape[0] != n_samples: raise ValueError( f"sample_weight has {sample_weight.shape[0]} samples, " f"expected {n_samples}" ) # Normalize weights to sum to n_samples sample_weight = sample_weight * (n_samples / np.sum(sample_weight)) else: sample_weight = np.ones(n_samples, dtype=np.float64) # Store references for variable importance (no copy needed — fit # owns these arrays and they're not mutated after this point). self._X_train = X self._y_train = y self._sample_weight = sample_weight # Calculate default parameters self._max_terms = self.max_terms if self._max_terms is None: self._max_terms = min(100, max(20, 2 * n_features)) + 1 self._min_span = self.min_span if self._min_span is None: # Friedman's recommendation: aim for ~20-50 knots per feature self._min_span = max(1, int(n_samples / 50)) self._endspan = self.endspan if self._endspan is None: self._endspan = self._min_span # Check for constant features self._feature_variance = np.var(X, axis=0) self._valid_features = self._feature_variance > 1e-10 # Forward pass: build model greedily self.forward_pass_record_ = [] forward_basis = self._forward_pass(X, y, sample_weight) # Backward pass: prune model using GCV self.pruning_record_ = [] self.basis_functions_ = self._backward_pass(forward_basis, X, y, sample_weight) # Fit final coefficients B = self.get_basis_matrix(X) self.coef_, self._rss = self._fit_coefficients(B, y, sample_weight) # Store intercept separately (coefficient of first basis function) self.intercept_ = self.coef_[0] # Compute final statistics self.gcv_ = self._compute_gcv(self.basis_functions_, X, y, sample_weight) total_ss = np.sum(sample_weight * (y - np.average(y, weights=sample_weight)) ** 2) self.rsq_ = 1.0 - self._rss / total_ss if total_ss > 0 else 0.0 return self
[docs] def predict(self, X: ArrayLike) -> NDArray[np.floating]: """Predict using the fitted MARS model. Parameters ---------- X : array-like of shape (n_samples, n_features) Samples to predict. Returns ------- y_pred : ndarray of shape (n_samples,) Predicted values. """ check_is_fitted(self) X = check_array(X, dtype=np.float64) if X.shape[1] != self.n_features_in_: raise ValueError( f"X has {X.shape[1]} features, expected {self.n_features_in_}" ) B = self.get_basis_matrix(X) return B @ self.coef_
[docs] def get_basis_matrix(self, X: ArrayLike) -> NDArray[np.floating]: """Compute the basis function matrix for given X. Parameters ---------- X : array-like of shape (n_samples, n_features) Input data. Returns ------- B : ndarray of shape (n_samples, n_basis_functions) Matrix where B[i, j] is the value of basis function j evaluated at sample i. """ X = np.asarray(X, dtype=np.float64) n_samples = X.shape[0] n_basis = len(self.basis_functions_) B = np.empty((n_samples, n_basis), dtype=np.float64) for j, bf in enumerate(self.basis_functions_): B[:, j] = bf.evaluate(X) return B
[docs] def summary(self) -> str: """Return a human-readable summary of the model. Returns a string showing: - Model equation with all basis functions - R^2 and GCV statistics - Variable importance Returns ------- summary : str Formatted model summary. """ check_is_fitted(self) lines = [] lines.append("MARS Model Summary") lines.append("=" * 50) lines.append("") # Call info params = [] if self.max_terms is not None: params.append(f"max_terms={self._max_terms}") params.append(f"max_degree={self.max_degree}") params.append(f"penalty={self.penalty}") lines.append(f"Call: MARSRegressor({', '.join(params)})") lines.append("") # Basis functions lines.append(f"Basis Functions ({len(self.basis_functions_)} terms):") lines.append("-" * 40) for i, bf in enumerate(self.basis_functions_): if bf.is_intercept: name = "Intercept" else: name = bf.to_str_with_names(list(self.feature_names_in_)) coef = self.coef_[i] lines.append(f" {name:40s} {coef:10.4f}") lines.append("") # Model statistics lines.append("Model Statistics:") lines.append("-" * 40) lines.append(f" GCV: {self.gcv_:.4f}") lines.append(f" RSS: {self._rss:.4f}") lines.append(f" R-squared: {self.rsq_:.4f}") # Adjusted R-squared n = len(self._y_train) p = len(self.basis_functions_) if n > p: adj_rsq = 1.0 - (1.0 - self.rsq_) * (n - 1) / (n - p) lines.append(f" Adj R-sq: {adj_rsq:.4f}") lines.append("") # Variable importance importance = self.compute_variable_importance() if importance: lines.append("Variable Importance:") lines.append("-" * 40) sorted_imp = sorted(importance.items(), key=lambda x: -x[1]) for name, imp in sorted_imp: if imp > 0: lines.append(f" {name:20s} {imp:6.1f}") lines.append("") # Equation lines.append("Equation:") lines.append("-" * 40) eq_parts = [] for i, bf in enumerate(self.basis_functions_): coef = self.coef_[i] if bf.is_intercept: eq_parts.append(f"{coef:.4g}") else: bf_str = bf.to_str_with_names(list(self.feature_names_in_)) if coef >= 0: eq_parts.append(f" + {coef:.4g}*{bf_str}") else: eq_parts.append(f" - {abs(coef):.4g}*{bf_str}") equation = "y = " + "".join(eq_parts) # Wrap long equations if len(equation) > 70: wrapped = ["y = " + eq_parts[0]] current_line = " " for part in eq_parts[1:]: if len(current_line) + len(part) > 66: wrapped.append(current_line) current_line = " " + part else: current_line += part if current_line.strip(): wrapped.append(current_line) equation = "\n".join(wrapped) lines.append(equation) return "\n".join(lines)
[docs] def compute_variable_importance(self) -> dict[str, float]: """Compute variable importance based on GCV decrease. For each variable, compute how much GCV would increase if all basis functions involving that variable were removed. Returns ------- importance : dict {feature_name: importance_score} Scores are normalized so max = 100. """ check_is_fitted(self) importances = {} for feature_idx in range(self.n_features_in_): # Find all basis functions using this feature bf_using_feature = [ i for i, bf in enumerate(self.basis_functions_) if feature_idx in bf.feature_indices ] if not bf_using_feature: importances[feature_idx] = 0.0 continue # Compute GCV without these basis functions subset = [ bf for i, bf in enumerate(self.basis_functions_) if i not in bf_using_feature or bf.is_intercept ] # Need at least intercept if not subset: subset = [BasisFunction()] gcv_without = self._compute_gcv( subset, self._X_train, self._y_train, self._sample_weight ) # Importance = increase in GCV when variable is removed importances[feature_idx] = max(0.0, gcv_without - self.gcv_) # Normalize to 0-100 scale max_imp = max(importances.values()) if importances else 1 if max_imp > 0: importances = {k: 100 * v / max_imp for k, v in importances.items()} # Convert to feature names named_importances = { self.feature_names_in_[k]: v for k, v in importances.items() } return named_importances
_MAX_FORWARD_SAMPLES = 2000 def _forward_pass( self, X: NDArray[np.floating], y: NDArray[np.floating], sample_weight: NDArray[np.floating], ) -> list[BasisFunction | LinearBasisFunction]: """Greedy forward pass to add basis functions. Parameters ---------- X : ndarray of shape (n_samples, n_features) Training data. y : ndarray of shape (n_samples,) Target values. sample_weight : ndarray of shape (n_samples,) Sample weights. Returns ------- basis_functions : list List of basis functions added during forward pass. """ n_samples, n_features = X.shape if n_samples > int(self._MAX_FORWARD_SAMPLES * 1.5): rng = np.random.RandomState(42) _idx = rng.choice(n_samples, self._MAX_FORWARD_SAMPLES, replace=False) X = X[_idx] y = y[_idx] sample_weight = sample_weight[_idx] sample_weight = sample_weight * (self._MAX_FORWARD_SAMPLES / np.sum(sample_weight)) n_samples = self._MAX_FORWARD_SAMPLES # Start with intercept only basis_functions: list[BasisFunction | LinearBasisFunction] = [ BasisFunction() ] # Build initial basis matrix with QR decomposition for fast updates B = np.ones((n_samples, 1), dtype=np.float64) # Weight the basis matrix for weighted least squares sqrt_w = np.sqrt(sample_weight) B_weighted = B * sqrt_w[:, np.newaxis] y_weighted = y * sqrt_w Q, R = np.linalg.qr(B_weighted) # Compute initial RSS residuals = y_weighted - Q @ (Q.T @ y_weighted) current_rss = np.sum(residuals ** 2) total_ss = np.sum(sample_weight * (y - np.average(y, weights=sample_weight)) ** 2) # Pre-compute candidate knots for each feature knot_candidates = {} for j in range(n_features): if self._valid_features[j]: knot_candidates[j] = self._get_candidate_knots(X[:, j]) else: knot_candidates[j] = np.array([]) while len(basis_functions) < self._max_terms: best_decrease = 0.0 best_pair: tuple | None = None best_linear: LinearBasisFunction | None = None # Determine which parents to consider (Fast MARS) if self.fast_k > 0 and len(basis_functions) > self.fast_k: # Select top fast_k parents based on their contribution # Use absolute coefficient as proxy for importance parent_indices = self._select_top_parents( basis_functions, B, y, sample_weight, self.fast_k ) else: parent_indices = list(range(len(basis_functions))) # Consider each existing basis function as a parent for parent_idx in parent_indices: parent = basis_functions[parent_idx] if parent.degree >= self.max_degree: continue parent_features = set(parent.feature_indices) # Pre-evaluate parent once (intercept → all ones) if parent.is_intercept: parent_vals = None # signals "ones" else: parent_vals = parent.evaluate(X) for feature_j in range(n_features): if not self._valid_features[feature_j]: continue if feature_j in parent_features: continue knots = knot_candidates[feature_j] if len(knots) == 0: continue x_col = X[:, feature_j] # Vectorised hinge: compute for ALL knots at once # diff shape: (n_samples, n_knots) diff = x_col[:, np.newaxis] - knots[np.newaxis, :] h_plus_all = np.maximum(0.0, diff) # max(0, x - t) h_minus_all = np.maximum(0.0, -diff) # max(0, t - x) if parent_vals is not None: h_plus_all = h_plus_all * parent_vals[:, np.newaxis] h_minus_all = h_minus_all * parent_vals[:, np.newaxis] # Weight columns h_plus_w = h_plus_all * sqrt_w[:, np.newaxis] h_minus_w = h_minus_all * sqrt_w[:, np.newaxis] # Check for all-zeros columns (energy check) energy_plus = np.einsum("ij,ij->j", h_plus_w, h_plus_w) energy_minus = np.einsum("ij,ij->j", h_minus_w, h_minus_w) valid = (energy_plus > 1e-10) & (energy_minus > 1e-10) if not np.any(valid): # Try linear term if applicable if self.allow_linear and parent.is_intercept: col_linear = x_col * sqrt_w if np.dot(col_linear, col_linear) > 1e-10: rss_dec = self._compute_rss_decrease_single( Q, y_weighted, col_linear, ) if rss_dec > best_decrease: best_decrease = rss_dec best_pair = None best_linear = (LinearBasisFunction(feature_j), col_linear) continue # Project all valid knot columns at once: Q.T @ col QtHp = Q.T @ h_plus_w[:, valid] # (k, n_valid) QtHm = Q.T @ h_minus_w[:, valid] # Orthogonal component for each knot hp_orth = h_plus_w[:, valid] - Q @ QtHp hm_orth = h_minus_w[:, valid] - Q @ QtHm # Gram-Schmidt: orthogonalise minus against plus norms_p = np.sqrt(np.einsum("ij,ij->j", hp_orth, hp_orth)) safe_p = norms_p > 1e-10 var1 = np.zeros(safe_p.shape) if np.any(safe_p): hp_unit = hp_orth[:, safe_p] / norms_p[np.newaxis, safe_p] var1[safe_p] = np.einsum("ij,i->j", hp_unit, y_weighted) ** 2 # Project hm_orth against hp_unit dot_pm = np.einsum("ij,ij->j", hp_unit, hm_orth[:, safe_p]) hm_orth2 = hm_orth[:, safe_p] - hp_unit * dot_pm[np.newaxis, :] else: hm_orth2 = hm_orth # Compute var2 for all valid knots var2 = np.zeros(safe_p.shape) if np.any(safe_p): norms_m = np.sqrt(np.einsum("ij,ij->j", hm_orth2, hm_orth2)) safe_m = norms_m > 1e-10 if np.any(safe_m): hm_unit = hm_orth2[:, safe_m] / norms_m[np.newaxis, safe_m] v2_vals = np.einsum("ij,i->j", hm_unit, y_weighted) ** 2 idx_m = np.where(safe_p)[0][safe_m] var2[idx_m] = v2_vals rss_dec_all = var1 + var2 best_k = int(np.argmax(rss_dec_all)) if rss_dec_all[best_k] > best_decrease: best_decrease = rss_dec_all[best_k] # Map back to original knot index orig_idx = np.where(valid)[0][best_k] knot_t = knots[orig_idx] h_plus = parent.extend(HingeSpec(feature_j, knot_t, +1)) h_minus = parent.extend(HingeSpec(feature_j, knot_t, -1)) best_pair = ( h_plus, h_minus, h_plus_w[:, orig_idx], h_minus_w[:, orig_idx], ) best_linear = None # Also try linear term if allowed and parent is intercept if self.allow_linear and parent.is_intercept: col_linear = x_col * sqrt_w if np.dot(col_linear, col_linear) > 1e-10: rss_dec = self._compute_rss_decrease_single( Q, y_weighted, col_linear, ) if rss_dec > best_decrease: best_decrease = rss_dec best_pair = None best_linear = (LinearBasisFunction(feature_j), col_linear) # Stopping criterion if total_ss > 0: r2_improvement = best_decrease / total_ss else: r2_improvement = 0.0 if r2_improvement < self.thresh: break if best_decrease <= 0: break # Add the best term(s) if best_linear is not None: linear, col_linear = best_linear basis_functions.append(linear) B = np.column_stack([B, col_linear / sqrt_w[:, np.newaxis]]) B_weighted = np.column_stack([B_weighted[:, :-1] if B_weighted.shape[1] > 1 else B_weighted, col_linear]) # Actually append properly B_weighted = B * sqrt_w[:, np.newaxis] Q, R = np.linalg.qr(B_weighted) self.forward_pass_record_.append({ "type": "linear", "feature": feature_j, "rss_decrease": best_decrease, "r2_improvement": r2_improvement, }) elif best_pair is not None: h_plus, h_minus, col_plus, col_minus = best_pair basis_functions.append(h_plus) basis_functions.append(h_minus) # Update basis matrix B = np.column_stack([ B, col_plus / sqrt_w, col_minus / sqrt_w ]) B_weighted = B * sqrt_w[:, np.newaxis] Q, R = np.linalg.qr(B_weighted) self.forward_pass_record_.append({ "type": "hinge_pair", "feature": h_plus.hinges[-1].feature_idx, "knot": h_plus.hinges[-1].knot, "parent_degree": h_plus.degree - 1, "rss_decrease": best_decrease, "r2_improvement": r2_improvement, }) # Update RSS residuals = y_weighted - Q @ (Q.T @ y_weighted) current_rss = np.sum(residuals ** 2) return basis_functions def _backward_pass( self, basis_functions: list[BasisFunction | LinearBasisFunction], X: NDArray[np.floating], y: NDArray[np.floating], sample_weight: NDArray[np.floating], ) -> list[BasisFunction | LinearBasisFunction]: """Backward pass to prune basis functions using GCV. Uses Gram matrix (normal equations) to avoid O(n) lstsq per candidate removal. Pre-computes G = B_w^T B_w and h = B_w^T y_w once, then each candidate evaluation is O(k^3) instead of O(n * k^2). """ n = len(y) sqrt_w = np.sqrt(sample_weight) y_weighted = y * sqrt_w current = list(basis_functions) active = list(range(len(current))) B_full = np.column_stack([bf.evaluate(X) for bf in current]) B_full_w = B_full * sqrt_w[:, np.newaxis] G = B_full_w.T @ B_full_w h = B_full_w.T @ y_weighted yWy = float(y_weighted @ y_weighted) def _gcv_subset(idx_list): idx = np.array(idx_list) G_sub = G[np.ix_(idx, idx)].copy() h_sub = h[idx] reg = 1e-12 * max(np.mean(np.diag(G_sub)), 1e-30) G_sub[np.diag_indices_from(G_sub)] += reg try: coef = np.linalg.solve(G_sub, h_sub) except np.linalg.LinAlgError: return np.inf rss = max(0.0, yWy - float(h_sub @ coef)) n_coef = len(idx_list) n_knots = sum(current[i].degree for i in idx_list) eff = n_coef + self.penalty * n_knots denom = n * (1.0 - eff / n) ** 2 return rss / denom if denom > 0 else np.inf best_gcv = _gcv_subset(active) best_model_indices = list(active) while len(active) > 1: best_removal_gcv = np.inf col_to_remove = None for i, a in enumerate(active): if current[a].is_intercept: continue subset = active[:i] + active[i + 1:] gcv = _gcv_subset(subset) if gcv < best_removal_gcv: best_removal_gcv = gcv col_to_remove = a if col_to_remove is None: break self.pruning_record_.append({ "removed": str(current[col_to_remove]), "gcv_after": best_removal_gcv, "n_terms_after": len(active) - 1, }) active.remove(col_to_remove) if best_removal_gcv < best_gcv: best_gcv = best_removal_gcv best_model_indices = list(active) return [current[i] for i in best_model_indices] def _compute_gcv( self, basis_functions: list[BasisFunction | LinearBasisFunction], X: NDArray[np.floating], y: NDArray[np.floating], sample_weight: NDArray[np.floating], ) -> float: """Compute Generalized Cross-Validation score. GCV = RSS / (n * (1 - effective_params / n)^2) Where effective_params accounts for both the number of coefficients and the knot selection process. Parameters ---------- basis_functions : list Current model's basis functions. X : ndarray of shape (n_samples, n_features) Training data. y : ndarray of shape (n_samples,) Target values. sample_weight : ndarray of shape (n_samples,) Sample weights. Returns ------- gcv : float GCV score (lower is better). """ n = len(y) # Build basis matrix B = np.column_stack([bf.evaluate(X) for bf in basis_functions]) # Fit coefficients with weighted least squares sqrt_w = np.sqrt(sample_weight) B_weighted = B * sqrt_w[:, np.newaxis] y_weighted = y * sqrt_w try: coef, residuals, rank, s = np.linalg.lstsq(B_weighted, y_weighted, rcond=None) except np.linalg.LinAlgError: return np.inf # Compute RSS y_pred = B @ coef rss = np.sum(sample_weight * (y - y_pred) ** 2) # Count effective parameters # Number of coefficients (columns in B) n_coefficients = B.shape[1] # Number of knots (hinges) - counts the knot selection cost n_knots = sum(bf.degree for bf in basis_functions) # Effective parameters: coefficients + penalty * knots effective_params = n_coefficients + self.penalty * n_knots # GCV formula denominator = n * (1 - effective_params / n) ** 2 if denominator <= 0: return np.inf gcv = rss / denominator return gcv def _fit_coefficients( self, B: NDArray[np.floating], y: NDArray[np.floating], sample_weight: NDArray[np.floating], ) -> tuple[NDArray[np.floating], float]: """Fit coefficients using weighted least squares. Parameters ---------- B : ndarray of shape (n_samples, n_basis_functions) Basis matrix. y : ndarray of shape (n_samples,) Target values. sample_weight : ndarray of shape (n_samples,) Sample weights. Returns ------- coef : ndarray of shape (n_basis_functions,) Fitted coefficients. rss : float Residual sum of squares. """ sqrt_w = np.sqrt(sample_weight) B_weighted = B * sqrt_w[:, np.newaxis] y_weighted = y * sqrt_w coef, residuals, rank, s = np.linalg.lstsq(B_weighted, y_weighted, rcond=None) y_pred = B @ coef rss = np.sum(sample_weight * (y - y_pred) ** 2) return coef, rss def _get_candidate_knots( self, x: NDArray[np.floating], ) -> NDArray[np.floating]: """Get candidate knot locations for a feature. Parameters ---------- x : ndarray of shape (n_samples,) Feature values. Returns ------- knots : ndarray Candidate knot values. """ n = len(x) # Sort and get unique values sorted_unique = np.sort(np.unique(x)) if len(sorted_unique) <= 1: return np.array([]) # Only keep values that have enough observations on each side if len(sorted_unique) <= 2 * self._endspan: # For small number of unique values, use middle values if len(sorted_unique) > 2: return sorted_unique[1:-1] return np.array([]) # Select knots with proper spacing # Exclude endpoints based on endspan candidates = sorted_unique[self._endspan:-self._endspan] if len(candidates) > 25: indices = np.linspace(0, len(candidates) - 1, 25, dtype=int) candidates = candidates[indices] return candidates def _compute_rss_decrease_fast( self, Q: NDArray[np.floating], y_weighted: NDArray[np.floating], new_col1: NDArray[np.floating], new_col2: NDArray[np.floating], ) -> float: """Compute RSS decrease from adding two columns using QR update. This is O(n * k) instead of O(n * k^2) for full refitting. Parameters ---------- Q : ndarray Q matrix from QR decomposition of current basis. y_weighted : ndarray Weighted target values. new_col1 : ndarray First new basis function evaluated (weighted). new_col2 : ndarray Second new basis function evaluated (weighted). Returns ------- rss_decrease : float Decrease in RSS from adding these columns. """ # Project new columns onto orthogonal complement of current space new_col1_orth = new_col1 - Q @ (Q.T @ new_col1) new_col2_orth = new_col2 - Q @ (Q.T @ new_col2) # Gram-Schmidt on the two new columns norm1 = np.linalg.norm(new_col1_orth) if norm1 < 1e-10: new_col1_orth = np.zeros_like(new_col1_orth) var1 = 0.0 else: new_col1_orth = new_col1_orth / norm1 var1 = (new_col1_orth @ y_weighted) ** 2 new_col2_orth = new_col2_orth - new_col1_orth * (new_col1_orth @ new_col2_orth) norm2 = np.linalg.norm(new_col2_orth) if norm2 < 1e-10: var2 = 0.0 else: new_col2_orth = new_col2_orth / norm2 var2 = (new_col2_orth @ y_weighted) ** 2 # Additional variance explained return var1 + var2 def _compute_rss_decrease_single( self, Q: NDArray[np.floating], y_weighted: NDArray[np.floating], new_col: NDArray[np.floating], ) -> float: """Compute RSS decrease from adding a single column. Parameters ---------- Q : ndarray Q matrix from QR decomposition of current basis. y_weighted : ndarray Weighted target values. new_col : ndarray New basis function evaluated (weighted). Returns ------- rss_decrease : float Decrease in RSS from adding this column. """ # Project new column onto orthogonal complement new_col_orth = new_col - Q @ (Q.T @ new_col) norm = np.linalg.norm(new_col_orth) if norm < 1e-10: return 0.0 new_col_orth = new_col_orth / norm return (new_col_orth @ y_weighted) ** 2 def _select_top_parents( self, basis_functions: list[BasisFunction | LinearBasisFunction], B: NDArray[np.floating], y: NDArray[np.floating], sample_weight: NDArray[np.floating], k: int, ) -> list[int]: """Select top k parent basis functions for Fast MARS. Parameters ---------- basis_functions : list Current basis functions. B : ndarray Current basis matrix. y : ndarray Target values. sample_weight : ndarray Sample weights. k : int Number of parents to select. Returns ------- indices : list of int Indices of top k parents. """ # Always include intercept selected = [0] # Compute residual contribution for each basis function sqrt_w = np.sqrt(sample_weight) B_weighted = B * sqrt_w[:, np.newaxis] y_weighted = y * sqrt_w coef, _, _, _ = np.linalg.lstsq(B_weighted, y_weighted, rcond=None) # Compute contribution as |coef| * std(basis) contributions = [] for i in range(1, len(basis_functions)): if i < B.shape[1]: contrib = abs(coef[i]) * np.std(B[:, i]) contributions.append((i, contrib)) # Sort by contribution and take top k-1 (intercept is already included) contributions.sort(key=lambda x: -x[1]) for i, _ in contributions[:k - 1]: selected.append(i) return selected
[docs] class MARSClassifier(ClassifierMixin, BaseEstimator): """MARS for classification via logistic regression on basis functions. Fits a MARS model to discover basis functions, then uses logistic regression on those basis functions for classification. Parameters ---------- max_terms : int, default=None Maximum number of basis functions (including intercept). If None, defaults to min(100, max(20, 2 * n_features)) + 1. max_degree : int, default=1 Maximum degree of interactions. 1 = additive model (no interactions) 2 = pairwise interactions allowed 3 = three-way interactions allowed penalty : float, default=3.0 Generalized Cross-Validation (GCV) penalty per knot. Higher values produce simpler models. thresh : float, default=0.001 Forward pass stopping threshold. min_span : int, default=None Minimum number of observations between knots. endspan : int, default=None Minimum observations before first knot and after last knot. fast_k : int, default=20 Fast MARS parameter (see MARSRegressor). feature_names : list of str, default=None Names for features. allow_linear : bool, default=True If True, allows linear terms. method : str, default='logistic' Classification method: - 'logistic': Logistic regression on MARS basis functions - 'threshold': Threshold regression predictions at 0.5 logistic_C : float, default=1.0 Regularization parameter for logistic regression. Only used when method='logistic'. Attributes ---------- classes_ : ndarray of shape (n_classes,) Unique class labels. mars_regressor_ : MARSRegressor Underlying MARS model for basis function discovery. logistic_ : LogisticRegression Fitted logistic regression model (when method='logistic'). Examples -------- >>> from endgame.models import MARSClassifier >>> import numpy as np >>> X = np.random.randn(100, 3) >>> y = (X[:, 0] + X[:, 1] > 0).astype(int) >>> model = MARSClassifier(max_degree=1) >>> model.fit(X, y) MARSClassifier(max_degree=1) >>> predictions = model.predict(X) >>> probas = model.predict_proba(X) """ _estimator_type = "classifier" def __init__( self, max_terms: int | None = None, max_degree: int = 1, penalty: float = 3.0, thresh: float = 0.001, min_span: int | None = None, endspan: int | None = None, fast_k: int = 20, feature_names: list[str] | None = None, allow_linear: bool = True, method: str = "logistic", logistic_C: float = 1.0, ): self.max_terms = max_terms self.max_degree = max_degree self.penalty = penalty self.thresh = thresh self.min_span = min_span self.endspan = endspan self.fast_k = fast_k self.feature_names = feature_names self.allow_linear = allow_linear self.method = method self.logistic_C = logistic_C
[docs] def fit( self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None, ) -> MARSClassifier: """Fit the MARS classifier. Parameters ---------- X : array-like of shape (n_samples, n_features) Training data. y : array-like of shape (n_samples,) Target class labels. sample_weight : array-like of shape (n_samples,), default=None Individual weights for each sample. Returns ------- self : object Fitted estimator. """ if self.feature_names is None and hasattr(X, "columns"): self.feature_names = list(X.columns) X, y = check_X_y(X, y, dtype=np.float64) # Store classes self.classes_ = np.unique(y) n_classes = len(self.classes_) if n_classes < 2: raise ValueError("Need at least 2 classes for classification") # Store feature count self.n_features_in_ = X.shape[1] if self.feature_names is not None: self.feature_names_in_ = np.array(self.feature_names) else: self.feature_names_in_ = np.array([f"x{i}" for i in range(self.n_features_in_)]) # Create and fit the underlying MARS regressor # For binary classification, use y directly # For multiclass, we fit on the first class indicator if n_classes == 2: y_reg = (y == self.classes_[1]).astype(np.float64) else: y_reg = y.astype(np.float64) self.mars_regressor_ = MARSRegressor( max_terms=self.max_terms, max_degree=self.max_degree, penalty=self.penalty, thresh=self.thresh, min_span=self.min_span, endspan=self.endspan, fast_k=self.fast_k, feature_names=self.feature_names, allow_linear=self.allow_linear, ) self.mars_regressor_.fit(X, y_reg, sample_weight=sample_weight) # Get basis matrix B = self.mars_regressor_.get_basis_matrix(X) if self.method == "logistic": # Fit logistic regression on basis functions self.logistic_ = LogisticRegression( C=self.logistic_C, solver="lbfgs", max_iter=1000, ) self.logistic_.fit(B, y, sample_weight=sample_weight) elif self.method == "threshold": # Just use the regression predictions pass else: raise ValueError(f"Unknown method: {self.method}") return self
[docs] def predict(self, X: ArrayLike) -> NDArray: """Predict class labels. Parameters ---------- X : array-like of shape (n_samples, n_features) Samples to predict. Returns ------- y_pred : ndarray of shape (n_samples,) Predicted class labels. """ check_is_fitted(self) X = check_array(X, dtype=np.float64) if self.method == "logistic": B = self.mars_regressor_.get_basis_matrix(X) return self.logistic_.predict(B) else: # Threshold method y_pred_reg = self.mars_regressor_.predict(X) if len(self.classes_) == 2: return np.where(y_pred_reg >= 0.5, self.classes_[1], self.classes_[0]) else: return self.classes_[np.clip(np.round(y_pred_reg).astype(int), 0, len(self.classes_) - 1)]
[docs] def predict_proba(self, X: ArrayLike) -> NDArray[np.floating]: """Predict class probabilities. Parameters ---------- X : array-like of shape (n_samples, n_features) Samples to predict. Returns ------- proba : ndarray of shape (n_samples, n_classes) Predicted class probabilities. """ check_is_fitted(self) X = check_array(X, dtype=np.float64) if self.method == "logistic": B = self.mars_regressor_.get_basis_matrix(X) return self.logistic_.predict_proba(B) else: # Threshold method - use sigmoid-like transformation y_pred_reg = self.mars_regressor_.predict(X) y_pred_reg = np.clip(y_pred_reg, 0, 1) if len(self.classes_) == 2: return np.column_stack([1 - y_pred_reg, y_pred_reg]) else: # Simple softmax-like for multiclass n_samples = len(y_pred_reg) proba = np.zeros((n_samples, len(self.classes_))) for i, c in enumerate(self.classes_): proba[:, i] = np.exp(-np.abs(y_pred_reg - c)) proba /= proba.sum(axis=1, keepdims=True) return proba
[docs] def summary(self) -> str: """Return a human-readable summary of the model. Returns ------- summary : str Formatted model summary. """ check_is_fitted(self) lines = [ "MARS Classifier Summary", "=" * 50, "", f"Classes: {self.classes_}", f"Method: {self.method}", "", "Underlying MARS Model:", "-" * 40, ] # Add regressor summary (indented) mars_summary = self.mars_regressor_.summary() for line in mars_summary.split("\n"): lines.append(" " + line) return "\n".join(lines)
@property def basis_functions_(self): """Return basis functions from underlying MARS regressor.""" check_is_fitted(self) return self.mars_regressor_.basis_functions_
[docs] def get_basis_matrix(self, X: ArrayLike) -> NDArray[np.floating]: """Compute the basis function matrix for given X. Parameters ---------- X : array-like of shape (n_samples, n_features) Input data. Returns ------- B : ndarray of shape (n_samples, n_basis_functions) Basis matrix. """ check_is_fitted(self) return self.mars_regressor_.get_basis_matrix(X)
[docs] def compute_variable_importance(self) -> dict[str, float]: """Compute variable importance. Returns ------- importance : dict {feature_name: importance_score} """ check_is_fitted(self) return self.mars_regressor_.compute_variable_importance()