Source code for endgame.visualization.regression_report

"""Regression report — comprehensive single-page evaluation.

Generates a self-contained HTML report with performance metrics,
predicted vs actual scatter, residual analysis, feature importances,
and model interpretability (decision tree rules, linear coefficients).

Example
-------
>>> from endgame.visualization import RegressionReport
>>> report = RegressionReport(model, X_test, y_test, feature_names=fnames)
>>> report.save("regression_report.html", open_browser=True)
"""

from __future__ import annotations

import html as html_module
import json
from collections.abc import Sequence
from pathlib import Path
from typing import Any

import numpy as np
from sklearn.metrics import (
    explained_variance_score,
    max_error,
    mean_absolute_error,
    mean_absolute_percentage_error,
    mean_squared_error,
    median_absolute_error,
    r2_score,
)

from endgame.visualization._palettes import DEFAULT_CATEGORICAL, get_palette
from endgame.visualization._report_template import render_report
from endgame.visualization.classification_report import (
    _extract_linear_coefs,
    _extract_tree_rules,
    _is_decision_tree,
    _is_linear,
)


[docs] class RegressionReport: """Comprehensive regression model evaluation report. Generates a multi-section HTML report with metrics, charts, and model interpretability for any sklearn-compatible regressor. Parameters ---------- model : estimator Fitted sklearn-compatible regressor. X : array-like Test features. y : array-like True target values. feature_names : list of str, optional Feature names. model_name : str, optional Display name for the model. dataset_name : str, optional Display name for the dataset. palette : str, default='tableau' Color palette. theme : str, default='dark' 'dark' or 'light'. Examples -------- >>> from sklearn.ensemble import RandomForestRegressor >>> reg = RandomForestRegressor().fit(X_train, y_train) >>> report = RegressionReport(reg, X_test, y_test) >>> report.save("report.html") """ def __init__( self, model: Any, X: Any, y: Any, *, feature_names: Sequence[str] | None = None, model_name: str | None = None, dataset_name: str | None = None, palette: str = DEFAULT_CATEGORICAL, theme: str = "dark", ): self.model = model self.X = np.asarray(X) self.y = np.asarray(y).ravel().astype(float) self.feature_names = list(feature_names) if feature_names is not None else None self.model_name = model_name or type(model).__name__ self.dataset_name = dataset_name or "" self.palette = palette self.theme = theme # Predictions self.y_pred = np.asarray(model.predict(self.X)).ravel().astype(float) self.residuals = self.y - self.y_pred # Compute metrics self._metrics = self._compute_metrics() def _compute_metrics(self) -> dict[str, Any]: m = {} m["mae"] = round(float(mean_absolute_error(self.y, self.y_pred)), 4) m["mse"] = round(float(mean_squared_error(self.y, self.y_pred)), 4) m["rmse"] = round(float(np.sqrt(m["mse"])), 4) m["r2"] = round(float(r2_score(self.y, self.y_pred)), 4) m["explained_var"] = round(float(explained_variance_score(self.y, self.y_pred)), 4) m["median_ae"] = round(float(median_absolute_error(self.y, self.y_pred)), 4) m["max_error"] = round(float(max_error(self.y, self.y_pred)), 4) try: m["mape"] = round(float(mean_absolute_percentage_error(self.y, self.y_pred)), 4) except Exception: pass # Adjusted R² n = len(self.y) p = self.X.shape[1] if self.X.ndim > 1 else 1 if n - p - 1 > 0: m["adj_r2"] = round(1 - (1 - m["r2"]) * (n - 1) / (n - p - 1), 4) m["n_samples"] = n m["n_features"] = p # Residual statistics m["residual_mean"] = round(float(np.mean(self.residuals)), 4) m["residual_std"] = round(float(np.std(self.residuals)), 4) return m @property def metrics(self) -> dict[str, Any]: """Access computed metrics dictionary.""" return self._metrics
[docs] def save(self, filepath: str | Path, open_browser: bool = False) -> Path: """Save report as self-contained HTML.""" filepath = Path(filepath) if not filepath.suffix: filepath = filepath.with_suffix(".html") html = self._render() filepath.write_text(html, encoding="utf-8") if open_browser: import webbrowser webbrowser.open(filepath.resolve().as_uri()) return filepath.resolve()
def _repr_html_(self) -> str: """Jupyter inline display.""" return self._render() def _render(self) -> str: colors = get_palette(self.palette) m = self._metrics parts = [self.model_name] if self.dataset_name: parts.append(self.dataset_name) parts.append(f"{m['n_samples']} samples · {m['n_features']} features") subtitle = html_module.escape(" — ".join(parts)) # Metrics panel metrics_cards = [ ("R²", f"{m['r2']:.4f}"), ("RMSE", f"{m['rmse']:.4f}"), ("MAE", f"{m['mae']:.4f}"), ("Median AE", f"{m['median_ae']:.4f}"), ("Max Error", f"{m['max_error']:.4f}"), ("Explained Var", f"{m['explained_var']:.4f}"), ] if "adj_r2" in m: metrics_cards.append(("Adj R²", f"{m['adj_r2']:.4f}")) if "mape" in m: metrics_cards.append(("MAPE", f"{m['mape']:.2%}")) metrics_html = "\n".join( f'<div class="metric-card"><div class="metric-value">{val}</div>' f'<div class="metric-label">{lbl}</div></div>' for lbl, val in metrics_cards ) sections = [] chart_w, chart_h = 600, 420 # 1. Predicted vs Actual sections.append(self._section_pred_vs_actual(chart_w, chart_h, colors)) # 2. Residual Distribution sections.append(self._section_residual_hist(chart_w, chart_h, colors)) # 3. Residuals vs Predicted sections.append(self._section_residuals_vs_predicted(chart_w, chart_h, colors)) # 4. Residuals vs Index (order) sections.append(self._section_residuals_vs_index(chart_w, chart_h, colors)) # 5. Feature importances if hasattr(self.model, "feature_importances_"): sections.append(self._section_importances(chart_w, chart_h, colors)) # 6. QQ plot (residuals) sections.append(self._section_qq(chart_w, chart_h, colors)) footer_html = self._build_interpretability_footer() return render_report( title="Regression Report", subtitle=subtitle, theme=self.theme, metrics_html=metrics_html, sections=sections, footer_html=footer_html, ) # ------------------------------------------------------------------ # Chart sections # ------------------------------------------------------------------ def _section_pred_vs_actual(self, w, h, colors): # Subsample for performance n = len(self.y) max_pts = 1000 if n > max_pts: idx = np.random.choice(n, max_pts, replace=False) else: idx = np.arange(n) y_t = self.y[idx] y_p = self.y_pred[idx] # Regression line lo = float(min(y_t.min(), y_p.min())) hi = float(max(y_t.max(), y_p.max())) data = { "yTrue": [round(float(v), 6) for v in y_t], "yPred": [round(float(v), 6) for v in y_p], "lo": round(lo, 6), "hi": round(hi, 6), "r2": self._metrics["r2"], } config = {"width": w, "height": h, "palette": colors} return { "title": "Predicted vs Actual", "chart_id": "predact", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _PREDACT_JS, } def _section_residual_hist(self, w, h, colors): n_bins = 40 counts, edges = np.histogram(self.residuals, bins=n_bins) bins = [(edges[i] + edges[i + 1]) / 2 for i in range(n_bins)] data = { "bins": [round(float(b), 6) for b in bins], "counts": [int(c) for c in counts], "mean": round(float(np.mean(self.residuals)), 4), "std": round(float(np.std(self.residuals)), 4), } config = {"width": w, "height": h, "palette": colors} return { "title": "Residual Distribution", "chart_id": "reshist", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _RESHIST_JS, } def _section_residuals_vs_predicted(self, w, h, colors): n = len(self.y_pred) max_pts = 1000 if n > max_pts: idx = np.random.choice(n, max_pts, replace=False) else: idx = np.arange(n) data = { "yPred": [round(float(v), 6) for v in self.y_pred[idx]], "residuals": [round(float(v), 6) for v in self.residuals[idx]], } config = {"width": w, "height": h, "palette": colors} return { "title": "Residuals vs Predicted", "chart_id": "respred", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _RESPRED_JS, } def _section_residuals_vs_index(self, w, h, colors): n = len(self.residuals) max_pts = 1000 if n > max_pts: idx = np.random.choice(n, max_pts, replace=False) idx.sort() else: idx = np.arange(n) data = { "indices": [int(i) for i in idx], "residuals": [round(float(self.residuals[i]), 6) for i in idx], } config = {"width": w, "height": h, "palette": colors} return { "title": "Residuals vs Sample Index", "chart_id": "residx", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _RESIDX_JS, } def _section_importances(self, w, h, colors): imp = self.model.feature_importances_ names = self.feature_names or [f"Feature {i}" for i in range(len(imp))] top_n = min(20, len(imp)) idx = np.argsort(imp)[::-1][:top_n] data = { "labels": [names[i] for i in idx], "values": [round(float(imp[i]), 6) for i in idx], } config = {"width": w, "height": h, "palette": colors} return { "title": f"Feature Importances (Top {top_n})", "chart_id": "imp", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _IMP_SECTION_JS, } def _section_qq(self, w, h, colors): """Normal Q-Q plot of residuals.""" sorted_res = np.sort(self.residuals) n = len(sorted_res) # Theoretical quantiles from scipy.stats import norm theoretical = norm.ppf(np.linspace(1 / (n + 1), n / (n + 1), n)) # Subsample max_pts = 500 if n > max_pts: idx = np.linspace(0, n - 1, max_pts, dtype=int) else: idx = np.arange(n) data = { "theoretical": [round(float(theoretical[i]), 4) for i in idx], "observed": [round(float(sorted_res[i]), 4) for i in idx], } config = {"width": w, "height": h, "palette": colors} return { "title": "Q-Q Plot (Residuals)", "chart_id": "qq", "width": w, "height": h, "data_json": json.dumps(data), "config_json": json.dumps(config), "chart_js": _QQ_JS, } def _build_interpretability_footer(self) -> str: parts = [] if _is_decision_tree(self.model): rules = _extract_tree_rules(self.model, self.feature_names, None) if rules: parts.append('<div class="interp-section">') parts.append("<h2>Decision Tree Rules</h2>") parts.append('<ol class="rules-list">') for rule in rules[:30]: parts.append(f"<li>{html_module.escape(rule)}</li>") if len(rules) > 30: parts.append(f"<li>... and {len(rules) - 30} more rules</li>") parts.append("</ol></div>") if _is_linear(self.model): coefs = _extract_linear_coefs(self.model, self.feature_names) if coefs: parts.append('<div class="interp-section">') parts.append("<h2>Model Coefficients (Top 20 by |coef|)</h2>") parts.append('<ol class="rules-list">') for name, coef in coefs[:20]: sign = "+" if coef >= 0 else "" parts.append(f"<li>{html_module.escape(name)}: {sign}{coef:.4f}</li>") if hasattr(self.model, "intercept_"): intercept = float(np.asarray(self.model.intercept_).ravel()[0]) parts.append(f"<li>Intercept: {intercept:.4f}</li>") parts.append("</ol></div>") # Residual summary parts.append('<div class="report-footer">') parts.append("<h3>Residual Statistics</h3>") parts.append("<pre>") m = self._metrics parts.append(f"Mean: {m['residual_mean']:.4f}") parts.append(f"Std Dev: {m['residual_std']:.4f}") parts.append(f"Min: {float(np.min(self.residuals)):.4f}") parts.append(f"Max: {float(np.max(self.residuals)):.4f}") parts.append(f"Median: {float(np.median(self.residuals)):.4f}") parts.append("</pre></div>") return "\n".join(parts)
# =================================================================== # Section JavaScript # =================================================================== _PREDACT_JS = r""" function renderChart_predact(data, config, container) { const margin={top:10,right:15,bottom:50,left:55}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const lo=data.lo,hi=data.hi; const pad=(hi-lo)*0.05||1; const xS=EG.scaleLinear([lo-pad,hi+pad],[0,iW]); const yS=EG.scaleLinear([lo-pad,hi+pad],[iH,0]); EG.drawXAxis(g,xS,iH,'Actual'); EG.drawYAxis(g,yS,iW,'Predicted'); // Diagonal y=x g.appendChild(EG.svg('line',{x1:xS(lo-pad),y1:yS(lo-pad),x2:xS(hi+pad),y2:yS(hi+pad),stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); // Points const color=config.palette[0]; for(let i=0;i<data.yTrue.length;i++){ const dot=EG.svg('circle',{cx:xS(data.yTrue[i]),cy:yS(data.yPred[i]),r:3,fill:color,opacity:0.5}); dot.addEventListener('mouseenter',e=>{dot.setAttribute('r','5');dot.setAttribute('opacity','1');EG.tooltip.show(e,'Actual: '+EG.fmt(data.yTrue[i],3)+'<br>Predicted: '+EG.fmt(data.yPred[i],3)+'<br>Error: '+EG.fmt(data.yPred[i]-data.yTrue[i],3));}); dot.addEventListener('mouseleave',()=>{dot.setAttribute('r','3');dot.setAttribute('opacity','0.5');EG.tooltip.hide();}); g.appendChild(dot); } // R² annotation g.appendChild(EG.svg('text',{x:iW-5,y:16,'text-anchor':'end',fill:'var(--text-primary)','font-size':'13px','font-weight':'600'})).textContent='R² = '+data.r2.toFixed(4); } """ _RESHIST_JS = r""" function renderChart_reshist(data, config, container) { const margin={top:10,right:15,bottom:50,left:50}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const bins=data.bins,counts=data.counts; const maxC=Math.max.apply(null,counts)||1; const xMin=bins[0],xMax=bins[bins.length-1]; const xPad=(xMax-xMin)*0.05||1; const xS=EG.scaleLinear([xMin-xPad,xMax+xPad],[0,iW]); const yS=EG.scaleLinear([0,maxC*1.1],[iH,0]); EG.drawXAxis(g,xS,iH,'Residual'); EG.drawYAxis(g,yS,iW,'Count'); const barW=iW/bins.length*0.85; const color=config.palette[0]; bins.forEach((b,i)=>{ const x=xS(b)-barW/2; const bH=iH-yS(counts[i]); const rect=EG.svg('rect',{x:x,y:iH-bH,width:barW,height:Math.max(bH,0),fill:color,opacity:0.7,rx:2}); rect.addEventListener('mouseenter',e=>{rect.setAttribute('opacity','1');EG.tooltip.show(e,'Residual ≈ '+EG.fmt(b,3)+'<br>Count: '+counts[i]);}); rect.addEventListener('mouseleave',()=>{rect.setAttribute('opacity','0.7');EG.tooltip.hide();}); g.appendChild(rect); }); // Zero line if(xMin-xPad<=0 && xMax+xPad>=0){ g.appendChild(EG.svg('line',{x1:xS(0),y1:0,x2:xS(0),y2:iH,stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'4,3',opacity:0.5})); } // Annotation g.appendChild(EG.svg('text',{x:iW-5,y:16,'text-anchor':'end',fill:'var(--text-secondary)','font-size':'11px'})).textContent='μ='+data.mean.toFixed(3)+' σ='+data.std.toFixed(3); } """ _RESPRED_JS = r""" function renderChart_respred(data, config, container) { const margin={top:10,right:15,bottom:50,left:55}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const yp=data.yPred,res=data.residuals; let xMin=Infinity,xMax=-Infinity,yMin=Infinity,yMax=-Infinity; yp.forEach(v=>{if(v<xMin)xMin=v;if(v>xMax)xMax=v;}); res.forEach(v=>{if(v<yMin)yMin=v;if(v>yMax)yMax=v;}); const xPad=(xMax-xMin)*0.05||1,yPad=(yMax-yMin)*0.05||1; const xS=EG.scaleLinear([xMin-xPad,xMax+xPad],[0,iW]); const yS=EG.scaleLinear([yMin-yPad,yMax+yPad],[iH,0]); EG.drawXAxis(g,xS,iH,'Predicted'); EG.drawYAxis(g,yS,iW,'Residual'); // Zero line g.appendChild(EG.svg('line',{x1:0,y1:yS(0),x2:iW,y2:yS(0),stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); const color=config.palette[0]; for(let i=0;i<yp.length;i++){ const dot=EG.svg('circle',{cx:xS(yp[i]),cy:yS(res[i]),r:3,fill:color,opacity:0.5}); dot.addEventListener('mouseenter',e=>{dot.setAttribute('r','5');dot.setAttribute('opacity','1');EG.tooltip.show(e,'Predicted: '+EG.fmt(yp[i],3)+'<br>Residual: '+EG.fmt(res[i],3));}); dot.addEventListener('mouseleave',()=>{dot.setAttribute('r','3');dot.setAttribute('opacity','0.5');EG.tooltip.hide();}); g.appendChild(dot); } } """ _RESIDX_JS = r""" function renderChart_residx(data, config, container) { const margin={top:10,right:15,bottom:50,left:55}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const idx=data.indices,res=data.residuals; const xS=EG.scaleLinear([idx[0],idx[idx.length-1]],[0,iW]); let yMin=Infinity,yMax=-Infinity; res.forEach(v=>{if(v<yMin)yMin=v;if(v>yMax)yMax=v;}); const yPad=(yMax-yMin)*0.05||1; const yS=EG.scaleLinear([yMin-yPad,yMax+yPad],[iH,0]); EG.drawXAxis(g,xS,iH,'Sample Index'); EG.drawYAxis(g,yS,iW,'Residual'); g.appendChild(EG.svg('line',{x1:0,y1:yS(0),x2:iW,y2:yS(0),stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); const color=config.palette[0]; for(let i=0;i<idx.length;i++){ g.appendChild(EG.svg('circle',{cx:xS(idx[i]),cy:yS(res[i]),r:2.5,fill:color,opacity:0.45})); } } """ _QQ_JS = r""" function renderChart_qq(data, config, container) { const margin={top:10,right:15,bottom:50,left:55}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const th=data.theoretical,obs=data.observed; let xMin=Infinity,xMax=-Infinity,yMin=Infinity,yMax=-Infinity; th.forEach(v=>{if(v<xMin)xMin=v;if(v>xMax)xMax=v;}); obs.forEach(v=>{if(v<yMin)yMin=v;if(v>yMax)yMax=v;}); const xPad=(xMax-xMin)*0.05||1,yPad=(yMax-yMin)*0.05||1; const xS=EG.scaleLinear([xMin-xPad,xMax+xPad],[0,iW]); const yS=EG.scaleLinear([yMin-yPad,yMax+yPad],[iH,0]); EG.drawXAxis(g,xS,iH,'Theoretical Quantiles'); EG.drawYAxis(g,yS,iW,'Observed Residuals'); // Reference line (fit through Q1 and Q3) const lo=Math.min(xMin-xPad,yMin-yPad),hi=Math.max(xMax+xPad,yMax+yPad); // Simple diagonal for normal reference g.appendChild(EG.svg('line',{x1:xS(xMin-xPad),y1:yS(xMin-xPad),x2:xS(xMax+xPad),y2:yS(xMax+xPad),stroke:'var(--text-muted)','stroke-width':1.5,'stroke-dasharray':'6,4',opacity:0.5})); const color=config.palette[0]; for(let i=0;i<th.length;i++){ g.appendChild(EG.svg('circle',{cx:xS(th[i]),cy:yS(obs[i]),r:3,fill:color,opacity:0.5})); } } """ # Reuse importance JS from classification report _IMP_SECTION_JS = r""" function renderChart_imp(data, config, container) { const margin={top:10,right:30,bottom:30,left:140}; const W=config.width,H=config.height; const svg=EG.svg('svg',{width:W,height:H}); container.appendChild(svg); const g=EG.svg('g',{transform:`translate(${margin.left},${margin.top})`}); svg.appendChild(g); const iW=W-margin.left-margin.right,iH=H-margin.top-margin.bottom; const n=data.labels.length; const rowH=iH/n; const maxV=Math.max.apply(null,data.values)||1; const xS=EG.scaleLinear([0,maxV],[0,iW]); for(let i=0;i<n;i++){ const y=i*rowH,v=data.values[i]; const color=config.palette[i%config.palette.length]; const bW=xS(v); const rect=EG.svg('rect',{x:0,y:y+2,width:Math.max(bW,2),height:rowH-4,fill:color,rx:3,opacity:0.8}); rect.addEventListener('mouseenter',e=>{rect.setAttribute('opacity','1');EG.tooltip.show(e,'<b>'+EG.esc(data.labels[i])+'</b><br>'+EG.fmt(v,4));}); rect.addEventListener('mouseleave',()=>{rect.setAttribute('opacity','0.8');EG.tooltip.hide();}); g.appendChild(rect); g.appendChild(EG.svg('text',{x:bW+5,y:y+rowH/2+4,fill:'var(--text-secondary)','font-size':'10px'})).textContent=EG.fmt(v,4); g.appendChild(EG.svg('text',{x:-6,y:y+rowH/2+4,'text-anchor':'end',fill:'var(--text-primary)','font-size':'11px'})).textContent=data.labels[i].length>20?data.labels[i].slice(0,18)+'…':data.labels[i]; } } """