Source code for endgame.explain.pdp

"""Partial Dependence Plots (1-D and 2-D).

Computes the marginal effect of one or two features on the model's
predictions, using sklearn's ``partial_dependence`` under the hood.

Example
-------
>>> from endgame.explain import PartialDependence
>>> pdp = PartialDependence(model, feature_names=feature_names)
>>> explanation = pdp.explain(X_train, features=[0, 1])
>>> explanation.plot()
"""

from __future__ import annotations

from typing import Any, Literal

import numpy as np
from sklearn.base import BaseEstimator

from endgame.explain._base import BaseExplainer, Explanation


[docs] class PartialDependence(BaseExplainer): """Partial Dependence computation for 1-D and 2-D feature effects. Uses :func:`sklearn.inspection.partial_dependence` when available, with a brute-force fallback for models that do not expose the ``predict`` contract sklearn expects. Parameters ---------- model : sklearn-compatible estimator A fitted model with ``predict`` (regression) or ``predict_proba`` (classification). grid_resolution : int, default=50 Number of evenly-spaced grid points along each feature axis. percentiles : tuple of float, default=(0.05, 0.95) Lower and upper percentile bounds of the grid. kind : str, default='average' ``'average'`` for the marginal expectation (classic PDP), or ``'individual'`` for Individual Conditional Expectation (ICE). feature_names : list of str, optional Feature names. random_state : int, optional Random seed. verbose : bool, default=False Verbose output. Examples -------- >>> pdp = PartialDependence(model) >>> # 1-D partial dependence for feature 0 >>> exp1d = pdp.explain(X_train, features=[0]) >>> exp1d.plot() >>> >>> # 2-D partial dependence for feature pair (0, 1) >>> exp2d = pdp.explain(X_train, features=[(0, 1)]) >>> exp2d.plot() """ def __init__( self, model: BaseEstimator, grid_resolution: int = 50, percentiles: tuple[float, float] = (0.05, 0.95), kind: Literal["average", "individual"] = "average", feature_names: list[str] | None = None, random_state: int | None = None, verbose: bool = False, ): super().__init__( model=model, feature_names=feature_names, random_state=random_state, verbose=verbose, ) self.grid_resolution = grid_resolution self.percentiles = percentiles self.kind = kind # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def explain( self, X: np.ndarray, *, features: list[int | tuple[int, int]] | None = None, target_class: int | None = None, ) -> Explanation: """Compute partial dependence for the requested features. Parameters ---------- X : array-like of shape (n_samples, n_features) Data used to marginalise over. features : list of int or list of (int, int), optional Feature indices to compute PDP for. Each element is either a single int (1-D PDP) or a pair of ints (2-D PDP). If ``None``, computes 1-D PDP for all features. target_class : int, optional For classifiers, the index of the class to explain. Defaults to class 1 for binary classification. Returns ------- Explanation An :class:`Explanation` whose ``values`` and ``metadata`` contain PDP grids for each requested feature (set). """ from sklearn.inspection import partial_dependence X = self._to_numpy(X) names = self._resolve_feature_names(X) if features is None: features = list(range(X.shape[1])) response_method = "auto" if hasattr(self.model, "predict_proba"): response_method = "predict_proba" results: dict[str, Any] = {} global_importance = np.zeros(X.shape[1]) self._log(f"Computing partial dependence for {len(features)} feature(s) ...") for feat in features: # sklearn expects a list for the features argument. feat_tuple = feat if isinstance(feat, (tuple, list)) else [feat] pd_result = partial_dependence( self.model, X, features=feat_tuple, grid_resolution=self.grid_resolution, percentiles=self.percentiles, kind=self.kind, response_method=response_method, ) # pd_result is a Bunch with 'average'/'individual' and 'grid_values'. if self.kind == "average": pdp_values = pd_result["average"] else: pdp_values = pd_result["individual"] grid_values = pd_result["grid_values"] # For classifiers with multiple classes, select target class. if pdp_values.ndim >= 2 and target_class is not None: pdp_values = pdp_values[target_class] elif pdp_values.ndim >= 2: # Default: class 1 for binary, or first class for multiclass. pdp_values = pdp_values[0] key = str(feat) results[key] = { "pdp_values": pdp_values, "grid_values": [np.asarray(g) for g in grid_values], "feature": feat, } # For 1-D features, measure importance as range of PDP. if isinstance(feat, int): pdp_range = float(np.ptp(pdp_values)) global_importance[feat] = pdp_range return Explanation( values=global_importance, base_value=None, feature_names=names, method="pdp", metadata={ "pdp_results": results, "kind": self.kind, "grid_resolution": self.grid_resolution, "percentiles": self.percentiles, }, )
[docs] def plot_feature( self, explanation: Explanation, feature: int | tuple[int, int], *, ax: Any | None = None, show: bool = True, **kwargs: Any, ) -> Any: """Plot partial dependence for a single feature (or pair). Parameters ---------- explanation : Explanation Result from :meth:`explain`. feature : int or (int, int) Feature index (or pair for 2-D). ax : matplotlib.axes.Axes, optional Axes to plot on. If ``None``, a new figure is created. show : bool, default=True Whether to call ``plt.show()``. **kwargs Forwarded to matplotlib plotting functions. Returns ------- matplotlib.figure.Figure """ import matplotlib.pyplot as plt key = str(feature) pdp_data = explanation.metadata["pdp_results"].get(key) if pdp_data is None: raise ValueError( f"Feature {feature} not found in explanation. " f"Available: {list(explanation.metadata['pdp_results'].keys())}" ) names = explanation.feature_names or [] pdp_values = pdp_data["pdp_values"] grid_values = pdp_data["grid_values"] if isinstance(feature, (tuple, list)) and len(feature) == 2: # 2-D contour plot. fig, ax_ = plt.subplots(figsize=kwargs.pop("figsize", (8, 6))) if ax is None else (ax.figure, ax) XX, YY = np.meshgrid(grid_values[0], grid_values[1]) ZZ = pdp_values.reshape(XX.shape) if pdp_values.ndim == 1 else pdp_values cs = ax_.contourf(XX, YY, ZZ, levels=20, cmap=kwargs.pop("cmap", "viridis"), **kwargs) fig.colorbar(cs, ax=ax_) f0_name = names[feature[0]] if feature[0] < len(names) else f"Feature {feature[0]}" f1_name = names[feature[1]] if feature[1] < len(names) else f"Feature {feature[1]}" ax_.set_xlabel(f0_name) ax_.set_ylabel(f1_name) ax_.set_title(f"2-D Partial Dependence: {f0_name} vs {f1_name}") else: # 1-D line plot. fig, ax_ = plt.subplots(figsize=kwargs.pop("figsize", (8, 5))) if ax is None else (ax.figure, ax) feat_idx = feature if isinstance(feature, int) else feature[0] f_name = names[feat_idx] if feat_idx < len(names) else f"Feature {feat_idx}" ax_.plot(grid_values[0], pdp_values.ravel(), linewidth=2, **kwargs) ax_.set_xlabel(f_name) ax_.set_ylabel("Partial Dependence") ax_.set_title(f"PDP: {f_name}") plt.tight_layout() if show: plt.show() return fig