"""Waterfall chart visualizer.
Interactive waterfall charts for explaining individual predictions
(SHAP-style) or showing sequential value contributions. Essential
for model interpretability.
Example
-------
>>> from endgame.visualization import WaterfallVisualizer
>>> viz = WaterfallVisualizer.from_shap(shap_values, feature_names, base_value)
>>> viz.save("waterfall.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class WaterfallVisualizer(BaseVisualizer):
"""Interactive waterfall chart visualizer.
Parameters
----------
categories : list of str
Category/feature labels.
values : list of float
Contribution values (positive = increase, negative = decrease).
base_value : float, optional
Starting/baseline value (shown at bottom).
final_value : float, optional
Final value (shown at top). If None, computed as base + sum(values).
show_connectors : bool, default=True
Show connector lines between bars.
sort_by : str, optional
'abs' to sort by absolute contribution, None for given order.
max_display : int, optional
Maximum features to display. Remaining are grouped into "Other".
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=750
Chart width.
height : int, default=550
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
categories: Sequence[str],
values: Sequence[float],
*,
base_value: float | None = None,
final_value: float | None = None,
show_connectors: bool = True,
sort_by: str | None = None,
max_display: int | None = None,
title: str = "",
palette: str = "tableau",
width: int = 750,
height: int = 550,
theme: str = "dark",
):
super().__init__(title=title or "Waterfall Chart", palette=palette, width=width, height=height, theme=theme)
cats = list(categories)
vals = list(values)
# Sort by absolute value if requested
if sort_by == "abs":
order = sorted(range(len(vals)), key=lambda i: abs(vals[i]), reverse=True)
cats = [cats[i] for i in order]
vals = [vals[i] for i in order]
# Truncate if max_display
if max_display is not None and len(vals) > max_display:
display_cats = cats[:max_display]
display_vals = vals[:max_display]
other_val = sum(vals[max_display:])
display_cats.append(f"Other ({len(vals) - max_display} features)")
display_vals.append(other_val)
cats = display_cats
vals = display_vals
self._categories = cats
self._values = [round(float(v), 6) for v in vals]
self.base_value = float(base_value) if base_value is not None else 0.0
self.final_value = (
float(final_value) if final_value is not None
else self.base_value + sum(vals)
)
self.show_connectors = show_connectors
[docs]
@classmethod
def from_shap(
cls,
shap_values: Any,
feature_names: Sequence[str],
base_value: float,
*,
max_display: int = 15,
**kwargs,
) -> WaterfallVisualizer:
"""Create from SHAP values for a single prediction.
Parameters
----------
shap_values : array-like of shape (n_features,)
SHAP values for one sample.
feature_names : list of str
Feature names.
base_value : float
Expected value (model output mean).
max_display : int, default=15
Max features to display.
**kwargs
Additional keyword arguments.
"""
sv = np.asarray(shap_values).ravel()
cats = list(feature_names)
if len(cats) != len(sv):
raise ValueError(f"Length mismatch: {len(cats)} names vs {len(sv)} values")
kwargs.setdefault("sort_by", "abs")
kwargs.setdefault("max_display", max_display)
return cls(cats, sv, base_value=base_value, **kwargs)
[docs]
@classmethod
def from_contributions(
cls,
categories: Sequence[str],
values: Sequence[float],
*,
base_value: float = 0.0,
**kwargs,
) -> WaterfallVisualizer:
"""Create from sequential contributions.
Parameters
----------
categories : list of str
Step/category labels.
values : list of float
Contribution values.
base_value : float, default=0.0
Starting value.
**kwargs
Additional keyword arguments.
"""
return cls(categories, values, base_value=base_value, **kwargs)
def _build_data(self) -> dict[str, Any]:
return {
"categories": self._categories,
"values": self._values,
"baseValue": round(self.base_value, 6),
"finalValue": round(self.final_value, 6),
"showConnectors": self.show_connectors,
}
def _chart_type(self) -> str:
return "waterfall"
def _get_chart_js(self) -> str:
return _WATERFALL_JS
_WATERFALL_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const n = data.categories.length;
const needsBase = data.baseValue !== 0;
const totalRows = n + (needsBase ? 2 : 1); // bars + base + final
const rowH = Math.min(32, (config.height - 120) / totalRows);
const actualH = rowH * totalRows + 120;
const margin = {top: 20, right: 30, bottom: 55, left: 180};
const W = config.width - margin.left - margin.right;
const H = rowH * totalRows;
const svg = EG.svg('svg', {width: config.width, height: actualH});
container.appendChild(svg);
const g = EG.svg('g', {transform: `translate(${margin.left},${margin.top})`});
svg.appendChild(g);
const palette = config.palette;
const posColor = '#2ca02c';
const negColor = '#d62728';
const totalColor = palette[0];
// Build cumulative running values
const baseVal = data.baseValue;
const vals = data.values;
const cats = data.categories;
const running = [baseVal];
for (let i = 0; i < n; i++) {
running.push(running[i] + vals[i]);
}
const finalVal = data.finalValue;
// X scale range: min/max of all running values + base + final
let allVals = running.concat([finalVal, baseVal]);
let xMin = Math.min.apply(null, allVals);
let xMax = Math.max.apply(null, allVals);
const xPad = (xMax - xMin) * 0.12 || 0.1;
xMin -= xPad; xMax += xPad;
const xScale = EG.scaleLinear([xMin, xMax], [0, W]);
const barH = rowH * 0.65;
// Helper
function drawBar(yIdx, startVal, endVal, color, label, tooltip) {
const y = yIdx * rowH;
const x1 = xScale(Math.min(startVal, endVal));
const x2 = xScale(Math.max(startVal, endVal));
const w = Math.max(x2 - x1, 2);
const rect = EG.svg('rect', {
x: x1, y: y + (rowH - barH) / 2,
width: w, height: barH,
fill: color, rx: 3, opacity: 0.85
});
rect.addEventListener('mouseenter', function(e) {
rect.setAttribute('opacity', '1');
EG.tooltip.show(e, tooltip);
});
rect.addEventListener('mouseleave', function() {
rect.setAttribute('opacity', '0.85');
EG.tooltip.hide();
});
g.appendChild(rect);
// Label on bar
const textX = endVal >= startVal ? x2 + 5 : x1 - 5;
const anchor = endVal >= startVal ? 'start' : 'end';
const valText = EG.svg('text', {
x: textX, y: y + rowH / 2 + 4,
'text-anchor': anchor, fill: color,
'font-size': '10px', 'font-weight': '600'
});
valText.textContent = (endVal - startVal >= 0 ? '+' : '') + EG.fmt(endVal - startVal, 3);
g.appendChild(valText);
// Category label (left)
const catText = EG.svg('text', {
x: -8, y: y + rowH / 2 + 4,
'text-anchor': 'end', fill: 'var(--text-primary)',
'font-size': '11px'
});
catText.textContent = label.length > 25 ? label.slice(0, 23) + '…' : label;
g.appendChild(catText);
}
// Draw rows
let row = 0;
// Base value
if (needsBase) {
drawBar(row, 0, baseVal, totalColor, 'Base value',
'<b>Base value</b><br>' + EG.fmt(baseVal, 4));
row++;
}
// Contribution bars
for (let i = 0; i < n; i++) {
const startV = running[i];
const endV = running[i + 1];
const color = vals[i] >= 0 ? posColor : negColor;
drawBar(row + i, startV, endV, color, cats[i],
'<b>' + EG.esc(cats[i]) + '</b><br>' +
'Contribution: ' + (vals[i] >= 0 ? '+' : '') + EG.fmt(vals[i], 4) + '<br>' +
'Running: ' + EG.fmt(endV, 4));
// Connector line
if (data.showConnectors && i < n - 1) {
const y1 = (row + i) * rowH + (rowH + barH) / 2;
const y2 = (row + i + 1) * rowH + (rowH - barH) / 2;
g.appendChild(EG.svg('line', {
x1: xScale(endV), y1: y1, x2: xScale(endV), y2: y2,
stroke: 'var(--text-muted)', 'stroke-width': 1,
'stroke-dasharray': '3,2', opacity: 0.5
}));
}
}
row += n;
// Connector to final
if (data.showConnectors && n > 0) {
const lastEnd = running[n];
const y1 = (row - 1) * rowH + (rowH + barH) / 2;
const y2 = row * rowH + (rowH - barH) / 2;
g.appendChild(EG.svg('line', {
x1: xScale(lastEnd), y1: y1, x2: xScale(lastEnd), y2: y2,
stroke: 'var(--text-muted)', 'stroke-width': 1,
'stroke-dasharray': '3,2', opacity: 0.5
}));
}
// Final total
drawBar(row, 0, finalVal, totalColor, 'Prediction',
'<b>Final prediction</b><br>' + EG.fmt(finalVal, 4));
// X axis
const axG = EG.svg('g', {transform: `translate(0,${H + 10})`});
g.appendChild(axG);
axG.appendChild(EG.svg('line', {x1: 0, y1: 0, x2: W, y2: 0, stroke: 'var(--border)'}));
const ticks = EG.niceTicks(xMin, xMax, 6);
ticks.forEach(function(v) {
axG.appendChild(EG.svg('line', {x1: xScale(v), y1: 0, x2: xScale(v), y2: 5, stroke: 'var(--text-muted)'}));
axG.appendChild(EG.svg('text', {x: xScale(v), y: 18, 'text-anchor': 'middle', fill: 'var(--text-secondary)', 'font-size': '10px'})).textContent = EG.fmt(v, 2);
});
axG.appendChild(EG.svg('text', {x: W/2, y: 38, 'text-anchor': 'middle', fill: 'var(--text-secondary)', 'font-size': '11px', 'font-weight': '500'})).textContent = 'Model Output';
// Zero line
if (xMin <= 0 && xMax >= 0) {
g.appendChild(EG.svg('line', {
x1: xScale(0), y1: 0, x2: xScale(0), y2: H,
stroke: 'var(--text-muted)', 'stroke-width': 1, 'stroke-dasharray': '4,3', opacity: 0.4
}));
}
}
"""