"""ROC curve visualizer.
Interactive Receiver Operating Characteristic curves for classification
evaluation. Supports binary and multi-class (One-vs-Rest) with AUC
annotation, operating point markers, and the random baseline.
Example
-------
>>> from endgame.visualization import ROCCurveVisualizer
>>> from sklearn.linear_model import LogisticRegression
>>> clf = LogisticRegression().fit(X_train, y_train)
>>> viz = ROCCurveVisualizer.from_estimator(clf, X_test, y_test)
>>> viz.save("roc.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class ROCCurveVisualizer(BaseVisualizer):
"""Interactive ROC curve visualizer.
Parameters
----------
curves : list of dict
Each dict has keys 'fpr' (list of float), 'tpr' (list of float),
'auc' (float), 'label' (str).
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=650
Chart width.
height : int, default=600
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
curves: Sequence[dict[str, Any]],
*,
title: str = "",
palette: str = "tableau",
width: int = 650,
height: int = 600,
theme: str = "dark",
):
super().__init__(title=title or "ROC Curve", palette=palette, width=width, height=height, theme=theme)
self._curves = list(curves)
# ------------------------------------------------------------------
# Classmethod constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def from_estimator(
cls,
model: Any,
X: Any,
y: Any,
*,
class_names: Sequence[str] | None = None,
**kwargs,
) -> ROCCurveVisualizer:
"""Create ROC curves from a fitted classifier.
For binary classifiers, plots a single curve. For multiclass,
plots one-vs-rest curves for each class.
Parameters
----------
model : estimator
Fitted sklearn-compatible classifier with ``predict_proba``.
X : array-like
Test features.
y : array-like
True labels.
class_names : list of str, optional
Class names.
**kwargs
Additional keyword arguments.
"""
from sklearn.metrics import auc, roc_curve
from sklearn.preprocessing import label_binarize
y_arr = np.asarray(y)
classes = np.unique(y_arr)
n_classes = len(classes)
if class_names is None:
if hasattr(model, "classes_"):
class_names = [str(c) for c in model.classes_]
else:
class_names = [str(c) for c in classes]
y_proba = model.predict_proba(X)
curves = []
if n_classes == 2:
# Binary: single ROC
fpr, tpr, thresholds = roc_curve(y_arr, y_proba[:, 1])
roc_auc = auc(fpr, tpr)
# Find optimal threshold (Youden's J)
j_scores = tpr - fpr
best_idx = int(np.argmax(j_scores))
curves.append({
"fpr": _downsample(fpr),
"tpr": _downsample(tpr),
"auc": round(float(roc_auc), 4),
"label": f"ROC (AUC = {roc_auc:.3f})",
"optimalPoint": {
"fpr": round(float(fpr[best_idx]), 4),
"tpr": round(float(tpr[best_idx]), 4),
"threshold": round(float(thresholds[best_idx]), 4),
},
})
else:
# Multiclass OVR
y_bin = label_binarize(y_arr, classes=classes)
for i in range(n_classes):
fpr, tpr, _ = roc_curve(y_bin[:, i], y_proba[:, i])
roc_auc = auc(fpr, tpr)
curves.append({
"fpr": _downsample(fpr),
"tpr": _downsample(tpr),
"auc": round(float(roc_auc), 4),
"label": f"{class_names[i]} (AUC = {roc_auc:.3f})",
})
# Micro-average
fpr_micro, tpr_micro, _ = roc_curve(y_bin.ravel(), y_proba.ravel())
micro_auc = auc(fpr_micro, tpr_micro)
curves.append({
"fpr": _downsample(fpr_micro),
"tpr": _downsample(tpr_micro),
"auc": round(float(micro_auc), 4),
"label": f"Micro-avg (AUC = {micro_auc:.3f})",
})
return cls(curves, **kwargs)
[docs]
@classmethod
def from_predictions(
cls,
y_true: Any,
y_score: Any,
*,
label: str = "Model",
**kwargs,
) -> ROCCurveVisualizer:
"""Create ROC curve from predictions (binary).
Parameters
----------
y_true : array-like
True binary labels.
y_score : array-like
Predicted probabilities or decision scores for the positive class.
label : str, default='Model'
Curve label.
**kwargs
Additional keyword arguments.
"""
from sklearn.metrics import auc, roc_curve
fpr, tpr, thresholds = roc_curve(np.asarray(y_true), np.asarray(y_score))
roc_auc = auc(fpr, tpr)
j_scores = tpr - fpr
best_idx = int(np.argmax(j_scores))
curves = [{
"fpr": _downsample(fpr),
"tpr": _downsample(tpr),
"auc": round(float(roc_auc), 4),
"label": f"{label} (AUC = {roc_auc:.3f})",
"optimalPoint": {
"fpr": round(float(fpr[best_idx]), 4),
"tpr": round(float(tpr[best_idx]), 4),
"threshold": round(float(thresholds[best_idx]), 4),
},
}]
return cls(curves, **kwargs)
# ------------------------------------------------------------------
# BaseVisualizer interface
# ------------------------------------------------------------------
def _build_data(self) -> dict[str, Any]:
return {"curves": self._curves}
def _chart_type(self) -> str:
return "roc_curve"
def _get_chart_js(self) -> str:
return _ROC_JS
def _downsample(arr, max_points: int = 500) -> list[float]:
"""Downsample an array for rendering efficiency."""
arr = np.asarray(arr)
if len(arr) <= max_points:
return [round(float(v), 6) for v in arr]
idx = np.linspace(0, len(arr) - 1, max_points, dtype=int)
return [round(float(arr[i]), 6) for i in idx]
# ---------------------------------------------------------------------------
# JavaScript renderer
# ---------------------------------------------------------------------------
_ROC_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const margin = {top: 20, right: 20, bottom: 55, left: 55};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
const palette = config.palette;
const curves = data.curves;
const xScale = EG.scaleLinear([0, 1], [0, W]);
const yScale = EG.scaleLinear([0, 1], [H, 0]);
EG.drawXAxis(g, xScale, H, 'False Positive Rate');
EG.drawYAxis(g, yScale, W, 'True Positive Rate');
// Random baseline (diagonal)
g.appendChild(EG.svg('line', {
x1: xScale(0), y1: yScale(0), x2: xScale(1), y2: yScale(1),
stroke: 'var(--text-muted)', 'stroke-width': 1.5,
'stroke-dasharray': '6,4', opacity: 0.5
}));
// Draw curves
curves.forEach(function(c, ci) {
const color = palette[ci % palette.length];
const n = Math.min(c.fpr.length, c.tpr.length);
let d = '';
for (let i = 0; i < n; i++) {
d += (i === 0 ? 'M' : ' L') + xScale(c.fpr[i]) + ' ' + yScale(c.tpr[i]);
}
const path = EG.svg('path', {
d: d, fill: 'none', stroke: color,
'stroke-width': 2.5, 'stroke-linejoin': 'round'
});
path.addEventListener('mouseenter', function(e) {
path.setAttribute('stroke-width', '4');
EG.tooltip.show(e, '<b>' + EG.esc(c.label) + '</b>');
});
path.addEventListener('mouseleave', function() {
path.setAttribute('stroke-width', '2.5');
EG.tooltip.hide();
});
g.appendChild(path);
// AUC fill (subtle)
let fillD = 'M' + xScale(c.fpr[0]) + ' ' + yScale(c.tpr[0]);
for (let i = 1; i < n; i++) {
fillD += ' L' + xScale(c.fpr[i]) + ' ' + yScale(c.tpr[i]);
}
fillD += ' L' + xScale(c.fpr[n-1]) + ' ' + H + ' L' + xScale(c.fpr[0]) + ' ' + H + ' Z';
g.appendChild(EG.svg('path', {d: fillD, fill: color, opacity: 0.06}));
// Optimal operating point
if (c.optimalPoint) {
const op = c.optimalPoint;
const cx2 = xScale(op.fpr), cy2 = yScale(op.tpr);
const marker = EG.svg('circle', {
cx: cx2, cy: cy2, r: 6,
fill: 'none', stroke: color, 'stroke-width': 2.5
});
marker.addEventListener('mouseenter', function(e) {
marker.setAttribute('r', '8');
EG.tooltip.show(e,
'<b>Optimal Point</b><br>' +
'FPR: ' + EG.fmt(op.fpr, 3) + '<br>' +
'TPR: ' + EG.fmt(op.tpr, 3) + '<br>' +
'Threshold: ' + EG.fmt(op.threshold, 3));
});
marker.addEventListener('mouseleave', function() {
marker.setAttribute('r', '6');
EG.tooltip.hide();
});
g.appendChild(marker);
// Crosshair
g.appendChild(EG.svg('line', {x1: cx2, y1: cy2, x2: cx2, y2: H, stroke: color, 'stroke-width': 1, 'stroke-dasharray': '3,3', opacity: 0.4}));
g.appendChild(EG.svg('line', {x1: 0, y1: cy2, x2: cx2, y2: cy2, stroke: color, 'stroke-width': 1, 'stroke-dasharray': '3,3', opacity: 0.4}));
}
});
// Legend
const items = curves.map(function(c, i) {
return {label: c.label, color: palette[i % palette.length]};
});
EG.drawLegend(container, items);
}
"""