"""Partial Dependence Plot (PDP) and Individual Conditional Expectation (ICE) visualizer.
Interactive PDP/ICE plots for understanding feature effects in any
sklearn-compatible model. Supports 1D PDP with optional ICE lines, and
2D PDP as a heatmap.
Example
-------
>>> from endgame.visualization import PDPVisualizer
>>> from sklearn.ensemble import RandomForestClassifier
>>> clf = RandomForestClassifier().fit(X_train, y_train)
>>> viz = PDPVisualizer.from_estimator(clf, X_train, feature=0)
>>> viz.save("pdp.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class PDPVisualizer(BaseVisualizer):
"""Interactive partial dependence / ICE plot visualizer.
Parameters
----------
grid_values : list of float
Feature values on the grid.
pdp_values : list of float
Partial dependence values (mean ICE).
ice_lines : list of list of float, optional
Individual conditional expectation curves.
feature_name : str, default=''
Feature name for axis label.
is_categorical : bool, default=False
Whether feature is categorical.
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=750
Chart width.
height : int, default=500
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
grid_values: Sequence[float],
pdp_values: Sequence[float],
*,
ice_lines: Sequence[Sequence[float]] | None = None,
feature_name: str = "",
is_categorical: bool = False,
title: str = "",
palette: str = "tableau",
width: int = 750,
height: int = 500,
theme: str = "dark",
):
super().__init__(
title=title or f"PDP – {feature_name}" if feature_name else "Partial Dependence",
palette=palette, width=width, height=height, theme=theme,
)
self._grid = list(grid_values)
self._pdp = list(pdp_values)
self._ice = [list(line) for line in ice_lines] if ice_lines else None
self.feature_name = feature_name
self.is_categorical = is_categorical
[docs]
@classmethod
def from_estimator(
cls,
model: Any,
X: Any,
*,
feature: int | str = 0,
ice: bool = True,
n_ice_lines: int = 50,
grid_resolution: int = 50,
percentiles: tuple[float, float] = (0.05, 0.95),
target_class: int = 1,
**kwargs,
) -> PDPVisualizer:
"""Create PDP/ICE from a fitted model.
Parameters
----------
model : estimator
Fitted sklearn-compatible estimator.
X : array-like of shape (n_samples, n_features)
Training data (used to build the grid and compute ICE).
feature : int or str
Feature index or name.
ice : bool, default=True
Whether to include ICE lines.
n_ice_lines : int, default=50
Number of ICE lines to sample.
grid_resolution : int, default=50
Number of grid points.
percentiles : tuple, default=(0.05, 0.95)
Feature range percentiles.
target_class : int, default=1
For classifiers, which class probability to show.
**kwargs
Additional keyword arguments.
"""
X_arr = np.asarray(X)
if X_arr.ndim == 1:
X_arr = X_arr.reshape(-1, 1)
# Resolve feature index
if isinstance(feature, str):
if hasattr(X, "columns"):
feat_idx = list(X.columns).index(feature)
feat_name = feature
else:
raise ValueError(f"Cannot resolve feature name '{feature}' from non-DataFrame input")
else:
feat_idx = int(feature)
if hasattr(X, "columns"):
feat_name = str(X.columns[feat_idx])
else:
feat_name = f"Feature {feat_idx}"
col = X_arr[:, feat_idx]
# Check if categorical (few unique values)
unique_vals = np.unique(col)
is_cat = len(unique_vals) <= 20 and np.all(unique_vals == unique_vals.astype(int))
if is_cat:
grid = sorted(unique_vals.tolist())
else:
lo, hi = np.percentile(col, [percentiles[0] * 100, percentiles[1] * 100])
grid = np.linspace(lo, hi, grid_resolution).tolist()
# Compute ICE / PDP
has_proba = hasattr(model, "predict_proba")
n_samples = X_arr.shape[0]
# Sample subset for ICE
if ice and n_samples > n_ice_lines:
ice_idx = np.random.choice(n_samples, size=n_ice_lines, replace=False)
else:
ice_idx = np.arange(min(n_samples, n_ice_lines)) if ice else np.array([], dtype=int)
all_ice = []
pdp_vals = []
for g_val in grid:
X_mod = X_arr.copy()
X_mod[:, feat_idx] = g_val
if has_proba:
preds = model.predict_proba(X_mod)[:, target_class]
else:
preds = model.predict(X_mod)
pdp_vals.append(round(float(preds.mean()), 6))
if ice and len(ice_idx) > 0:
ice_at_g = preds[ice_idx]
all_ice.append([round(float(v), 6) for v in ice_at_g])
# Transpose ICE: all_ice is [grid_points x n_ice] → we want [n_ice x grid_points]
ice_lines_out = None
if all_ice:
n_ice_actual = len(all_ice[0])
ice_lines_out = []
for j in range(n_ice_actual):
ice_lines_out.append([all_ice[i][j] for i in range(len(grid))])
grid_out = [round(float(v), 4) for v in grid]
kwargs.setdefault("feature_name", feat_name)
kwargs.setdefault("is_categorical", bool(is_cat))
return cls(grid_out, pdp_vals, ice_lines=ice_lines_out, **kwargs)
[docs]
@classmethod
def from_precomputed(
cls,
grid_values: Sequence[float],
pdp_values: Sequence[float],
*,
ice_lines: Sequence[Sequence[float]] | None = None,
feature_name: str = "",
**kwargs,
) -> PDPVisualizer:
"""Create from precomputed PDP/ICE values.
Parameters
----------
grid_values : list of float
Feature grid values.
pdp_values : list of float
Partial dependence values.
ice_lines : list of list of float, optional
ICE lines.
feature_name : str, optional
Feature name.
**kwargs
Additional keyword arguments.
"""
return cls(grid_values, pdp_values, ice_lines=ice_lines,
feature_name=feature_name, **kwargs)
def _build_data(self) -> dict[str, Any]:
return {
"grid": self._grid,
"pdp": self._pdp,
"ice": self._ice,
"featureName": self.feature_name,
"isCategorical": self.is_categorical,
}
def _chart_type(self) -> str:
return "pdp"
def _get_chart_js(self) -> str:
return _PDP_JS
[docs]
class PDP2DVisualizer(BaseVisualizer):
"""2D partial dependence plot (heatmap).
Parameters
----------
grid_x : list of float
Grid values for feature x.
grid_y : list of float
Grid values for feature y.
values : list of list of float
2D matrix of PD values (rows = y, cols = x).
feature_x : str, default=''
Feature x name.
feature_y : str, default=''
Feature y name.
"""
def __init__(
self,
grid_x: Sequence[float],
grid_y: Sequence[float],
values: Sequence[Sequence[float]],
*,
feature_x: str = "",
feature_y: str = "",
title: str = "",
palette: str = "tableau",
width: int = 700,
height: int = 600,
theme: str = "dark",
):
super().__init__(
title=title or f"2D PDP – {feature_x} × {feature_y}",
palette=palette, width=width, height=height, theme=theme,
)
self._grid_x = list(grid_x)
self._grid_y = list(grid_y)
self._values = [list(row) for row in values]
self.feature_x = feature_x
self.feature_y = feature_y
[docs]
@classmethod
def from_estimator(
cls,
model: Any,
X: Any,
*,
features: tuple[int | str, int | str] = (0, 1),
grid_resolution: int = 25,
percentiles: tuple[float, float] = (0.05, 0.95),
target_class: int = 1,
**kwargs,
) -> PDP2DVisualizer:
"""Create 2D PDP from a fitted model.
Parameters
----------
model : estimator
Fitted sklearn-compatible estimator.
X : array-like
Training data.
features : tuple of (int or str, int or str)
Two feature indices or names.
grid_resolution : int, default=25
Grid resolution per axis.
percentiles : tuple, default=(0.05, 0.95)
Feature range percentiles.
target_class : int, default=1
For classifiers, which class probability.
**kwargs
Additional keyword arguments.
"""
X_arr = np.asarray(X)
has_cols = hasattr(X, "columns")
feat_idxs = []
feat_names = []
for f in features:
if isinstance(f, str):
if has_cols:
idx = list(X.columns).index(f)
feat_idxs.append(idx)
feat_names.append(f)
else:
raise ValueError(f"Cannot resolve feature name '{f}'")
else:
feat_idxs.append(int(f))
feat_names.append(str(X.columns[int(f)]) if has_cols else f"Feature {f}")
grids = []
for fi in feat_idxs:
col = X_arr[:, fi]
lo, hi = np.percentile(col, [percentiles[0] * 100, percentiles[1] * 100])
grids.append(np.linspace(lo, hi, grid_resolution))
has_proba = hasattr(model, "predict_proba")
values = []
for y_val in grids[1]:
row = []
for x_val in grids[0]:
X_mod = X_arr.copy()
X_mod[:, feat_idxs[0]] = x_val
X_mod[:, feat_idxs[1]] = y_val
if has_proba:
preds = model.predict_proba(X_mod)[:, target_class]
else:
preds = model.predict(X_mod)
row.append(round(float(preds.mean()), 6))
values.append(row)
kwargs.setdefault("feature_x", feat_names[0])
kwargs.setdefault("feature_y", feat_names[1])
return cls(
[round(float(v), 4) for v in grids[0]],
[round(float(v), 4) for v in grids[1]],
values, **kwargs,
)
def _build_data(self) -> dict[str, Any]:
return {
"gridX": self._grid_x,
"gridY": self._grid_y,
"values": self._values,
"featureX": self.feature_x,
"featureY": self.feature_y,
}
def _chart_type(self) -> str:
return "pdp2d"
def _get_chart_js(self) -> str:
return _PDP2D_JS
# ---------------------------------------------------------------------------
# 1D PDP/ICE JavaScript
# ---------------------------------------------------------------------------
_PDP_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const margin = {top: 20, right: 20, bottom: 55, left: 60};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
const palette = config.palette;
const grid = data.grid;
const pdp = data.pdp;
const ice = data.ice;
const n = grid.length;
// Find y range from PDP + ICE
let yMin = Infinity, yMax = -Infinity;
pdp.forEach(function(v) { if (v < yMin) yMin = v; if (v > yMax) yMax = v; });
if (ice) {
ice.forEach(function(line) {
line.forEach(function(v) { if (v < yMin) yMin = v; if (v > yMax) yMax = v; });
});
}
const yPad = (yMax - yMin) * 0.08 || 0.01;
yMin -= yPad; yMax += yPad;
let xScale, isBand = false;
if (data.isCategorical) {
isBand = true;
xScale = EG.scaleBand(grid.map(String), [0, W], 0.3);
} else {
xScale = EG.scaleLinear([grid[0], grid[n-1]], [0, W]);
}
const yScale = EG.scaleLinear([yMin, yMax], [H, 0]);
EG.drawXAxis(g, xScale, H, data.featureName || 'Feature Value', isBand);
EG.drawYAxis(g, yScale, W, 'Partial Dependence');
// ICE lines (subtle background)
if (ice && ice.length > 0) {
const iceG = EG.svg('g', {opacity: 0.15});
g.appendChild(iceG);
ice.forEach(function(line) {
let d = '';
for (let i = 0; i < n; i++) {
const x = isBand ? xScale(String(grid[i])) + xScale.bandwidth / 2 : xScale(grid[i]);
d += (i === 0 ? 'M' : ' L') + x + ' ' + yScale(line[i]);
}
iceG.appendChild(EG.svg('path', {d:d, fill:'none', stroke: palette[0], 'stroke-width': 1}));
});
}
// PDP main line
let pdpD = '';
for (let i = 0; i < n; i++) {
const x = isBand ? xScale(String(grid[i])) + xScale.bandwidth / 2 : xScale(grid[i]);
pdpD += (i === 0 ? 'M' : ' L') + x + ' ' + yScale(pdp[i]);
}
g.appendChild(EG.svg('path', {
d: pdpD, fill: 'none', stroke: palette[0],
'stroke-width': 3.5, 'stroke-linejoin': 'round'
}));
// Interactive dots on PDP
for (let i = 0; i < n; i++) {
const x = isBand ? xScale(String(grid[i])) + xScale.bandwidth / 2 : xScale(grid[i]);
const y = yScale(pdp[i]);
const dot = EG.svg('circle', {cx: x, cy: y, r: 4.5, fill: palette[0], stroke: 'var(--bg-card)', 'stroke-width': 2, opacity: 0});
dot.addEventListener('mouseenter', function(e) {
dot.setAttribute('opacity', '1');
dot.setAttribute('r', '6');
let html = '<b>' + (data.featureName || 'Feature') + ' = ' + EG.fmt(grid[i], 3) + '</b><br>PD = ' + EG.fmt(pdp[i], 4);
if (ice) html += '<br><span style="opacity:0.6">' + ice.length + ' ICE lines</span>';
EG.tooltip.show(e, html);
});
dot.addEventListener('mouseleave', function() {
dot.setAttribute('opacity', '0');
dot.setAttribute('r', '4.5');
EG.tooltip.hide();
});
g.appendChild(dot);
}
// Legend
const items = [{label: 'PDP (mean)', color: palette[0]}];
if (ice) items.push({label: 'ICE lines', color: palette[0]});
EG.drawLegend(container, items);
}
"""
# ---------------------------------------------------------------------------
# 2D PDP JavaScript (heatmap-style)
# ---------------------------------------------------------------------------
_PDP2D_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const margin = {top: 20, right: 90, bottom: 55, left: 65};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
const gx = data.gridX, gy = data.gridY;
const nx = gx.length, ny = gy.length;
const vals = data.values;
// Flatten to find range
let vMin = Infinity, vMax = -Infinity;
vals.forEach(function(row) { row.forEach(function(v) { if (v < vMin) vMin = v; if (v > vMax) vMax = v; }); });
const cellW = W / nx;
const cellH = H / ny;
// Color scale (blue → white → red)
function heatColor(v) {
const t = (v - vMin) / (vMax - vMin + 1e-10);
if (t <= 0.5) {
const s = t * 2;
return interpolate('#2166ac', '#f7f7f7', s);
} else {
const s = (t - 0.5) * 2;
return interpolate('#f7f7f7', '#b2182b', s);
}
}
function interpolate(c1, c2, t) {
const r1 = parseInt(c1.slice(1,3),16), g1 = parseInt(c1.slice(3,5),16), b1 = parseInt(c1.slice(5,7),16);
const r2 = parseInt(c2.slice(1,3),16), g2 = parseInt(c2.slice(3,5),16), b2 = parseInt(c2.slice(5,7),16);
const r = Math.round(r1 + (r2-r1)*t), gg = Math.round(g1 + (g2-g1)*t), b = Math.round(b1 + (b2-b1)*t);
return 'rgb('+r+','+gg+','+b+')';
}
// Draw cells
for (let yi = 0; yi < ny; yi++) {
for (let xi = 0; xi < nx; xi++) {
const v = vals[yi][xi];
const rect = EG.svg('rect', {
x: xi * cellW, y: (ny - 1 - yi) * cellH,
width: cellW + 0.5, height: cellH + 0.5,
fill: heatColor(v), stroke: 'none'
});
rect.addEventListener('mouseenter', function(e) {
rect.setAttribute('stroke', 'var(--text-primary)');
rect.setAttribute('stroke-width', '2');
EG.tooltip.show(e,
'<b>' + EG.esc(data.featureX) + '</b> = ' + EG.fmt(gx[xi], 3) + '<br>' +
'<b>' + EG.esc(data.featureY) + '</b> = ' + EG.fmt(gy[yi], 3) + '<br>' +
'PD = ' + EG.fmt(v, 4));
});
rect.addEventListener('mouseleave', function() {
rect.setAttribute('stroke', 'none');
EG.tooltip.hide();
});
g.appendChild(rect);
}
}
// Axes
const xTicks = EG.niceTicks(gx[0], gx[nx-1], 5);
const xS = EG.scaleLinear([gx[0], gx[nx-1]], [0, W]);
const yS = EG.scaleLinear([gy[0], gy[ny-1]], [H, 0]);
xTicks.forEach(function(v) {
g.appendChild(EG.svg('text', {x:xS(v), y:H+18, 'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'10px'})).textContent = EG.fmt(v, 2);
});
g.appendChild(EG.svg('text', {x:W/2, y:H+42, 'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'12px', 'font-weight':'500'})).textContent = data.featureX;
const yTicks = EG.niceTicks(gy[0], gy[ny-1], 5);
yTicks.forEach(function(v) {
g.appendChild(EG.svg('text', {x:-8, y:yS(v)+4, 'text-anchor':'end', fill:'var(--text-secondary)', 'font-size':'10px'})).textContent = EG.fmt(v, 2);
});
g.appendChild(EG.svg('text', {'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'12px', 'font-weight':'500',
transform:`translate(-50,${H/2}) rotate(-90)`})).textContent = data.featureY;
// Color bar
const barX = W + 15, barW = 18, barH = H;
const nSteps = 50;
for (let i = 0; i < nSteps; i++) {
const t = i / (nSteps - 1);
const v = vMin + t * (vMax - vMin);
g.appendChild(EG.svg('rect', {
x: barX, y: barH - (i+1) * barH/nSteps,
width: barW, height: barH/nSteps + 0.5,
fill: heatColor(v), stroke: 'none'
}));
}
g.appendChild(EG.svg('text', {x:barX+barW+5, y:5, fill:'var(--text-secondary)', 'font-size':'10px', 'dominant-baseline':'middle'})).textContent = EG.fmt(vMax, 3);
g.appendChild(EG.svg('text', {x:barX+barW+5, y:H, fill:'var(--text-secondary)', 'font-size':'10px', 'dominant-baseline':'middle'})).textContent = EG.fmt(vMin, 3);
}
"""