Source code for endgame.explain.counterfactual

"""Counterfactual explanations via DiCE (Diverse Counterfactual Explanations).

Generates "what-if" explanations showing the minimal feature changes needed
to alter the model's prediction.

Reference
---------
Mothilal, R.K., Sharma, A. & Tan, C. (2020). "Explaining Machine
Learning Classifiers through Diverse Counterfactual Explanations."
*FAT* '20.

Example
-------
>>> from endgame.explain import CounterfactualExplainer
>>> cf = CounterfactualExplainer(model, X_train, feature_names=feature_names)
>>> explanation = cf.explain(X_test[:1], desired_class=1)
>>> explanation.to_dataframe()
"""

from __future__ import annotations

from typing import Any, Literal

import numpy as np
from sklearn.base import BaseEstimator

from endgame.explain._base import BaseExplainer, Explanation


def _check_dice_installed() -> None:
    """Raise ImportError if the dice-ml package is unavailable."""
    try:
        import dice_ml  # noqa: F401
    except ImportError:
        raise ImportError(
            "The 'dice-ml' package is required for CounterfactualExplainer. "
            "Install it with: pip install dice-ml"
        )


[docs] class CounterfactualExplainer(BaseExplainer): """Counterfactual explanation generator using DiCE. Finds diverse sets of minimal feature perturbations that change the model's prediction to a desired outcome. Parameters ---------- model : sklearn-compatible estimator A fitted classifier or regressor. training_data : array-like or pd.DataFrame Training data used to define feature ranges and constraints. continuous_features : list of str, optional Names of continuous features. If ``None``, all features are assumed continuous. outcome_name : str, default='outcome' Name of the target column (used internally by DiCE). feature_names : list of str, optional Feature names. random_state : int, optional Random seed. verbose : bool, default=False Verbose output. Examples -------- >>> cf = CounterfactualExplainer(model, X_train) >>> explanation = cf.explain(X_test[:1], desired_class=1, n_counterfactuals=3) >>> print(explanation.metadata['counterfactuals']) """ def __init__( self, model: BaseEstimator, training_data: Any, continuous_features: list[str] | None = None, outcome_name: str = "outcome", feature_names: list[str] | None = None, random_state: int | None = None, verbose: bool = False, ): super().__init__( model=model, feature_names=feature_names, random_state=random_state, verbose=verbose, ) self.training_data = training_data self.continuous_features = continuous_features self.outcome_name = outcome_name # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _prepare_dice_data( self, X: np.ndarray, names: list[str] ) -> Any: """Build a ``dice_ml.Data`` object from the training set. Parameters ---------- X : np.ndarray Training data (features only). names : list of str Feature names. Returns ------- dice_ml.Data """ import dice_ml import pandas as pd df = pd.DataFrame(X, columns=names) # Add a dummy outcome column (DiCE requires it for data schema). df[self.outcome_name] = 0 continuous = self.continuous_features if continuous is None: continuous = list(names) return dice_ml.Data( dataframe=df, continuous_features=continuous, outcome_name=self.outcome_name, ) def _prepare_dice_model(self) -> Any: """Wrap the sklearn model for DiCE. Returns ------- dice_ml.Model """ import dice_ml backend = "sklearn" model_type = ( "classifier" if hasattr(self.model, "predict_proba") else "regressor" ) return dice_ml.Model( model=self.model, backend=backend, model_type=model_type, ) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def explain( self, X: np.ndarray, *, desired_class: int | str = "opposite", n_counterfactuals: int = 3, method: Literal["random", "genetic", "kdtree"] = "random", features_to_vary: list[str] | None = None, permitted_range: dict[str, list[float]] | None = None, ) -> Explanation: """Generate counterfactual explanations. Parameters ---------- X : array-like of shape (n_instances, n_features) Instance(s) to explain. Typically a single row. desired_class : int or str, default='opposite' Target class for the counterfactual. ``'opposite'`` flips a binary prediction. n_counterfactuals : int, default=3 Number of diverse counterfactuals to generate per instance. method : str, default='random' DiCE generation method: - ``'random'``: Random perturbation search. - ``'genetic'``: Genetic algorithm search. - ``'kdtree'``: KD-tree nearest-neighbour search. features_to_vary : list of str, optional Restrict changes to these features only. If ``None``, all features may be varied. permitted_range : dict, optional Per-feature permitted ranges, e.g. ``{'age': [18, 65], 'income': [0, 200000]}``. Returns ------- Explanation An :class:`Explanation` with: - ``values``: Mean absolute change across counterfactuals (feature-level importance of change). - ``metadata['counterfactuals']``: DataFrame of generated counterfactuals. - ``metadata['original']``: The original instance(s). """ _check_dice_installed() import pandas as pd X = self._to_numpy(X) if X.ndim == 1: X = X.reshape(1, -1) train_data = self._to_numpy(self.training_data) names = self._resolve_feature_names(X) self._log( f"Generating {n_counterfactuals} counterfactual(s) for " f"{X.shape[0]} instance(s) using method='{method}' ..." ) dice_data = self._prepare_dice_data(train_data, names) dice_model = self._prepare_dice_model() import dice_ml dice_exp = dice_ml.Dice(dice_data, dice_model, method=method) query_df = pd.DataFrame(X, columns=names) generate_kwargs: dict[str, Any] = { "query_instances": query_df, "total_CFs": n_counterfactuals, "desired_class": desired_class, } if features_to_vary is not None: generate_kwargs["features_to_vary"] = features_to_vary if permitted_range is not None: generate_kwargs["permitted_range"] = permitted_range result = dice_exp.generate_counterfactuals(**generate_kwargs) # Extract counterfactual DataFrames and compute feature changes. all_cfs: list[pd.DataFrame] = [] total_changes = np.zeros(X.shape[1]) for i, cf_example in enumerate(result.cf_examples_list): cf_df = cf_example.final_cfs_df if cf_df is not None and len(cf_df) > 0: # Drop the outcome column if present. cf_features = cf_df.drop( columns=[self.outcome_name], errors="ignore" ) all_cfs.append(cf_features) # Compute absolute changes from the original instance. original = X[i] for _, cf_row in cf_features.iterrows(): cf_vals = cf_row[names].values.astype(float) total_changes += np.abs(cf_vals - original) n_total_cfs = sum(len(df) for df in all_cfs) if all_cfs else 1 mean_changes = total_changes / max(n_total_cfs, 1) counterfactual_df = pd.concat(all_cfs, ignore_index=True) if all_cfs else pd.DataFrame(columns=names) return Explanation( values=mean_changes, base_value=None, feature_names=names, method="counterfactual", metadata={ "counterfactuals": counterfactual_df, "original": pd.DataFrame(X, columns=names), "n_counterfactuals_requested": n_counterfactuals, "n_counterfactuals_generated": len(counterfactual_df), "desired_class": desired_class, "method": method, }, )