"""Interactive Decision Tree Visualization.
Generates gorgeous, self-contained HTML/JavaScript visualizations of decision
trees with expandable/collapsible branches, zoom, pan, and rich node tooltips.
Supports:
- sklearn DecisionTreeClassifier/Regressor
- C5.0 (C50Classifier, C50Ensemble)
- ObliqueDecisionTree (linear combination splits)
- EvolutionaryTree
- Any tree model with a `tree_` attribute following sklearn conventions
Example
-------
>>> from sklearn.tree import DecisionTreeClassifier
>>> from endgame.visualization import TreeVisualizer
>>>
>>> clf = DecisionTreeClassifier(max_depth=4).fit(X, y)
>>> viz = TreeVisualizer(clf, feature_names=['age', 'income', 'score'])
>>> viz.save("tree.html") # Open in browser for interactive visualization
"""
from __future__ import annotations
import html as html_module
import json
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
# ---------------------------------------------------------------------------
# Universal tree node representation for visualization
# ---------------------------------------------------------------------------
@dataclass
class VizNode:
"""Universal tree node for visualization.
All tree types are converted to this format before rendering.
"""
node_id: int = 0
is_leaf: bool = True
# Split info (internal nodes)
split_label: str = "" # e.g. "age <= 30.5" or "0.3*x1 + 0.7*x3 <= 2.1"
split_feature: str = "" # Primary feature name (for coloring)
split_type: str = "threshold" # "threshold", "oblique", "subset", "discrete"
# Node statistics
n_samples: int = 0
impurity: float = 0.0
impurity_name: str = "gini" # "gini", "entropy", "mse"
# Prediction info
predicted_class: str = ""
predicted_value: float | None = None
class_distribution: list[float] = field(default_factory=list)
class_names: list[str] = field(default_factory=list)
# Children
children: list[VizNode] = field(default_factory=list)
child_labels: list[str] = field(default_factory=list) # Edge labels
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dictionary."""
d: dict[str, Any] = {
"id": int(self.node_id),
"leaf": bool(self.is_leaf),
"split": str(self.split_label),
"feature": str(self.split_feature),
"splitType": str(self.split_type),
"samples": int(self.n_samples),
"impurity": round(float(self.impurity), 4),
"impurityName": str(self.impurity_name),
"prediction": str(self.predicted_class),
"classDist": [round(float(v), 4) for v in self.class_distribution],
"classNames": [str(n) for n in self.class_names],
"children": [c.to_dict() for c in self.children],
"childLabels": [str(l) for l in self.child_labels],
}
if self.predicted_value is not None:
d["predValue"] = round(float(self.predicted_value), 4)
return d
# ---------------------------------------------------------------------------
# Tree extraction adapters
# ---------------------------------------------------------------------------
def _extract_sklearn_tree(model, feature_names, class_names, is_classifier):
"""Extract tree structure from a sklearn DecisionTreeClassifier/Regressor."""
tree = model.tree_
n_nodes = tree.node_count
features = tree.feature
thresholds = tree.threshold
children_left = tree.children_left
children_right = tree.children_right
values = tree.value
n_samples_arr = tree.n_node_samples
impurities = tree.impurity
criterion = getattr(model, "criterion", "gini")
def build_node(node_idx):
is_leaf = bool(children_left[node_idx] == children_right[node_idx])
n_samp = int(n_samples_arr[node_idx])
imp = float(impurities[node_idx])
node = VizNode(
node_id=node_idx,
is_leaf=is_leaf,
n_samples=n_samp,
impurity=imp,
impurity_name=criterion,
)
if is_classifier:
class_counts = values[node_idx].flatten()
total = class_counts.sum()
dist = (class_counts / total).tolist() if total > 0 else []
pred_idx = int(np.argmax(class_counts))
node.class_distribution = dist
node.class_names = list(class_names) if class_names else [str(i) for i in range(len(class_counts))]
node.predicted_class = node.class_names[pred_idx] if node.class_names else str(pred_idx)
else:
val = float(values[node_idx].flatten()[0])
node.predicted_value = val
node.predicted_class = f"{val:.4f}"
if not is_leaf:
feat_idx = int(features[node_idx])
thresh = float(thresholds[node_idx])
feat_name = feature_names[feat_idx] if feature_names and feat_idx < len(feature_names) else f"feature_{feat_idx}"
node.split_label = f"{feat_name} \u2264 {thresh:.4g}"
node.split_feature = feat_name
node.split_type = "threshold"
left_child = build_node(int(children_left[node_idx]))
right_child = build_node(int(children_right[node_idx]))
node.children = [left_child, right_child]
node.child_labels = ["True", "False"]
return node
return build_node(0)
def _extract_c50_tree(tree_node, feature_names, class_names, node_counter=None):
"""Extract tree structure from a C5.0 TreeNode."""
if node_counter is None:
node_counter = [0]
node_id = node_counter[0]
node_counter[0] += 1
n_samp = int(tree_node.cases) if tree_node.cases else 0
dist = tree_node.class_dist.tolist() if hasattr(tree_node, 'class_dist') and len(tree_node.class_dist) > 0 else []
total = sum(dist) if dist else 0
norm_dist = [d / total for d in dist] if total > 0 else dist
node = VizNode(
node_id=node_id,
is_leaf=tree_node.is_leaf,
n_samples=n_samp,
impurity=float(tree_node.errors) if hasattr(tree_node, 'errors') else 0.0,
impurity_name="errors",
class_distribution=norm_dist,
class_names=list(class_names) if class_names else [str(i) for i in range(len(dist))],
predicted_class=class_names[tree_node.class_] if class_names and tree_node.class_ < len(class_names) else str(tree_node.class_),
)
if not tree_node.is_leaf and tree_node.branches:
feat_idx = tree_node.tested_attr
feat_name = feature_names[feat_idx] if feature_names and feat_idx is not None and feat_idx < len(feature_names) else f"feature_{feat_idx}"
node.split_feature = feat_name
from endgame.models.trees.c50 import NodeType
if tree_node.node_type == NodeType.THRESHOLD:
thresh = tree_node.threshold
node.split_label = f"{feat_name} \u2264 {thresh:.4g}"
node.split_type = "threshold"
node.child_labels = ["True", "False"]
elif tree_node.node_type == NodeType.DISCRETE:
node.split_label = f"{feat_name}"
node.split_type = "discrete"
node.child_labels = [f"= {i}" for i in range(len(tree_node.branches))]
elif tree_node.node_type == NodeType.SUBSET:
node.split_label = f"{feat_name} \u2208 subset"
node.split_type = "subset"
node.child_labels = [f"subset {i}" for i in range(len(tree_node.branches))]
else:
node.split_label = feat_name
node.split_type = "threshold"
for branch in tree_node.branches:
child = _extract_c50_tree(branch, feature_names, class_names, node_counter)
node.children.append(child)
if len(node.child_labels) < len(node.children):
node.child_labels.extend(["" for _ in range(len(node.children) - len(node.child_labels))])
return node
def _extract_oblique_tree(tree_node, feature_names, class_names, is_classifier, node_counter=None):
"""Extract tree structure from an ObliqueTreeNode."""
if node_counter is None:
node_counter = [0]
node_id = node_counter[0]
node_counter[0] += 1
node = VizNode(
node_id=node_id,
is_leaf=tree_node.is_leaf,
n_samples=int(tree_node.n_samples),
impurity=float(tree_node.impurity),
impurity_name="gini",
)
if tree_node.value is not None:
if is_classifier:
class_counts = tree_node.value.flatten()
total = class_counts.sum()
dist = (class_counts / total).tolist() if total > 0 else []
pred_idx = int(np.argmax(class_counts))
node.class_distribution = dist
node.class_names = list(class_names) if class_names else [str(i) for i in range(len(class_counts))]
node.predicted_class = node.class_names[pred_idx] if node.class_names else str(pred_idx)
else:
val = float(tree_node.value.flatten()[0])
node.predicted_value = val
node.predicted_class = f"{val:.4f}"
if not tree_node.is_leaf and tree_node.split is not None:
split = tree_node.split
# Build oblique split label
terms = []
for idx, coef in zip(split.feature_indices, split.coefficients):
if abs(coef) > 1e-10:
feat_name = feature_names[idx] if feature_names and idx < len(feature_names) else f"x{idx}"
terms.append(f"{coef:+.3g}\u00b7{feat_name}")
equation = " ".join(terms)
node.split_label = f"{equation} \u2264 {split.threshold:.4g}"
node.split_type = "oblique"
# Use dominant feature for coloring
dominant_idx = split.feature_indices[np.argmax(np.abs(split.coefficients))]
node.split_feature = feature_names[dominant_idx] if feature_names and dominant_idx < len(feature_names) else f"feature_{dominant_idx}"
if tree_node.left is not None:
left = _extract_oblique_tree(tree_node.left, feature_names, class_names, is_classifier, node_counter)
node.children.append(left)
if tree_node.right is not None:
right = _extract_oblique_tree(tree_node.right, feature_names, class_names, is_classifier, node_counter)
node.children.append(right)
node.child_labels = ["True", "False"][:len(node.children)]
return node
def _extract_evtree(tree_node, feature_names, class_names, is_classifier, node_counter=None):
"""Extract tree structure from an EvolutionaryTree TreeNode."""
if node_counter is None:
node_counter = [0]
node_id = node_counter[0]
node_counter[0] += 1
is_leaf = tree_node.is_leaf() if callable(tree_node.is_leaf) else tree_node.is_leaf
n_samp = int(tree_node.n_samples) if hasattr(tree_node, 'n_samples') else 0
node = VizNode(
node_id=node_id,
is_leaf=is_leaf,
n_samples=n_samp,
impurity_name="gini",
)
if tree_node.value is not None and len(tree_node.value) > 0:
if is_classifier:
class_counts = tree_node.value.flatten()
total = class_counts.sum()
dist = (class_counts / total).tolist() if total > 0 else []
pred_idx = int(np.argmax(class_counts))
node.class_distribution = dist
node.class_names = list(class_names) if class_names else [str(i) for i in range(len(class_counts))]
node.predicted_class = node.class_names[pred_idx] if node.class_names else str(pred_idx)
else:
val = float(tree_node.value.flatten()[0])
node.predicted_value = val
node.predicted_class = f"{val:.4f}"
if not is_leaf:
feat_idx = tree_node.feature_idx
thresh = tree_node.threshold
feat_name = feature_names[feat_idx] if feature_names and feat_idx < len(feature_names) else f"feature_{feat_idx}"
node.split_label = f"{feat_name} \u2264 {thresh:.4g}"
node.split_feature = feat_name
node.split_type = "threshold"
if tree_node.left is not None:
left = _extract_evtree(tree_node.left, feature_names, class_names, is_classifier, node_counter)
node.children.append(left)
if tree_node.right is not None:
right = _extract_evtree(tree_node.right, feature_names, class_names, is_classifier, node_counter)
node.children.append(right)
node.child_labels = ["True", "False"][:len(node.children)]
return node
# ---------------------------------------------------------------------------
# Main TreeVisualizer class
# ---------------------------------------------------------------------------
[docs]
class TreeVisualizer:
"""Interactive decision tree visualizer.
Generates self-contained HTML files with D3.js-powered interactive
tree visualizations featuring:
- Expandable/collapsible branches (click nodes)
- Zoom in/out and pan (mouse wheel + drag)
- Rich tooltips with node statistics
- Color-coded nodes by prediction or impurity
- Responsive layout
Parameters
----------
model : estimator
A fitted tree model. Supports:
- sklearn DecisionTreeClassifier/Regressor
- sklearn ensembles (extracts individual trees)
- C50Classifier / C50Ensemble
- ObliqueDecisionTreeClassifier/Regressor
- ObliqueRandomForestClassifier/Regressor
- EvolutionaryTreeClassifier/Regressor
feature_names : list of str, optional
Names for each feature. If None, uses "feature_0", "feature_1", etc.
class_names : list of str, optional
Names for each class (classification only).
tree_index : int, default=0
For ensemble models, which tree to visualize.
title : str, optional
Title displayed above the visualization.
color_by : str, default='prediction'
How to color nodes: 'prediction' (class color), 'impurity' (heatmap),
or 'samples' (by sample count).
max_depth : int, optional
Maximum depth to display. Deeper nodes are collapsed by default.
palette : str, default='tableau'
Color palette: 'tableau', 'viridis', 'pastel', or 'dark'.
Example
-------
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from endgame.visualization import TreeVisualizer
>>>
>>> X, y = load_iris(return_X_y=True)
>>> clf = DecisionTreeClassifier(max_depth=4, random_state=42).fit(X, y)
>>> viz = TreeVisualizer(
... clf,
... feature_names=load_iris().feature_names,
... class_names=load_iris().target_names.tolist(),
... title="Iris Decision Tree"
... )
>>> viz.save("iris_tree.html")
"""
def __init__(
self,
model,
feature_names: Sequence[str] | None = None,
class_names: Sequence[str] | None = None,
tree_index: int = 0,
title: str | None = None,
color_by: str = "prediction",
max_depth: int | None = None,
palette: str = "tableau",
):
self.model = model
self.feature_names = list(feature_names) if feature_names is not None else None
self.class_names = list(class_names) if class_names is not None else None
self.tree_index = tree_index
self.title = title
self.color_by = color_by
self.max_depth = max_depth
self.palette = palette
# Extract the tree
self._root = self._extract_tree()
def _detect_model_type(self):
"""Detect what type of tree model we have."""
model = self.model
cls_name = type(model).__name__
# Ensembles — extract single tree
if hasattr(model, 'estimators_'):
# GradientBoosting stores trees in a 2D ndarray
est = model.estimators_
if hasattr(est, 'shape') and len(est.shape) == 2:
if hasattr(est[0, 0], 'tree_'):
return 'sklearn_gb_ensemble'
elif hasattr(est, '__len__') and len(est) > 0:
first = est[0]
if hasattr(first, 'tree_'):
return 'sklearn_ensemble'
if cls_name in ('C50Ensemble',):
return 'c50_ensemble'
if cls_name in ('ObliqueRandomForestClassifier', 'ObliqueRandomForestRegressor'):
return 'oblique_forest'
if cls_name in ('RotationForestClassifier', 'RotationForestRegressor'):
return 'sklearn_ensemble'
# Single trees
if hasattr(model, 'tree_') and hasattr(model.tree_, 'feature'):
return 'sklearn'
if cls_name in ('C50Classifier',) and hasattr(model, 'tree_'):
return 'c50'
if cls_name in ('ObliqueDecisionTreeClassifier', 'ObliqueDecisionTreeRegressor'):
return 'oblique'
if cls_name in ('EvolutionaryTreeClassifier', 'EvolutionaryTreeRegressor'):
return 'evtree'
# Fallback: try sklearn-like
if hasattr(model, 'tree_'):
return 'sklearn'
raise ValueError(
f"Unsupported model type: {cls_name}. "
"TreeVisualizer supports sklearn trees, C5.0, ObliqueTree, "
"EvolutionaryTree, and their ensemble variants."
)
def _is_classifier(self) -> bool:
"""Check if the model is a classifier."""
from sklearn.base import is_classifier
model = self.model
if hasattr(model, 'estimators_'):
return is_classifier(model)
return is_classifier(model)
def _extract_tree(self) -> VizNode:
"""Extract tree from the fitted model into VizNode format."""
model_type = self._detect_model_type()
is_clf = self._is_classifier()
if model_type == 'sklearn_gb_ensemble':
# GradientBoosting: estimators_ is 2D (n_estimators, n_classes)
est = self.model.estimators_
idx = min(self.tree_index, est.shape[0] - 1)
tree_model = est[idx, 0]
return _extract_sklearn_tree(tree_model, self.feature_names, self.class_names, False)
elif model_type == 'sklearn_ensemble':
tree_model = self.model.estimators_[self.tree_index]
if hasattr(tree_model, 'estimators_'):
# Nested ensemble (e.g., bagging)
tree_model = tree_model.estimators_[0]
return _extract_sklearn_tree(tree_model, self.feature_names, self.class_names, is_clf)
elif model_type == 'sklearn':
return _extract_sklearn_tree(self.model, self.feature_names, self.class_names, is_clf)
elif model_type == 'c50':
return _extract_c50_tree(self.model.tree_, self.feature_names, self.class_names)
elif model_type == 'c50_ensemble':
trees = self.model.trees_ if hasattr(self.model, 'trees_') else [self.model.tree_]
idx = min(self.tree_index, len(trees) - 1)
return _extract_c50_tree(trees[idx], self.feature_names, self.class_names)
elif model_type == 'oblique':
root = self.model.tree_ if hasattr(self.model, 'tree_') else self.model.root_
return _extract_oblique_tree(root, self.feature_names, self.class_names, is_clf)
elif model_type == 'oblique_forest':
tree_model = self.model.estimators_[self.tree_index]
root = tree_model.tree_ if hasattr(tree_model, 'tree_') else tree_model.root_
return _extract_oblique_tree(root, self.feature_names, self.class_names, is_clf)
elif model_type == 'evtree':
root = self.model.tree_ if hasattr(self.model, 'tree_') else self.model.root_
return _extract_evtree(root, self.feature_names, self.class_names, is_clf)
else:
raise ValueError(f"Unknown model type: {model_type}")
[docs]
def to_json(self) -> str:
"""Export tree data as JSON string."""
return json.dumps(self._root.to_dict(), indent=2)
[docs]
def save(self, filepath: str | Path, open_browser: bool = False) -> Path:
"""Save interactive visualization as a self-contained HTML file.
Parameters
----------
filepath : str or Path
Output file path (should end in .html).
open_browser : bool, default=False
If True, open the file in the default web browser.
Returns
-------
Path
The absolute path to the saved file.
"""
filepath = Path(filepath)
if not filepath.suffix:
filepath = filepath.with_suffix('.html')
tree_json = self._root.to_dict()
title = self.title or "Decision Tree Visualization"
title_escaped = html_module.escape(title)
html_content = _generate_html(
tree_data=tree_json,
title=title_escaped,
color_by=self.color_by,
max_depth=self.max_depth,
palette=self.palette,
)
filepath.write_text(html_content, encoding='utf-8')
if open_browser:
import webbrowser
webbrowser.open(filepath.resolve().as_uri())
return filepath.resolve()
def _repr_html_(self) -> str:
"""Jupyter notebook display support."""
tree_json = self._root.to_dict()
title = self.title or "Decision Tree Visualization"
return _generate_html(
tree_data=tree_json,
title=html_module.escape(title),
color_by=self.color_by,
max_depth=self.max_depth,
palette=self.palette,
embedded=True,
)
# ---------------------------------------------------------------------------
# HTML/JS/CSS generation
# ---------------------------------------------------------------------------
_PALETTES = {
"tableau": [
"#4e79a7", "#f28e2b", "#e15759", "#76b7b2", "#59a14f",
"#edc948", "#b07aa1", "#ff9da7", "#9c755f", "#bab0ac",
],
"viridis": [
"#440154", "#482777", "#3e4989", "#31688e", "#26828e",
"#1f9e89", "#35b779", "#6ece58", "#b5de2b", "#fde725",
],
"pastel": [
"#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3",
"#fdb462", "#b3de69", "#fccde5", "#d9d9d9", "#bc80bd",
],
"dark": [
"#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e",
"#e6ab02", "#a6761d", "#666666", "#e41a1c", "#377eb8",
],
}
def _generate_html(
tree_data: dict,
title: str,
color_by: str,
max_depth: int | None,
palette: str,
embedded: bool = False,
) -> str:
"""Generate the complete self-contained HTML visualization."""
palette_colors = _PALETTES.get(palette, _PALETTES["tableau"])
tree_json_str = json.dumps(tree_data)
colors_json = json.dumps(palette_colors)
max_depth_js = str(max_depth) if max_depth is not None else "null"
# Height for embedded (Jupyter) vs standalone
container_height = "600px" if embedded else "100vh"
html = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{title}</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background: #0f1117;
color: #e0e0e0;
overflow: hidden;
}}
#app {{
width: 100vw;
height: {container_height};
display: flex;
flex-direction: column;
}}
/* Header */
.header {{
display: flex;
align-items: center;
justify-content: space-between;
padding: 12px 24px;
background: linear-gradient(135deg, #1a1d29, #252836);
border-bottom: 1px solid rgba(255,255,255,0.06);
z-index: 100;
flex-shrink: 0;
}}
.header h1 {{
font-size: 18px;
font-weight: 600;
color: #f0f0f0;
letter-spacing: -0.3px;
}}
.controls {{
display: flex;
gap: 8px;
align-items: center;
}}
.btn {{
padding: 6px 14px;
border: 1px solid rgba(255,255,255,0.12);
border-radius: 6px;
background: rgba(255,255,255,0.05);
color: #c0c0c0;
font-size: 13px;
cursor: pointer;
transition: all 0.2s;
user-select: none;
}}
.btn:hover {{
background: rgba(255,255,255,0.1);
color: #fff;
border-color: rgba(255,255,255,0.2);
}}
.btn.active {{
background: rgba(78, 121, 167, 0.3);
border-color: #4e79a7;
color: #a0c4e8;
}}
.stats {{
font-size: 12px;
color: #888;
padding: 0 12px;
}}
/* Tree container */
#tree-container {{
flex: 1;
overflow: hidden;
position: relative;
cursor: grab;
}}
#tree-container:active {{ cursor: grabbing; }}
#tree-svg {{
width: 100%;
height: 100%;
}}
/* Links */
.link {{
fill: none;
stroke: rgba(255,255,255,0.12);
stroke-width: 1.5px;
transition: stroke 0.3s, stroke-width 0.3s;
}}
.link:hover {{
stroke: rgba(255,255,255,0.3);
stroke-width: 2.5px;
}}
/* Nodes */
.node {{
cursor: pointer;
}}
.node-rect {{
rx: 8;
ry: 8;
stroke-width: 1.5px;
transition: all 0.3s;
filter: drop-shadow(0 2px 6px rgba(0,0,0,0.4));
}}
.node:hover .node-rect {{
stroke-width: 2.5px;
filter: drop-shadow(0 4px 12px rgba(0,0,0,0.6));
}}
.node.collapsed .node-rect {{
stroke-dasharray: 4 2;
}}
.node-label {{
font-size: 11px;
fill: #f0f0f0;
text-anchor: middle;
pointer-events: none;
font-weight: 500;
}}
.node-sublabel {{
font-size: 9.5px;
fill: #aaa;
text-anchor: middle;
pointer-events: none;
}}
.edge-label {{
font-size: 9px;
fill: #888;
text-anchor: middle;
pointer-events: none;
font-weight: 500;
}}
/* Collapse indicator */
.collapse-badge {{
fill: rgba(255,255,255,0.15);
stroke: rgba(255,255,255,0.2);
stroke-width: 1px;
}}
.collapse-text {{
font-size: 9px;
fill: #ccc;
text-anchor: middle;
dominant-baseline: central;
pointer-events: none;
font-weight: 600;
}}
/* Tooltip */
.tooltip {{
position: absolute;
background: linear-gradient(135deg, #252836, #1e2130);
border: 1px solid rgba(255,255,255,0.12);
border-radius: 10px;
padding: 14px 18px;
pointer-events: none;
opacity: 0;
transition: opacity 0.2s;
z-index: 1000;
min-width: 220px;
max-width: 360px;
box-shadow: 0 8px 32px rgba(0,0,0,0.5);
backdrop-filter: blur(10px);
}}
.tooltip.visible {{ opacity: 1; }}
.tooltip h3 {{
font-size: 13px;
font-weight: 600;
color: #f0f0f0;
margin-bottom: 8px;
border-bottom: 1px solid rgba(255,255,255,0.08);
padding-bottom: 6px;
}}
.tooltip .row {{
display: flex;
justify-content: space-between;
font-size: 12px;
padding: 2px 0;
}}
.tooltip .row .label {{ color: #888; }}
.tooltip .row .value {{ color: #ddd; font-weight: 500; }}
/* Distribution bar */
.dist-bar {{
display: flex;
height: 8px;
border-radius: 4px;
overflow: hidden;
margin: 8px 0 4px;
}}
.dist-bar .segment {{
transition: width 0.3s;
}}
.dist-legend {{
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-top: 4px;
}}
.dist-legend .item {{
display: flex;
align-items: center;
gap: 4px;
font-size: 10px;
color: #aaa;
}}
.dist-legend .swatch {{
width: 8px;
height: 8px;
border-radius: 2px;
}}
/* Minimap */
.minimap {{
position: absolute;
bottom: 16px;
right: 16px;
width: 180px;
height: 120px;
background: rgba(15, 17, 23, 0.85);
border: 1px solid rgba(255,255,255,0.1);
border-radius: 8px;
overflow: hidden;
z-index: 50;
}}
.minimap-viewport {{
stroke: #4e79a7;
stroke-width: 1.5px;
fill: rgba(78, 121, 167, 0.1);
}}
/* Search */
.search-box {{
position: relative;
}}
.search-box input {{
padding: 6px 12px;
border: 1px solid rgba(255,255,255,0.12);
border-radius: 6px;
background: rgba(255,255,255,0.05);
color: #e0e0e0;
font-size: 13px;
width: 160px;
outline: none;
transition: all 0.2s;
}}
.search-box input:focus {{
border-color: #4e79a7;
background: rgba(255,255,255,0.08);
width: 200px;
}}
.search-box input::placeholder {{ color: #666; }}
/* Zoom indicator */
.zoom-level {{
position: absolute;
bottom: 16px;
left: 16px;
background: rgba(15, 17, 23, 0.8);
border: 1px solid rgba(255,255,255,0.1);
border-radius: 6px;
padding: 6px 12px;
font-size: 11px;
color: #888;
z-index: 50;
}}
</style>
</head>
<body>
<div id="app">
<div class="header">
<h1>{title}</h1>
<div class="controls">
<div class="search-box">
<input type="text" id="search-input" placeholder="Search features..." />
</div>
<button class="btn" onclick="expandAll()" title="Expand all nodes">Expand All</button>
<button class="btn" onclick="collapseAll()" title="Collapse all nodes">Collapse All</button>
<button class="btn" onclick="resetZoom()" title="Reset view">Reset View</button>
<button class="btn" id="btn-color" onclick="cycleColor()" title="Change color mode">Color: prediction</button>
<span class="stats" id="stats-text"></span>
</div>
</div>
<div id="tree-container">
<svg id="tree-svg"></svg>
<div class="tooltip" id="tooltip"></div>
<div class="zoom-level" id="zoom-level">100%</div>
</div>
</div>
<script>
// ===== D3.js v7 (minified inline) =====
// Using a lightweight subset of D3 functionality built from scratch
// to keep the file self-contained without external CDN dependencies.
</script>
<script>
(function() {{
"use strict";
// ===== Data =====
const treeData = {tree_json_str};
const PALETTE = {colors_json};
const MAX_DEPTH = {max_depth_js};
const COLOR_MODES = ['prediction', 'impurity', 'samples'];
let colorMode = '{color_by}';
// ===== State =====
let root;
let nodeId = 0;
let transform = {{ x: 0, y: 0, k: 1 }};
let dragging = false;
let dragStart = {{ x: 0, y: 0 }};
let selectedNode = null;
// Layout constants
const NODE_W = 160;
const NODE_H = 52;
const H_GAP = 24;
const V_GAP = 72;
const ANIMATION_MS = 400;
// ===== Tree data processing =====
function processNode(data, depth, parent) {{
const node = {{
id: nodeId++,
data: data,
depth: depth,
parent: parent,
children: [],
_children: null,
x: 0,
y: 0,
collapsed: false,
}};
if (data.children && data.children.length > 0) {{
for (const childData of data.children) {{
node.children.push(processNode(childData, depth + 1, node));
}}
// Auto-collapse beyond max depth
if (MAX_DEPTH !== null && depth >= MAX_DEPTH) {{
node._children = node.children;
node.children = [];
node.collapsed = true;
}}
}}
return node;
}}
function countDescendants(node) {{
if (!node.children || node.children.length === 0) {{
if (!node._children || node._children.length === 0) return 1;
}}
let count = 0;
const kids = node.children.length > 0 ? node.children : (node._children || []);
for (const c of kids) count += countDescendants(c);
return Math.max(count, 1);
}}
function countVisible(node) {{
if (node.children.length === 0) return 1;
let count = 0;
for (const c of node.children) count += countVisible(c);
return Math.max(count, 1);
}}
// ===== Layout (Reingold-Tilford inspired) =====
function layoutTree(root) {{
// First pass: compute leaf counts for spacing
assignY(root, 0);
// Second pass: assign X based on leaf positions
let leafIndex = 0;
assignX(root);
function assignY(node, depth) {{
node.y = depth * (NODE_H + V_GAP);
node.depth = depth;
for (const child of node.children) {{
assignY(child, depth + 1);
}}
}}
function assignX(node) {{
if (node.children.length === 0) {{
node.x = leafIndex * (NODE_W + H_GAP);
leafIndex++;
return;
}}
for (const child of node.children) {{
assignX(child);
}}
// Center parent above children
const first = node.children[0];
const last = node.children[node.children.length - 1];
node.x = (first.x + last.x) / 2;
}}
}}
// ===== Rendering =====
const svg = document.getElementById('tree-svg');
const container = document.getElementById('tree-container');
const tooltip = document.getElementById('tooltip');
const zoomLabel = document.getElementById('zoom-level');
const statsText = document.getElementById('stats-text');
function createSVGElement(tag) {{
return document.createElementNS('http://www.w3.org/2000/svg', tag);
}}
function clearSVG() {{
while (svg.firstChild) svg.removeChild(svg.firstChild);
}}
function getNodeColor(node) {{
const d = node.data;
if (colorMode === 'prediction') {{
if (d.classDist && d.classDist.length > 0) {{
const maxIdx = d.classDist.indexOf(Math.max(...d.classDist));
return PALETTE[maxIdx % PALETTE.length];
}}
return PALETTE[0];
}} else if (colorMode === 'impurity') {{
// Heatmap: low impurity = blue, high = red
const imp = Math.min(d.impurity, 0.5) / 0.5;
const r = Math.round(30 + imp * 200);
const g = Math.round(80 - imp * 50);
const b = Math.round(180 - imp * 150);
return `rgb(${{r}},${{g}},${{b}})`;
}} else {{ // samples
// Size-based: more samples = brighter
const maxSamples = root.data.samples || 1;
const ratio = Math.min((d.samples || 0) / maxSamples, 1);
const intensity = Math.round(40 + ratio * 160);
return `rgb(${{Math.round(30 + ratio * 60)}}, ${{intensity}}, ${{Math.round(120 + ratio * 80)}})`;
}}
}}
function getNodeStroke(node) {{
const color = getNodeColor(node);
return color;
}}
function getNodeFill(node) {{
const color = getNodeColor(node);
// Parse and darken
if (color.startsWith('rgb')) {{
const m = color.match(/\\d+/g);
if (m) {{
return `rgba(${{Math.round(m[0]*0.25)}},${{Math.round(m[1]*0.25)}},${{Math.round(m[2]*0.25)}},0.85)`;
}}
}}
// Hex
const r = parseInt(color.slice(1,3), 16);
const g = parseInt(color.slice(3,5), 16);
const b = parseInt(color.slice(5,7), 16);
return `rgba(${{Math.round(r*0.25)}},${{Math.round(g*0.25)}},${{Math.round(b*0.25)}},0.85)`;
}}
function render() {{
clearSVG();
layoutTree(root);
const g = createSVGElement('g');
g.setAttribute('id', 'tree-group');
svg.appendChild(g);
// Apply transform
updateTransform();
// Draw links first (under nodes)
drawLinks(g, root);
// Draw nodes on top
drawNodes(g, root);
updateStats();
}}
function drawLinks(parent, node) {{
for (let i = 0; i < node.children.length; i++) {{
const child = node.children[i];
const link = createSVGElement('path');
const x1 = node.x + NODE_W / 2;
const y1 = node.y + NODE_H;
const x2 = child.x + NODE_W / 2;
const y2 = child.y;
const midY = (y1 + y2) / 2;
link.setAttribute('d', `M${{x1}},${{y1}} C${{x1}},${{midY}} ${{x2}},${{midY}} ${{x2}},${{y2}}`);
link.setAttribute('class', 'link');
parent.appendChild(link);
// Edge label
const edgeLabels = node.data.childLabels || [];
if (edgeLabels[i]) {{
const label = createSVGElement('text');
label.setAttribute('x', (x1 + x2) / 2);
label.setAttribute('y', midY - 4);
label.setAttribute('class', 'edge-label');
label.textContent = edgeLabels[i];
parent.appendChild(label);
}}
drawLinks(parent, child);
}}
}}
function drawNodes(parent, node) {{
const g = createSVGElement('g');
g.setAttribute('class', 'node' + (node.collapsed ? ' collapsed' : ''));
g.setAttribute('transform', `translate(${{node.x}},${{node.y}})`);
// Node rectangle
const rect = createSVGElement('rect');
rect.setAttribute('width', NODE_W);
rect.setAttribute('height', NODE_H);
rect.setAttribute('class', 'node-rect');
rect.setAttribute('fill', getNodeFill(node));
rect.setAttribute('stroke', getNodeStroke(node));
g.appendChild(rect);
// Main label (split condition or prediction)
const label = createSVGElement('text');
label.setAttribute('x', NODE_W / 2);
label.setAttribute('y', 20);
label.setAttribute('class', 'node-label');
const labelText = node.data.leaf
? (node.data.prediction || 'leaf')
: truncateText(node.data.split || 'split', 22);
label.textContent = labelText;
g.appendChild(label);
// Sub-label (samples / impurity)
const sublabel = createSVGElement('text');
sublabel.setAttribute('x', NODE_W / 2);
sublabel.setAttribute('y', 38);
sublabel.setAttribute('class', 'node-sublabel');
sublabel.textContent = `n=${{node.data.samples}}`;
if (!node.data.leaf) {{
sublabel.textContent += ` | ${{node.data.impurityName}}=${{node.data.impurity.toFixed(3)}}`;
}}
g.appendChild(sublabel);
// Collapse indicator badge
if (node.collapsed && node._children && node._children.length > 0) {{
const badgeR = 10;
const badge = createSVGElement('circle');
badge.setAttribute('cx', NODE_W / 2);
badge.setAttribute('cy', NODE_H + 8);
badge.setAttribute('r', badgeR);
badge.setAttribute('class', 'collapse-badge');
g.appendChild(badge);
const badgeText = createSVGElement('text');
badgeText.setAttribute('x', NODE_W / 2);
badgeText.setAttribute('y', NODE_H + 8);
badgeText.setAttribute('class', 'collapse-text');
const hiddenCount = countDescendants({{ children: node._children, _children: null }});
badgeText.textContent = `+${{hiddenCount}}`;
g.appendChild(badgeText);
}}
// Click handler: expand/collapse
g.addEventListener('click', (e) => {{
e.stopPropagation();
toggleNode(node);
}});
// Hover handlers for tooltip
g.addEventListener('mouseenter', (e) => showTooltip(e, node));
g.addEventListener('mouseleave', hideTooltip);
parent.appendChild(g);
// Recurse for visible children
for (const child of node.children) {{
drawNodes(parent, child);
}}
}}
function truncateText(text, maxLen) {{
if (text.length <= maxLen) return text;
return text.slice(0, maxLen - 1) + '\u2026';
}}
// ===== Node toggle =====
function toggleNode(node) {{
if (node.collapsed) {{
// Expand
node.children = node._children || [];
node._children = null;
node.collapsed = false;
}} else if (node.children.length > 0) {{
// Collapse
node._children = node.children;
node.children = [];
node.collapsed = true;
}}
render();
}}
function expandAll() {{
function expand(node) {{
if (node._children) {{
node.children = node._children;
node._children = null;
node.collapsed = false;
}}
for (const c of node.children) expand(c);
}}
expand(root);
render();
fitToView();
}}
function collapseAll() {{
function collapse(node) {{
for (const c of node.children) collapse(c);
if (node.children.length > 0) {{
node._children = node.children;
node.children = [];
node.collapsed = true;
}}
}}
if (root.children.length > 0) {{
for (const c of root.children) {{
collapse(c);
}}
}}
render();
fitToView();
}}
// ===== Tooltip =====
function showTooltip(event, node) {{
const d = node.data;
let html = `<h3>${{d.leaf ? 'Leaf Node' : 'Decision Node'}}</h3>`;
if (!d.leaf) {{
html += `<div class="row"><span class="label">Split</span><span class="value">${{escapeHtml(d.split)}}</span></div>`;
if (d.splitType === 'oblique') {{
html += `<div class="row"><span class="label">Type</span><span class="value">Oblique (multi-feature)</span></div>`;
}}
}}
html += `<div class="row"><span class="label">Samples</span><span class="value">${{d.samples.toLocaleString()}}</span></div>`;
html += `<div class="row"><span class="label">${{capitalize(d.impurityName)}}</span><span class="value">${{d.impurity.toFixed(4)}}</span></div>`;
html += `<div class="row"><span class="label">Prediction</span><span class="value">${{escapeHtml(d.prediction)}}</span></div>`;
if (d.predValue !== undefined) {{
html += `<div class="row"><span class="label">Value</span><span class="value">${{d.predValue.toFixed(4)}}</span></div>`;
}}
// Class distribution bar
if (d.classDist && d.classDist.length > 0) {{
html += '<div class="dist-bar">';
for (let i = 0; i < d.classDist.length; i++) {{
const pct = (d.classDist[i] * 100).toFixed(1);
html += `<div class="segment" style="width:${{pct}}%;background:${{PALETTE[i % PALETTE.length]}}"></div>`;
}}
html += '</div>';
html += '<div class="dist-legend">';
for (let i = 0; i < d.classDist.length; i++) {{
const name = d.classNames[i] || `Class ${{i}}`;
const pct = (d.classDist[i] * 100).toFixed(1);
html += `<span class="item"><span class="swatch" style="background:${{PALETTE[i % PALETTE.length]}}"></span>${{escapeHtml(name)}}: ${{pct}}%</span>`;
}}
html += '</div>';
}}
tooltip.innerHTML = html;
tooltip.classList.add('visible');
const rect = container.getBoundingClientRect();
let left = event.clientX - rect.left + 16;
let top = event.clientY - rect.top - 10;
// Keep tooltip in bounds
const tw = tooltip.offsetWidth;
const th = tooltip.offsetHeight;
if (left + tw > rect.width - 16) left = event.clientX - rect.left - tw - 16;
if (top + th > rect.height - 16) top = rect.height - th - 16;
if (top < 8) top = 8;
tooltip.style.left = left + 'px';
tooltip.style.top = top + 'px';
}}
function hideTooltip() {{
tooltip.classList.remove('visible');
}}
function escapeHtml(str) {{
const div = document.createElement('div');
div.textContent = str;
return div.innerHTML;
}}
function capitalize(s) {{
return s.charAt(0).toUpperCase() + s.slice(1);
}}
// ===== Zoom & Pan =====
function updateTransform() {{
const g = document.getElementById('tree-group');
if (g) {{
g.setAttribute('transform', `translate(${{transform.x}},${{transform.y}}) scale(${{transform.k}})`);
}}
zoomLabel.textContent = Math.round(transform.k * 100) + '%';
}}
container.addEventListener('wheel', (e) => {{
e.preventDefault();
const rect = container.getBoundingClientRect();
const mx = e.clientX - rect.left;
const my = e.clientY - rect.top;
const delta = e.deltaY > 0 ? 0.9 : 1.1;
const newK = Math.max(0.05, Math.min(5, transform.k * delta));
// Zoom toward cursor
transform.x = mx - (mx - transform.x) * (newK / transform.k);
transform.y = my - (my - transform.y) * (newK / transform.k);
transform.k = newK;
updateTransform();
}}, {{ passive: false }});
container.addEventListener('mousedown', (e) => {{
if (e.target.closest('.node')) return;
dragging = true;
dragStart.x = e.clientX - transform.x;
dragStart.y = e.clientY - transform.y;
}});
window.addEventListener('mousemove', (e) => {{
if (!dragging) return;
transform.x = e.clientX - dragStart.x;
transform.y = e.clientY - dragStart.y;
updateTransform();
}});
window.addEventListener('mouseup', () => {{ dragging = false; }});
// Touch support
let lastTouchDist = 0;
container.addEventListener('touchstart', (e) => {{
if (e.touches.length === 1) {{
dragging = true;
dragStart.x = e.touches[0].clientX - transform.x;
dragStart.y = e.touches[0].clientY - transform.y;
}} else if (e.touches.length === 2) {{
lastTouchDist = Math.hypot(
e.touches[0].clientX - e.touches[1].clientX,
e.touches[0].clientY - e.touches[1].clientY
);
}}
}}, {{ passive: true }});
container.addEventListener('touchmove', (e) => {{
e.preventDefault();
if (e.touches.length === 1 && dragging) {{
transform.x = e.touches[0].clientX - dragStart.x;
transform.y = e.touches[0].clientY - dragStart.y;
updateTransform();
}} else if (e.touches.length === 2) {{
const dist = Math.hypot(
e.touches[0].clientX - e.touches[1].clientX,
e.touches[0].clientY - e.touches[1].clientY
);
const delta = dist / lastTouchDist;
transform.k = Math.max(0.05, Math.min(5, transform.k * delta));
lastTouchDist = dist;
updateTransform();
}}
}}, {{ passive: false }});
container.addEventListener('touchend', () => {{ dragging = false; }});
function fitToView() {{
// Find bounds of all visible nodes
let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
function traverse(node) {{
minX = Math.min(minX, node.x);
maxX = Math.max(maxX, node.x + NODE_W);
minY = Math.min(minY, node.y);
maxY = Math.max(maxY, node.y + NODE_H);
for (const c of node.children) traverse(c);
}}
traverse(root);
const rect = container.getBoundingClientRect();
const padding = 60;
const treeW = maxX - minX + padding * 2;
const treeH = maxY - minY + padding * 2;
const scaleX = rect.width / treeW;
const scaleY = rect.height / treeH;
const k = Math.min(scaleX, scaleY, 1.5);
transform.k = k;
transform.x = (rect.width - (maxX - minX) * k) / 2 - minX * k;
transform.y = (rect.height - (maxY - minY) * k) / 2 - minY * k + 20;
updateTransform();
}}
window.resetZoom = fitToView;
window.expandAll = expandAll;
window.collapseAll = collapseAll;
// ===== Color mode cycling =====
window.cycleColor = function() {{
const idx = COLOR_MODES.indexOf(colorMode);
colorMode = COLOR_MODES[(idx + 1) % COLOR_MODES.length];
document.getElementById('btn-color').textContent = 'Color: ' + colorMode;
render();
// Preserve current transform after re-render
updateTransform();
}};
// ===== Search =====
const searchInput = document.getElementById('search-input');
searchInput.addEventListener('input', (e) => {{
const query = e.target.value.toLowerCase().trim();
if (!query) {{
// Reset highlight
document.querySelectorAll('.node-rect').forEach(r => r.style.opacity = '1');
return;
}}
function searchNode(node, parentG, idx) {{
const g = parentG.querySelectorAll(':scope > g.node');
// Highlight matching nodes
const allNodes = document.querySelectorAll('.node');
allNodes.forEach((nodeEl, i) => {{
const rect = nodeEl.querySelector('.node-rect');
const label = nodeEl.querySelector('.node-label');
if (label && label.textContent.toLowerCase().includes(query)) {{
rect.style.opacity = '1';
rect.style.strokeWidth = '3px';
}} else {{
rect.style.opacity = '0.3';
rect.style.strokeWidth = '1.5px';
}}
}});
}}
// Simple approach: check all node labels
const allNodes = document.querySelectorAll('.node');
allNodes.forEach((nodeEl) => {{
const rect = nodeEl.querySelector('.node-rect');
const labels = nodeEl.querySelectorAll('.node-label, .node-sublabel');
let match = false;
labels.forEach(l => {{
if (l.textContent.toLowerCase().includes(query)) match = true;
}});
if (rect) {{
rect.style.opacity = match ? '1' : '0.2';
rect.style.strokeWidth = match ? '3px' : '1.5px';
}}
}});
}});
// ===== Stats =====
function updateStats() {{
let totalNodes = 0, visibleNodes = 0, leaves = 0, maxDepth = 0;
function count(node, depth) {{
totalNodes++;
visibleNodes++;
if (node.data.leaf || (node.children.length === 0 && (!node._children || node._children.length === 0))) leaves++;
maxDepth = Math.max(maxDepth, depth);
for (const c of node.children) count(c, depth + 1);
if (node._children) {{
function countHidden(n) {{
totalNodes++;
const kids = n.children || n._children || [];
if (n.data && n.data.children) {{
for (const cd of n.data.children) countHidden({{ data: cd }});
}}
}}
// Count hidden via _children tree structure
for (const hc of node._children) {{
function countAll(nd) {{
totalNodes++;
for (const c of nd.children) countAll(c);
if (nd._children) for (const c of nd._children) countAll(c);
}}
countAll(hc);
}}
}}
}}
totalNodes = 0; visibleNodes = 0; leaves = 0; maxDepth = 0;
// Simple count of visible
function simpleCount(node, depth) {{
visibleNodes++;
if (node.children.length === 0) leaves++;
maxDepth = Math.max(maxDepth, depth);
for (const c of node.children) simpleCount(c, depth + 1);
}}
simpleCount(root, 0);
statsText.textContent = `${{visibleNodes}} nodes | ${{leaves}} leaves | depth ${{maxDepth}}`;
}}
// ===== Keyboard shortcuts =====
document.addEventListener('keydown', (e) => {{
if (e.target === searchInput) return;
if (e.key === '+' || e.key === '=') {{
transform.k = Math.min(5, transform.k * 1.2);
updateTransform();
}} else if (e.key === '-') {{
transform.k = Math.max(0.05, transform.k / 1.2);
updateTransform();
}} else if (e.key === '0') {{
fitToView();
}} else if (e.key === 'e') {{
expandAll();
}} else if (e.key === 'c') {{
collapseAll();
}} else if (e.key === '/') {{
e.preventDefault();
searchInput.focus();
}}
}});
// ===== Init =====
root = processNode(treeData, 0, null);
render();
// Auto fit after initial render
requestAnimationFrame(() => {{
fitToView();
}});
// Handle window resize
window.addEventListener('resize', () => {{ fitToView(); }});
}})();
</script>
</body>
</html>"""
return html