Source code for endgame.explain._base

"""Base classes for the explainability module.

Provides the ``Explanation`` dataclass for storing and visualizing explanation
results, and ``BaseExplainer`` as the abstract base class for all explainers.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
from sklearn.base import BaseEstimator

if TYPE_CHECKING:
    import pandas as pd


[docs] @dataclass class Explanation: """Container for feature-attribution explanations. Holds SHAP values, LIME weights, partial dependence grids, or any other per-feature attribution output. Provides convenience methods for plotting and DataFrame export. Attributes ---------- values : np.ndarray Attribution values. Shape depends on the method: - Global: ``(n_features,)`` (mean absolute attributions). - Local: ``(n_samples, n_features)`` or ``(n_features,)`` for a single instance. - PDP: ``(n_grid_points,)`` or ``(n_grid_1, n_grid_2)`` for 2-D. base_value : float or np.ndarray or None Expected value / base prediction the attributions are relative to. feature_names : list of str or None Names corresponding to the feature axis of *values*. method : str Name of the explanation method (e.g. ``'shap'``, ``'lime'``). metadata : dict Arbitrary extra information (explainer params, timings, etc.). Examples -------- >>> explanation.plot(kind='bar') >>> df = explanation.to_dataframe() """ values: np.ndarray base_value: float | np.ndarray | None = None feature_names: list[str] | None = None method: str = "" metadata: dict[str, Any] = field(default_factory=dict) # ------------------------------------------------------------------ # Plotting # ------------------------------------------------------------------
[docs] def plot( self, kind: Literal[ "bar", "beeswarm", "waterfall", "heatmap", "force", ] = "bar", max_display: int = 20, show: bool = True, **kwargs: Any, ) -> Any: """Plot the explanation. Parameters ---------- kind : str, default='bar' Plot type. Supported values depend on the explanation method: - ``'bar'``: Horizontal bar chart of mean |attributions| (all methods). - ``'beeswarm'``: SHAP beeswarm summary plot (SHAP only). - ``'waterfall'``: Waterfall plot for a single prediction (SHAP only). - ``'heatmap'``: SHAP heatmap of sample-level attributions. - ``'force'``: SHAP force plot for a single prediction. max_display : int, default=20 Maximum number of features to display. show : bool, default=True Whether to call ``matplotlib.pyplot.show()``. **kwargs Forwarded to the underlying plotting function. Returns ------- matplotlib.figure.Figure or None The figure object when *show* is False. """ import matplotlib.pyplot as plt vals = np.asarray(self.values) names = self.feature_names # Collapse to 1-D for bar plots when we have sample-level values. if kind == "bar": if vals.ndim == 2: importance = np.mean(np.abs(vals), axis=0) else: importance = np.abs(vals) n_features = len(importance) if names is None: names = [f"Feature {i}" for i in range(n_features)] order = np.argsort(importance)[::-1][:max_display] ordered_names = [names[int(i)] for i in order] ordered_vals = importance[order] fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (8, 0.4 * max_display + 1))) y_pos = np.arange(len(ordered_names)) ax.barh(y_pos, ordered_vals, align="center", color=kwargs.pop("color", "#1f77b4")) ax.set_yticks(y_pos) ax.set_yticklabels(ordered_names) ax.invert_yaxis() ax.set_xlabel("Mean |Attribution|") ax.set_title(f"Feature Attributions ({self.method})") plt.tight_layout() if show: plt.show() return fig # SHAP-native plots (delegate to the shap library). if kind in ("beeswarm", "waterfall", "heatmap", "force"): try: import shap as shap_lib except ImportError: raise ImportError( f"Plot kind='{kind}' requires the 'shap' package. " "Install it with: pip install shap" ) shap_explanation = shap_lib.Explanation( values=vals, base_values=self.base_value, feature_names=names, ) if kind == "beeswarm": shap_lib.plots.beeswarm(shap_explanation, max_display=max_display, show=show, **kwargs) elif kind == "waterfall": if vals.ndim == 2: shap_explanation = shap_explanation[0] shap_lib.plots.waterfall(shap_explanation, max_display=max_display, show=show, **kwargs) elif kind == "heatmap": shap_lib.plots.heatmap(shap_explanation, max_display=max_display, show=show, **kwargs) elif kind == "force": if vals.ndim == 2: shap_explanation = shap_explanation[0] shap_lib.plots.force(shap_explanation, **kwargs) return None raise ValueError( f"Unknown plot kind '{kind}'. " "Supported: 'bar', 'beeswarm', 'waterfall', 'heatmap', 'force'." )
# ------------------------------------------------------------------ # DataFrame export # ------------------------------------------------------------------
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert explanation to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with feature names as the index and attribution values as columns. Raises ------ ImportError If pandas is not installed. """ try: import pandas as pd except ImportError: raise ImportError( "pandas is required for to_dataframe(). " "Install it with: pip install pandas" ) vals = np.asarray(self.values) n_features = vals.shape[-1] if vals.ndim >= 2 else vals.shape[0] names = self.feature_names or [f"Feature {i}" for i in range(n_features)] if vals.ndim == 1: return pd.DataFrame({"attribution": vals}, index=names) # Sample-level attributions: summary with mean |attribution|. return pd.DataFrame( {"mean_abs_attribution": np.mean(np.abs(vals), axis=0)}, index=names, ).sort_values("mean_abs_attribution", ascending=False)
# ------------------------------------------------------------------ # Convenience # ------------------------------------------------------------------
[docs] def top_features(self, n: int = 10) -> list[str]: """Return the top-*n* most important feature names. Parameters ---------- n : int, default=10 Number of top features to return. Returns ------- list of str Feature names ordered by descending mean |attribution|. """ vals = np.asarray(self.values) if vals.ndim == 2: importance = np.mean(np.abs(vals), axis=0) else: importance = np.abs(vals) n_features = len(importance) names = self.feature_names or [f"Feature {i}" for i in range(n_features)] order = np.argsort(importance)[::-1][:n] return [names[int(i)] for i in order]
def __repr__(self) -> str: shape = np.asarray(self.values).shape return ( f"Explanation(method='{self.method}', shape={shape}, " f"n_features={shape[-1] if len(shape) >= 1 else 0})" )
[docs] class BaseExplainer(ABC): """Abstract base class for all Endgame explainers. Subclasses must implement :meth:`explain` which returns an :class:`Explanation` object. Parameters ---------- model : sklearn-compatible estimator A fitted model to explain. feature_names : list of str, optional Feature names. If ``None``, generic names are generated. random_state : int, optional Random seed for reproducibility. verbose : bool, default=False Enable verbose logging. """ def __init__( self, model: BaseEstimator, feature_names: list[str] | None = None, random_state: int | None = None, verbose: bool = False, ): self.model = model self.feature_names = feature_names self.random_state = random_state self.verbose = verbose
[docs] @abstractmethod def explain(self, X: np.ndarray, **kwargs: Any) -> Explanation: """Generate an explanation for the given data. Parameters ---------- X : array-like of shape (n_samples, n_features) Data to explain. **kwargs Method-specific arguments. Returns ------- Explanation The computed explanation. """
def _resolve_feature_names(self, X: np.ndarray) -> list[str]: """Return feature names, generating defaults if necessary. Parameters ---------- X : np.ndarray Input array used to infer the number of features. Returns ------- list of str Feature names of length ``X.shape[1]``. """ if self.feature_names is not None: return list(self.feature_names) # Try to pull names from the data itself (pandas / polars). if hasattr(X, "columns"): return list(X.columns) n_features = X.shape[1] if X.ndim >= 2 else X.shape[0] return [f"Feature {i}" for i in range(n_features)] def _to_numpy(self, X: Any) -> np.ndarray: """Convert input to a numpy array. Parameters ---------- X : array-like Input data (numpy, pandas, or polars). Returns ------- np.ndarray Numpy array. """ if isinstance(X, np.ndarray): return X try: import pandas as pd if isinstance(X, (pd.DataFrame, pd.Series)): return X.values except ImportError: pass try: import polars as pl if isinstance(X, (pl.DataFrame, pl.LazyFrame)): if isinstance(X, pl.LazyFrame): X = X.collect() return X.to_numpy() except ImportError: pass return np.asarray(X) def _log(self, message: str) -> None: """Print a message when verbose mode is enabled.""" if self.verbose: print(f"[INFO] {self.__class__.__name__}: {message}") def __repr__(self) -> str: model_name = type(self.model).__name__ return f"{self.__class__.__name__}(model={model_name})"