Source code for endgame.visualization.error_bars

"""Error bars visualizer.

Interactive model comparison chart with confidence intervals,
ideal for comparing multiple models with mean + CI bars.

Example
-------
>>> from endgame.visualization import ErrorBarsVisualizer
>>> viz = ErrorBarsVisualizer(
...     labels=["LGBM", "XGB", "CatBoost"],
...     means=[0.912, 0.908, 0.915],
...     errors=[0.003, 0.005, 0.004],
...     title="Model Comparison",
... )
>>> viz.save("error_bars.html")
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from endgame.visualization._base import BaseVisualizer


[docs] class ErrorBarsVisualizer(BaseVisualizer): """Interactive error bar chart for model comparison. Parameters ---------- labels : list of str Model/category names. means : list of float Mean values. errors : list of float or list of tuple of float Symmetric errors (single value) or asymmetric (lo, hi) per point. orientation : str, default='horizontal' 'horizontal' or 'vertical'. sort : bool, default=True Sort by mean value. x_label : str, optional Value axis label. y_label : str, optional Category axis label. title : str, optional Chart title. palette : str, default='tableau' Color palette. width : int, default=800 Chart width. height : int, default=500 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, labels: Sequence[str], means: Sequence[float], errors: Sequence, *, orientation: str = "horizontal", sort: bool = True, x_label: str = "", y_label: str = "", title: str = "", palette: str = "tableau", width: int = 800, height: int = 500, theme: str = "dark", ): super().__init__(title=title, palette=palette, width=width, height=height, theme=theme) self.labels = list(labels) self.means = list(means) self.orientation = orientation self.sort = sort self.x_label = x_label self.y_label = y_label # Normalize errors to (lo, hi) tuples self._errors = [] for e in errors: if isinstance(e, (list, tuple)): self._errors.append((float(e[0]), float(e[1]))) else: self._errors.append((float(e), float(e)))
[docs] @classmethod def from_cv_results( cls, results: dict[str, Sequence[float]], **kwargs, ) -> ErrorBarsVisualizer: """Create from CV results (mean +/- std). Parameters ---------- results : dict of str → list of float Model name → fold scores. **kwargs Additional keyword arguments. """ labels, means, errors = [], [], [] for name, scores in results.items(): arr = np.asarray(scores, dtype=float) labels.append(name) means.append(float(arr.mean())) errors.append(float(arr.std())) kwargs.setdefault("title", "Model Comparison (mean ± std)") kwargs.setdefault("x_label", "Score") return cls(labels, means, errors, **kwargs)
def _build_data(self) -> dict[str, Any]: labels = list(self.labels) means = list(self.means) errors = list(self._errors) if self.sort: indices = sorted(range(len(means)), key=lambda i: means[i]) labels = [labels[i] for i in indices] means = [means[i] for i in indices] errors = [errors[i] for i in indices] return { "labels": labels, "means": means, "errors": errors, "orientation": self.orientation, "xLabel": self.x_label, "yLabel": self.y_label, } def _chart_type(self) -> str: return "error_bars" def _get_chart_js(self) -> str: return _ERROR_BARS_JS
_ERROR_BARS_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const isHoriz = data.orientation === 'horizontal'; const margin = isHoriz ? {top: 20, right: 30, bottom: 50, left: 120} : {top: 20, right: 20, bottom: 60, left: 60}; const ctx = EG.createSVG(container, config.width, config.height, margin); const {g, width: W, height: H} = ctx; const palette = config.palette; const labels = data.labels; const means = data.means; const errors = data.errors; const n = labels.length; // Value range let vMin = Infinity, vMax = -Infinity; for (let i = 0; i < n; i++) { const lo = means[i] - errors[i][0]; const hi = means[i] + errors[i][1]; if (lo < vMin) vMin = lo; if (hi > vMax) vMax = hi; } const pad = (vMax - vMin) * 0.1 || 0.01; vMin -= pad; vMax += pad; const catScale = EG.scaleBand(labels, [0, isHoriz ? H : W], 0.25); const valScale = EG.scaleLinear([vMin, vMax], isHoriz ? [0, W] : [H, 0]); const bw = catScale.bandwidth(); if (isHoriz) { EG.drawXAxis(g, valScale, H, data.xLabel); EG.drawYAxis(g, catScale, W, data.yLabel, true); } else { EG.drawXAxis(g, catScale, H, data.xLabel, true); EG.drawYAxis(g, valScale, W, data.yLabel); } for (let i = 0; i < n; i++) { const color = palette[i % palette.length]; const m = means[i]; const eLo = errors[i][0], eHi = errors[i][1]; if (isHoriz) { const cy = catScale(labels[i]) + bw / 2; const xLo = valScale(m - eLo), xHi = valScale(m + eHi), xM = valScale(m); // Error bar line g.appendChild(EG.svg('line', {x1: xLo, y1: cy, x2: xHi, y2: cy, stroke: color, 'stroke-width': 2})); // Caps const capH = bw * 0.25; g.appendChild(EG.svg('line', {x1: xLo, y1: cy-capH, x2: xLo, y2: cy+capH, stroke: color, 'stroke-width': 2})); g.appendChild(EG.svg('line', {x1: xHi, y1: cy-capH, x2: xHi, y2: cy+capH, stroke: color, 'stroke-width': 2})); // Mean dot const dot = EG.svg('circle', {cx: xM, cy: cy, r: 6, fill: color, stroke: 'var(--bg-card)', 'stroke-width': 2}); dot.addEventListener('mouseenter', function(e) { EG.tooltip.show(e, '<b>' + EG.esc(labels[i]) + '</b><br>Mean: ' + EG.fmt(m, 4) + '<br>CI: [' + EG.fmt(m - eLo, 4) + ', ' + EG.fmt(m + eHi, 4) + ']'); }); dot.addEventListener('mouseleave', function() { EG.tooltip.hide(); }); g.appendChild(dot); } else { const cx = catScale(labels[i]) + bw / 2; const yLo = valScale(m - eLo), yHi = valScale(m + eHi), yM = valScale(m); g.appendChild(EG.svg('line', {x1: cx, y1: yLo, x2: cx, y2: yHi, stroke: color, 'stroke-width': 2})); const capW = bw * 0.25; g.appendChild(EG.svg('line', {x1: cx-capW, y1: yLo, x2: cx+capW, y2: yLo, stroke: color, 'stroke-width': 2})); g.appendChild(EG.svg('line', {x1: cx-capW, y1: yHi, x2: cx+capW, y2: yHi, stroke: color, 'stroke-width': 2})); const dot = EG.svg('circle', {cx: cx, cy: yM, r: 6, fill: color, stroke: 'var(--bg-card)', 'stroke-width': 2}); dot.addEventListener('mouseenter', function(e) { EG.tooltip.show(e, '<b>' + EG.esc(labels[i]) + '</b><br>Mean: ' + EG.fmt(m, 4) + '<br>CI: [' + EG.fmt(m - eLo, 4) + ', ' + EG.fmt(m + eHi, 4) + ']'); }); dot.addEventListener('mouseleave', function() { EG.tooltip.hide(); }); g.appendChild(dot); } } } """