Source code for endgame.visualization.lift_chart

"""Lift and cumulative gains chart visualizer.

Interactive lift charts and cumulative gains plots for evaluating
ranking quality of classifiers. Shows how much better the model is
compared to a random baseline at various thresholds.

Example
-------
>>> from endgame.visualization import LiftChartVisualizer
>>> viz = LiftChartVisualizer.from_predictions(y_true, y_score)
>>> viz.save("lift.html")
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from endgame.visualization._base import BaseVisualizer


[docs] class LiftChartVisualizer(BaseVisualizer): """Interactive lift / cumulative gains chart visualizer. Parameters ---------- curves : list of dict Each dict has keys 'percentiles' (list of float 0-1), 'gains' (list of float), 'lift' (list of float), 'label' (str). mode : str, default='both' 'gains' (cumulative gains only), 'lift' (lift only), or 'both' (gains on left, lift on right — dual axis). title : str, optional Chart title. palette : str, default='tableau' Color palette. width : int, default=850 Chart width. height : int, default=550 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, curves: Sequence[dict[str, Any]], *, mode: str = "both", title: str = "", palette: str = "tableau", width: int = 850, height: int = 550, theme: str = "dark", ): super().__init__(title=title or "Lift & Cumulative Gains", palette=palette, width=width, height=height, theme=theme) self._curves = list(curves) self.mode = mode
[docs] @classmethod def from_estimator( cls, model: Any, X: Any, y: Any, *, label: str | None = None, **kwargs, ) -> LiftChartVisualizer: """Create from a fitted classifier. Parameters ---------- model : estimator Fitted classifier with ``predict_proba``. X : array-like Test features. y : array-like True binary labels. label : str, optional Model label. **kwargs Additional keyword arguments. """ y_score = model.predict_proba(X)[:, 1] label = label or type(model).__name__ return cls.from_predictions(np.asarray(y), y_score, label=label, **kwargs)
[docs] @classmethod def from_predictions( cls, y_true: Any, y_score: Any, *, label: str = "Model", n_points: int = 100, **kwargs, ) -> LiftChartVisualizer: """Create from predictions. Parameters ---------- y_true : array-like True binary labels. y_score : array-like Predicted probabilities or decision scores. label : str, default='Model' Model label. n_points : int, default=100 Number of evaluation points. **kwargs Additional keyword arguments. """ y_true = np.asarray(y_true, dtype=float) y_score = np.asarray(y_score, dtype=float) # Sort by score descending order = np.argsort(-y_score) y_sorted = y_true[order] n = len(y_sorted) total_pos = y_sorted.sum() prevalence = total_pos / n if n > 0 else 0 percentiles = np.linspace(0, 1, n_points + 1)[1:] # skip 0 gains = [] lifts = [] for pct in percentiles: k = max(1, int(np.ceil(pct * n))) captured = y_sorted[:k].sum() gain = captured / total_pos if total_pos > 0 else 0 lift = gain / pct if pct > 0 else 1 gains.append(round(float(gain), 6)) lifts.append(round(float(lift), 4)) curves = [{ "percentiles": [round(float(p), 4) for p in percentiles], "gains": gains, "lift": lifts, "label": label, }] return cls(curves, **kwargs)
def _build_data(self) -> dict[str, Any]: return { "curves": self._curves, "mode": self.mode, } def _chart_type(self) -> str: return "lift_chart" def _get_chart_js(self) -> str: return _LIFT_JS
_LIFT_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const palette = config.palette; const curves = data.curves; const mode = data.mode; const showGains = mode === 'both' || mode === 'gains'; const showLift = mode === 'both' || mode === 'lift'; if (mode === 'both') { // Side by side: gains left, lift right const halfW = Math.floor(config.width / 2); const margin = {top: 20, right: 15, bottom: 55, left: 55}; const svg = EG.svg('svg', {width: config.width, height: config.height}); container.appendChild(svg); // ---- Cumulative Gains (left) ---- const gL = EG.svg('g', {transform: `translate(${margin.left},${margin.top})`}); svg.appendChild(gL); const W1 = halfW - margin.left - margin.right; const H1 = config.height - margin.top - margin.bottom; const xS = EG.scaleLinear([0, 1], [0, W1]); const yS = EG.scaleLinear([0, 1], [H1, 0]); drawAxes(gL, xS, yS, W1, H1, '% Population', 'Cumulative Gains'); // Random baseline gL.appendChild(EG.svg('line', { x1: xS(0), y1: yS(0), x2: xS(1), y2: yS(1), stroke: 'var(--text-muted)', 'stroke-width': 1.5, 'stroke-dasharray': '6,4', opacity: 0.5 })); curves.forEach(function(c, ci) { drawLine(gL, c.percentiles, c.gains, xS, yS, palette[ci % palette.length], c.label); }); gL.appendChild(EG.svg('text', {x: W1/2, y: -5, 'text-anchor':'middle', fill:'var(--text-primary)', 'font-size':'13px', 'font-weight':'600'})).textContent = 'Cumulative Gains'; // ---- Lift Chart (right) ---- const gR = EG.svg('g', {transform: `translate(${halfW + margin.left},${margin.top})`}); svg.appendChild(gR); let maxLift = 1; curves.forEach(function(c) { c.lift.forEach(function(v) { if (v > maxLift) maxLift = v; }); }); maxLift = Math.ceil(maxLift * 1.1); const xS2 = EG.scaleLinear([0, 1], [0, W1]); const yS2 = EG.scaleLinear([0, maxLift], [H1, 0]); drawAxes(gR, xS2, yS2, W1, H1, '% Population', 'Lift'); // Baseline lift = 1 gR.appendChild(EG.svg('line', { x1: 0, y1: yS2(1), x2: W1, y2: yS2(1), stroke: 'var(--text-muted)', 'stroke-width': 1.5, 'stroke-dasharray': '6,4', opacity: 0.5 })); curves.forEach(function(c, ci) { drawLine(gR, c.percentiles, c.lift, xS2, yS2, palette[ci % palette.length], c.label + ' (Lift)'); }); gR.appendChild(EG.svg('text', {x: W1/2, y: -5, 'text-anchor':'middle', fill:'var(--text-primary)', 'font-size':'13px', 'font-weight':'600'})).textContent = 'Lift Chart'; } else { // Single chart const margin = {top: 20, right: 20, bottom: 55, left: 55}; const ctx = EG.createSVG(container, config.width, config.height, margin); const {g, width: W, height: H} = ctx; if (showGains) { const xS = EG.scaleLinear([0, 1], [0, W]); const yS = EG.scaleLinear([0, 1], [H, 0]); drawAxes(g, xS, yS, W, H, '% Population', 'Cumulative Gains'); g.appendChild(EG.svg('line', {x1:xS(0),y1:yS(0),x2:xS(1),y2:yS(1), stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); curves.forEach(function(c, ci) { drawLine(g, c.percentiles, c.gains, xS, yS, palette[ci % palette.length], c.label); }); } else { let maxLift = 1; curves.forEach(function(c) { c.lift.forEach(function(v) { if (v > maxLift) maxLift = v; }); }); maxLift = Math.ceil(maxLift * 1.1); const xS = EG.scaleLinear([0, 1], [0, W]); const yS = EG.scaleLinear([0, maxLift], [H, 0]); drawAxes(g, xS, yS, W, H, '% Population', 'Lift'); g.appendChild(EG.svg('line', {x1:0,y1:yS(1),x2:W,y2:yS(1), stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); curves.forEach(function(c, ci) { drawLine(g, c.percentiles, c.lift, xS, yS, palette[ci % palette.length], c.label); }); } } // Legend const items = curves.map(function(c, i) { return {label: c.label, color: palette[i % palette.length]}; }); EG.drawLegend(container, items); function drawAxes(g, xS, yS, W, H, xLabel, yLabel) { // X axis var axG = EG.svg('g', {transform: `translate(0,${H})`}); g.appendChild(axG); axG.appendChild(EG.svg('line', {x1:0,y1:0,x2:W,y2:0,stroke:'var(--border)'})); var ticks = EG.niceTicks(xS.domain[0], xS.domain[1], 5); ticks.forEach(function(v) { axG.appendChild(EG.svg('text', {x:xS(v), y:18, 'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'10px'})).textContent = EG.pct(v); }); axG.appendChild(EG.svg('text', {x:W/2, y:40, 'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'11px', 'font-weight':'500'})).textContent = xLabel; // Y axis var yTicks = EG.niceTicks(yS.domain[0], yS.domain[1], 5); yTicks.forEach(function(v) { var y = yS(v); g.appendChild(EG.svg('line', {x1:0,y1:y,x2:W,y2:y,stroke:'var(--grid-line)'})); g.appendChild(EG.svg('text', {x:-8,y:y+4,'text-anchor':'end',fill:'var(--text-secondary)','font-size':'10px'})).textContent = EG.fmt(v, v >= 10 ? 0 : 2); }); g.appendChild(EG.svg('text', {'text-anchor':'middle', fill:'var(--text-secondary)', 'font-size':'11px', 'font-weight':'500', transform:`translate(-40,${(yS.range[0]+yS.range[1])/2}) rotate(-90)`})).textContent = yLabel; } function drawLine(g, xData, yData, xS, yS, color, label) { var d = ''; for (var i = 0; i < xData.length; i++) { d += (i === 0 ? 'M' : ' L') + xS(xData[i]) + ' ' + yS(yData[i]); } // Fill var fillD = d + ' L' + xS(xData[xData.length-1]) + ' ' + yS.range[0] + ' L' + xS(xData[0]) + ' ' + yS.range[0] + ' Z'; g.appendChild(EG.svg('path', {d:fillD, fill:color, opacity:0.06})); var path = EG.svg('path', {d:d, fill:'none', stroke:color, 'stroke-width':2.5}); path.addEventListener('mouseenter', function(e) { path.setAttribute('stroke-width','4'); EG.tooltip.show(e, '<b>'+EG.esc(label)+'</b>'); }); path.addEventListener('mouseleave', function() { path.setAttribute('stroke-width','2.5'); EG.tooltip.hide(); }); g.appendChild(path); } } """