"""Lollipop chart visualizer.
A cleaner alternative to bar charts for displaying ranked values, such as
feature importances. Each value is shown as a dot on a stem line, reducing
visual clutter while preserving precision.
Example
-------
>>> from endgame.visualization import LollipopChartVisualizer
>>> viz = LollipopChartVisualizer.from_importances(model, feature_names)
>>> viz.save("lollipop.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class LollipopChartVisualizer(BaseVisualizer):
"""Interactive lollipop chart visualizer.
Parameters
----------
labels : list of str
Category labels.
values : list of float
Values for each category.
orientation : str, default='horizontal'
'horizontal' (labels on Y) or 'vertical' (labels on X).
sort : bool, default=False
Sort by value (descending).
highlight_top : int, optional
Number of top items to highlight with larger dots.
baseline : float, default=0
Baseline value where stems start.
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=750
Chart width.
height : int, default=500
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
labels: Sequence[str],
values: Sequence[float],
*,
orientation: str = "horizontal",
sort: bool = False,
highlight_top: int | None = None,
baseline: float = 0,
title: str = "",
palette: str = "tableau",
width: int = 750,
height: int = 500,
theme: str = "dark",
):
super().__init__(title=title or "Lollipop Chart", palette=palette, width=width, height=height, theme=theme)
labs = list(labels)
vals = [float(v) for v in values]
if sort:
order = sorted(range(len(vals)), key=lambda i: vals[i], reverse=True)
labs = [labs[i] for i in order]
vals = [vals[i] for i in order]
self._labels = labs
self._values = vals
self.orientation = orientation
self.highlight_top = highlight_top
self.baseline = baseline
[docs]
@classmethod
def from_importances(
cls,
model: Any,
*,
feature_names: Sequence[str] | None = None,
top_n: int | None = None,
**kwargs,
) -> LollipopChartVisualizer:
"""Create from model feature importances.
Parameters
----------
model : estimator
Fitted model with ``feature_importances_``.
feature_names : list of str, optional
Feature names.
top_n : int, optional
Show only top N features.
**kwargs
Additional keyword arguments.
"""
imp = np.asarray(model.feature_importances_)
if feature_names is None:
feature_names = [f"Feature {i}" for i in range(len(imp))]
if top_n is not None:
idx = np.argsort(imp)[::-1][:top_n]
imp = imp[idx]
feature_names = [feature_names[i] for i in idx]
kwargs.setdefault("sort", True)
kwargs.setdefault("title", "Feature Importances")
kwargs.setdefault("highlight_top", 3)
return cls(list(feature_names), imp.tolist(), **kwargs)
[docs]
@classmethod
def from_dict(cls, data: dict[str, float], **kwargs) -> LollipopChartVisualizer:
"""Create from a dictionary.
Parameters
----------
data : dict of str → float
Label → value pairs.
**kwargs
Additional keyword arguments.
"""
return cls(list(data.keys()), list(data.values()), **kwargs)
def _build_data(self) -> dict[str, Any]:
return {
"labels": self._labels,
"values": self._values,
"orientation": self.orientation,
"highlightTop": self.highlight_top,
"baseline": self.baseline,
}
def _chart_type(self) -> str:
return "lollipop"
def _get_chart_js(self) -> str:
return _LOLLIPOP_JS
_LOLLIPOP_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const palette = config.palette;
const labels = data.labels;
const values = data.values;
const n = labels.length;
const horiz = data.orientation === 'horizontal';
const baseline = data.baseline;
const hlTop = data.highlightTop || 0;
if (horiz) {
const margin = {top: 15, right: 30, bottom: 45, left: 140};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
let vMin = Math.min(baseline, Math.min.apply(null, values));
let vMax = Math.max(baseline, Math.max.apply(null, values));
const pad = (vMax - vMin) * 0.08 || 0.01;
vMin -= pad; vMax += pad;
const xScale = EG.scaleLinear([vMin, vMax], [0, W]);
const bx = xScale(baseline);
const rowH = H / n;
EG.drawXAxis(g, xScale, H, 'Value');
// Baseline line
g.appendChild(EG.svg('line', {x1: bx, y1: 0, x2: bx, y2: H, stroke: 'var(--text-muted)', 'stroke-width': 1, 'stroke-dasharray': '4,3', opacity: 0.5}));
for (let i = 0; i < n; i++) {
const y = i * rowH + rowH / 2;
const vx = xScale(values[i]);
const color = palette[i % palette.length];
const isHighlighted = i < hlTop;
const dotR = isHighlighted ? 7 : 5;
// Stem
g.appendChild(EG.svg('line', {
x1: bx, y1: y, x2: vx, y2: y,
stroke: color, 'stroke-width': 2
}));
// Dot
const dot = EG.svg('circle', {
cx: vx, cy: y, r: dotR,
fill: color, stroke: 'var(--bg-card)', 'stroke-width': 2
});
dot.addEventListener('mouseenter', function(e) {
dot.setAttribute('r', String(dotR + 2));
EG.tooltip.show(e, '<b>' + EG.esc(labels[i]) + '</b><br>Value: ' + EG.fmt(values[i], 4));
});
dot.addEventListener('mouseleave', function() {
dot.setAttribute('r', String(dotR));
EG.tooltip.hide();
});
g.appendChild(dot);
// Label
g.appendChild(EG.svg('text', {
x: -8, y: y + 4, 'text-anchor': 'end',
fill: 'var(--text-primary)', 'font-size': '11px',
'font-weight': isHighlighted ? '600' : '400'
})).textContent = labels[i].length > 20 ? labels[i].slice(0, 18) + '…' : labels[i];
}
} else {
// Vertical orientation
const margin = {top: 15, right: 20, bottom: 65, left: 55};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
let vMin = Math.min(baseline, Math.min.apply(null, values));
let vMax = Math.max(baseline, Math.max.apply(null, values));
const pad = (vMax - vMin) * 0.08 || 0.01;
vMin -= pad; vMax += pad;
const yScale = EG.scaleLinear([vMin, vMax], [H, 0]);
const by = yScale(baseline);
const colW = W / n;
EG.drawYAxis(g, yScale, W, 'Value');
// Baseline line
g.appendChild(EG.svg('line', {x1: 0, y1: by, x2: W, y2: by, stroke: 'var(--text-muted)', 'stroke-width': 1, 'stroke-dasharray': '4,3', opacity: 0.5}));
for (let i = 0; i < n; i++) {
const x = i * colW + colW / 2;
const vy = yScale(values[i]);
const color = palette[i % palette.length];
const isHighlighted = i < hlTop;
const dotR = isHighlighted ? 7 : 5;
g.appendChild(EG.svg('line', {x1: x, y1: by, x2: x, y2: vy, stroke: color, 'stroke-width': 2}));
const dot = EG.svg('circle', {cx: x, cy: vy, r: dotR, fill: color, stroke: 'var(--bg-card)', 'stroke-width': 2});
dot.addEventListener('mouseenter', function(e) {
dot.setAttribute('r', String(dotR + 2));
EG.tooltip.show(e, '<b>' + EG.esc(labels[i]) + '</b><br>Value: ' + EG.fmt(values[i], 4));
});
dot.addEventListener('mouseleave', function() { dot.setAttribute('r', String(dotR)); EG.tooltip.hide(); });
g.appendChild(dot);
// Rotated label
var txt = EG.svg('text', {x: x, y: H + 12, 'text-anchor': 'end', fill: 'var(--text-secondary)', 'font-size': '10px', transform: 'rotate(-35,' + x + ',' + (H + 12) + ')'});
txt.textContent = labels[i].length > 15 ? labels[i].slice(0, 13) + '…' : labels[i];
g.appendChild(txt);
}
}
}
"""