Source code for endgame.explain.shap_explainer

"""SHAP-based model explanations with automatic explainer selection.

Wraps the ``shap`` library and auto-detects the most efficient explainer
(Tree, Linear, Deep, Kernel) based on the model type.

Example
-------
>>> from endgame.explain import SHAPExplainer
>>> explainer = SHAPExplainer(model)
>>> explanation = explainer.explain(X_test)
>>> explanation.plot(kind='bar')
"""

from __future__ import annotations

from typing import Any, Literal

import numpy as np
from sklearn.base import BaseEstimator

from endgame.explain._base import BaseExplainer, Explanation

# Model-name substrings used for auto-detection (shared with
# endgame.feature_selection.importance.shap_importance).
_TREE_MODEL_NAMES = frozenset([
    "lgbm", "xgb", "catboost", "randomforest", "gradientboosting",
    "extratrees", "decisiontree", "histgradientboosting",
    "rotationforest", "isolationforest",
])

_LINEAR_MODEL_NAMES = frozenset([
    "linear", "logistic", "ridge", "lasso", "elasticnet", "sgd",
])

_DEEP_MODEL_NAMES = frozenset([
    "fttransformer", "saint", "node", "tabnet", "nam", "gandalf",
    "talr", "tabularresnet", "tabtransformer", "embeddingmlp",
])


def _check_shap_installed() -> None:
    """Raise ImportError if the shap package is unavailable."""
    try:
        import shap  # noqa: F401
    except ImportError:
        raise ImportError(
            "The 'shap' package is required for SHAPExplainer. "
            "Install it with: pip install shap"
        )


def _detect_explainer_type(model: Any) -> str:
    """Auto-detect the best SHAP explainer type for *model*.

    Parameters
    ----------
    model : estimator
        A fitted model.

    Returns
    -------
    str
        One of ``'tree'``, ``'linear'``, ``'deep'``, ``'kernel'``.
    """
    name = type(model).__name__.lower()

    # Unwrap common wrappers.
    inner = getattr(model, "model_", None) or getattr(model, "estimator_", None)
    inner_name = type(inner).__name__.lower() if inner is not None else ""

    for candidate in (name, inner_name):
        if any(t in candidate for t in _TREE_MODEL_NAMES):
            return "tree"
        if any(t in candidate for t in _LINEAR_MODEL_NAMES):
            return "linear"
        if any(t in candidate for t in _DEEP_MODEL_NAMES):
            return "deep"

    # Pipeline: inspect the last step.
    if hasattr(model, "steps"):
        last_step = model.steps[-1][1]
        return _detect_explainer_type(last_step)

    return "kernel"


[docs] class SHAPExplainer(BaseExplainer): """SHAP-based explainer with automatic backend selection. Supports ``TreeExplainer``, ``LinearExplainer``, ``DeepExplainer``, and ``KernelExplainer``. By default the most efficient backend is chosen automatically based on the model type. Parameters ---------- model : sklearn-compatible estimator A fitted model. explainer_type : str, default='auto' SHAP explainer backend: - ``'auto'``: Auto-detect from model type. - ``'tree'``: ``shap.TreeExplainer`` (tree-based models). - ``'linear'``: ``shap.LinearExplainer`` (linear models). - ``'deep'``: ``shap.DeepExplainer`` (neural networks). - ``'kernel'``: ``shap.KernelExplainer`` (model-agnostic). background_samples : int, default=100 Number of background samples for Kernel / Linear / Deep explainers. max_samples : int, optional If set, subsample *X* to at most this many rows before computing SHAP values (useful for large datasets with KernelExplainer). check_additivity : bool, default=False Whether to verify the SHAP additivity property. feature_names : list of str, optional Feature names. random_state : int, optional Random seed. verbose : bool, default=False Verbose output. Examples -------- >>> from endgame.explain import SHAPExplainer >>> explainer = SHAPExplainer(model, explainer_type='auto') >>> explanation = explainer.explain(X_test) >>> explanation.plot(kind='beeswarm') >>> print(explanation.top_features(5)) """ def __init__( self, model: BaseEstimator, explainer_type: Literal["auto", "tree", "linear", "deep", "kernel"] = "auto", background_samples: int = 100, max_samples: int | None = None, check_additivity: bool = False, feature_names: list[str] | None = None, random_state: int | None = None, verbose: bool = False, ): super().__init__( model=model, feature_names=feature_names, random_state=random_state, verbose=verbose, ) self.explainer_type = explainer_type self.background_samples = background_samples self.max_samples = max_samples self.check_additivity = check_additivity # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _resolve_type(self) -> str: """Resolve the effective explainer type string.""" if self.explainer_type == "auto": detected = _detect_explainer_type(self.model) self._log(f"Auto-detected SHAP explainer type: {detected}") return detected return self.explainer_type def _build_explainer(self, X: np.ndarray, explainer_type: str) -> Any: """Construct the underlying ``shap.Explainer`` instance. Parameters ---------- X : np.ndarray Data used to derive background summaries when needed. explainer_type : str One of ``'tree'``, ``'linear'``, ``'deep'``, ``'kernel'``. Returns ------- shap explainer instance """ import shap rng = np.random.RandomState(self.random_state) n_bg = min(self.background_samples, len(X)) if explainer_type == "tree": # Unwrap endgame wrappers so SHAP sees the native model model = getattr(self.model, "model_", self.model) return shap.TreeExplainer(model) elif explainer_type == "linear": idx = rng.choice(len(X), size=n_bg, replace=False) return shap.LinearExplainer(self.model, X[idx]) elif explainer_type == "deep": idx = rng.choice(len(X), size=n_bg, replace=False) return shap.DeepExplainer(self.model, X[idx]) elif explainer_type == "kernel": idx = rng.choice(len(X), size=n_bg, replace=False) predict_fn = ( self.model.predict_proba if hasattr(self.model, "predict_proba") else self.model.predict ) return shap.KernelExplainer(predict_fn, X[idx]) raise ValueError( f"Unknown explainer_type '{explainer_type}'. " "Expected one of: 'tree', 'linear', 'deep', 'kernel'." ) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def explain( self, X: np.ndarray, *, check_additivity: bool | None = None, ) -> Explanation: """Compute SHAP values for *X*. Parameters ---------- X : array-like of shape (n_samples, n_features) Data to explain. check_additivity : bool, optional Override the instance-level ``check_additivity`` setting. Returns ------- Explanation An :class:`Explanation` containing sample-level SHAP values of shape ``(n_samples, n_features)``. """ _check_shap_installed() X = self._to_numpy(X) names = self._resolve_feature_names(X) check_add = check_additivity if check_additivity is not None else self.check_additivity # Optional subsampling for expensive explainers. if self.max_samples is not None and len(X) > self.max_samples: rng = np.random.RandomState(self.random_state) idx = rng.choice(len(X), size=self.max_samples, replace=False) X_explain = X[idx] else: X_explain = X explainer_type = self._resolve_type() explainer = self._build_explainer(X, explainer_type) self._log(f"Computing SHAP values for {X_explain.shape[0]} samples ...") # Only TreeExplainer supports check_additivity. shap_kwargs: dict[str, Any] = {} if explainer_type == "tree": shap_kwargs["check_additivity"] = check_add shap_values = explainer.shap_values(X_explain, **shap_kwargs) # Multi-class handling: collapse the class dimension. # Older shap versions return a list of (n_samples, n_features) arrays. # Newer shap versions return (n_samples, n_features, n_classes). if isinstance(shap_values, list): shap_values = np.mean(np.abs(np.array(shap_values)), axis=0) elif isinstance(shap_values, np.ndarray) and shap_values.ndim == 3: # (n_samples, n_features, n_classes) -> mean |values| over classes. shap_values = np.mean(np.abs(shap_values), axis=2) # Extract base value. base_value = getattr(explainer, "expected_value", None) if isinstance(base_value, (list, np.ndarray)): base_value = np.asarray(base_value) if base_value.ndim > 0 and len(base_value) > 1: # Multi-class: take mean across classes. base_value = float(np.mean(base_value)) else: base_value = float(base_value.flat[0]) return Explanation( values=np.asarray(shap_values), base_value=base_value, feature_names=names, method="shap", metadata={ "explainer_type": explainer_type, "n_samples": X_explain.shape[0], "n_features": X_explain.shape[1], "check_additivity": check_add, "background_samples": self.background_samples, }, )
[docs] def explain_interaction(self, X: np.ndarray) -> Explanation: """Compute SHAP interaction values (tree models only). Parameters ---------- X : array-like of shape (n_samples, n_features) Data to explain. Returns ------- Explanation An :class:`Explanation` with ``values`` of shape ``(n_samples, n_features, n_features)``. Raises ------ ValueError If the model is not tree-based. """ _check_shap_installed() import shap X = self._to_numpy(X) names = self._resolve_feature_names(X) explainer_type = self._resolve_type() if explainer_type != "tree": raise ValueError( "SHAP interaction values are only supported for tree-based models. " f"Detected explainer type: '{explainer_type}'." ) model = getattr(self.model, "model_", self.model) explainer = shap.TreeExplainer(model) interaction_values = explainer.shap_interaction_values(X) # Multi-class handling: collapse the class dimension. # Older shap: list of (n_samples, n_features, n_features). # Newer shap: (n_samples, n_features, n_features, n_classes). if isinstance(interaction_values, list): interaction_values = np.mean( np.abs(np.array(interaction_values)), axis=0 ) elif isinstance(interaction_values, np.ndarray) and interaction_values.ndim == 4: interaction_values = np.mean(np.abs(interaction_values), axis=3) return Explanation( values=np.asarray(interaction_values), base_value=getattr(explainer, "expected_value", None), feature_names=names, method="shap_interaction", metadata={"n_samples": X.shape[0], "n_features": X.shape[1]}, )