from __future__ import annotations
"""Oblique Decision Trees for Oblique Random Forests.
Oblique decision trees use linear combinations of features for splits,
allowing them to capture diagonal decision boundaries more efficiently
than axis-aligned trees.
These trees are primarily used as base estimators for ObliqueRandomForest
but can be used standalone for interpretable oblique splits.
"""
from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_random_state
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from endgame.models.trees.oblique_splits import (
ObliqueSplit,
compute_entropy,
compute_gini,
compute_mae,
compute_mse,
find_best_oblique_split,
)
@dataclass
class ObliqueTreeNode:
"""A node in an oblique decision tree.
Attributes
----------
is_leaf : bool
Whether this is a leaf node.
split : ObliqueSplit or None
Split information (None for leaf nodes).
left : ObliqueTreeNode or None
Left child (samples where split condition is True).
right : ObliqueTreeNode or None
Right child (samples where split condition is False).
value : ndarray
For classification: class counts or probabilities.
For regression: mean target value.
n_samples : int
Number of training samples at this node.
impurity : float
Impurity (Gini, entropy, or MSE) at this node.
depth : int
Depth of this node in the tree.
node_id : int
Unique identifier for this node.
"""
is_leaf: bool = True
split: ObliqueSplit | None = None
left: Optional[ObliqueTreeNode] = None
right: Optional[ObliqueTreeNode] = None
value: np.ndarray | None = None
n_samples: int = 0
impurity: float = 0.0
depth: int = 0
node_id: int = 0
[docs]
class ObliqueDecisionTreeClassifier(ClassifierMixin, BaseEstimator):
"""A single oblique decision tree for classification.
This is the base estimator used by ObliqueRandomForestClassifier.
Uses linear combinations of features for splits, enabling better
capture of diagonal decision boundaries.
Parameters
----------
oblique_method : str, default='ridge'
Method for finding oblique splits:
- 'ridge': Ridge regression on class labels (recommended)
- 'pca': Principal Component Analysis
- 'lda': Linear Discriminant Analysis
- 'random': Random projections (fastest)
- 'svm': Linear SVM hyperplane
- 'householder': Householder reflections
criterion : str, default='gini'
Splitting criterion: 'gini' or 'entropy'.
max_depth : int, default=None
Maximum tree depth. None means unlimited.
min_samples_split : int or float, default=2
Minimum samples required to split a node.
If float, fraction of total samples.
min_samples_leaf : int or float, default=1
Minimum samples required at a leaf.
If float, fraction of total samples.
max_features : int, float, str, or None, default=None
Features to consider per split:
- int: Use exactly max_features
- float: Use max_features * n_features (fraction)
- 'sqrt': Use sqrt(n_features)
- 'log2': Use log2(n_features)
- None: Use all features
min_impurity_decrease : float, default=0.0
Minimum impurity decrease required for split.
random_state : int, RandomState, or None, default=None
Random seed.
ridge_alpha : float, default=1.0
Ridge regularization for 'ridge' method.
feature_combinations : int, default=2
Features per random combination (for 'random' method).
Attributes
----------
tree_ : ObliqueTreeNode
The root node of the fitted tree.
classes_ : ndarray of shape (n_classes,)
Unique class labels.
n_classes_ : int
Number of classes.
n_features_in_ : int
Number of features seen during fit.
feature_importances_ : ndarray of shape (n_features_in_,)
Impurity-based feature importances.
n_nodes_ : int
Number of nodes in the tree.
"""
_estimator_type = "classifier"
def __init__(
self,
oblique_method: str = "ridge",
criterion: str = "gini",
max_depth: int | None = None,
min_samples_split: int | float = 2,
min_samples_leaf: int | float = 1,
max_features: int | float | str | None = None,
min_impurity_decrease: float = 0.0,
random_state: int | np.random.RandomState | None = None,
ridge_alpha: float = 1.0,
feature_combinations: int = 2,
):
self.oblique_method = oblique_method
self.criterion = criterion
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.max_features = max_features
self.min_impurity_decrease = min_impurity_decrease
self.random_state = random_state
self.ridge_alpha = ridge_alpha
self.feature_combinations = feature_combinations
def _resolve_max_features(self, n_features: int) -> int:
"""Resolve max_features parameter to an integer."""
if self.max_features is None:
return n_features
elif isinstance(self.max_features, int):
return min(self.max_features, n_features)
elif isinstance(self.max_features, float):
return max(1, int(self.max_features * n_features))
elif self.max_features == "sqrt":
return max(1, int(np.sqrt(n_features)))
elif self.max_features == "log2":
return max(1, int(np.log2(n_features)))
else:
raise ValueError(f"Invalid max_features: {self.max_features}")
def _resolve_min_samples(self, param: int | float, n_samples: int) -> int:
"""Resolve min_samples parameter to an integer."""
if isinstance(param, float):
return max(1, int(param * n_samples))
return param
def _get_impurity_func(self) -> Callable:
"""Get the impurity function based on criterion."""
if self.criterion == "gini":
return compute_gini
elif self.criterion == "entropy":
return compute_entropy
else:
raise ValueError(f"Invalid criterion: {self.criterion}")
[docs]
def fit(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray | None = None,
) -> ObliqueDecisionTreeClassifier:
"""Build the oblique decision tree.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target class labels.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
self : object
Fitted estimator.
"""
X, y = check_X_y(X, y)
n_samples, n_features = X.shape
self.n_features_in_ = n_features
# Encode class labels
self._label_encoder = LabelEncoder()
y_encoded = self._label_encoder.fit_transform(y)
self.classes_ = self._label_encoder.classes_
self.n_classes_ = len(self.classes_)
# Resolve parameters
self._max_features = self._resolve_max_features(n_features)
self._min_samples_split = self._resolve_min_samples(
self.min_samples_split, n_samples
)
self._min_samples_leaf = self._resolve_min_samples(
self.min_samples_leaf, n_samples
)
# Random state
self._rng = check_random_state(self.random_state)
# Impurity function
self._impurity_func = self._get_impurity_func()
# Build tree
self._node_count = 0
self.tree_ = self._build_tree(X, y_encoded, sample_weight, depth=0)
self.n_nodes_ = self._node_count
# Compute feature importances
self._compute_feature_importances()
return self
def _build_tree(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray | None,
depth: int,
) -> ObliqueTreeNode:
"""Recursively build the oblique decision tree.
Parameters
----------
X : ndarray of shape (n_samples, n_features)
Training data at this node.
y : ndarray of shape (n_samples,)
Target values (encoded).
sample_weight : ndarray or None
Sample weights.
depth : int
Current depth in tree.
Returns
-------
node : ObliqueTreeNode
The constructed tree node.
"""
n_samples = len(y)
# Create node
node = ObliqueTreeNode()
node.n_samples = n_samples
node.depth = depth
node.node_id = self._node_count
self._node_count += 1
# Compute node value (class distribution)
if sample_weight is None:
class_counts = np.bincount(y, minlength=self.n_classes_)
else:
class_counts = np.zeros(self.n_classes_)
for c in range(self.n_classes_):
mask = y == c
class_counts[c] = np.sum(sample_weight[mask])
node.value = class_counts / class_counts.sum() if class_counts.sum() > 0 else class_counts
node.impurity = self._impurity_func(y, sample_weight)
# Check stopping criteria
if self._should_stop(node, depth, n_samples):
node.is_leaf = True
return node
# Find best oblique split
best_split = find_best_oblique_split(
X, y, sample_weight,
oblique_method=self.oblique_method,
max_features=self._max_features,
random_state=self._rng,
min_samples_leaf=self._min_samples_leaf,
impurity_func=self._impurity_func,
include_axis_aligned=True,
ridge_alpha=self.ridge_alpha,
feature_combinations=self.feature_combinations,
)
if best_split is None:
node.is_leaf = True
return node
# Check minimum impurity decrease
impurity_decrease = (
node.impurity -
(best_split.n_samples_left / n_samples) * best_split.impurity_left -
(best_split.n_samples_right / n_samples) * best_split.impurity_right
)
if impurity_decrease < self.min_impurity_decrease:
node.is_leaf = True
return node
# Apply split
left_mask = best_split.apply(X)
right_mask = ~left_mask
# Check min_samples_leaf again (should already be satisfied, but double-check)
if np.sum(left_mask) < self._min_samples_leaf or np.sum(right_mask) < self._min_samples_leaf:
node.is_leaf = True
return node
# Create children
node.is_leaf = False
node.split = best_split
# Left child
if sample_weight is not None:
left_weight = sample_weight[left_mask]
right_weight = sample_weight[right_mask]
else:
left_weight = None
right_weight = None
node.left = self._build_tree(
X[left_mask], y[left_mask], left_weight, depth + 1
)
node.right = self._build_tree(
X[right_mask], y[right_mask], right_weight, depth + 1
)
return node
def _should_stop(self, node: ObliqueTreeNode, depth: int, n_samples: int) -> bool:
"""Check if we should stop splitting."""
if self.max_depth is not None and depth >= self.max_depth:
return True
if n_samples < self._min_samples_split:
return True
if node.impurity <= 0: # Pure node
return True
return False
def _compute_feature_importances(self) -> None:
"""Compute impurity-based feature importances."""
importances = np.zeros(self.n_features_in_)
self._accumulate_importances(self.tree_, importances, self.tree_.n_samples)
# Normalize
total = np.sum(importances)
if total > 0:
importances = importances / total
self.feature_importances_ = importances
def _accumulate_importances(
self,
node: ObliqueTreeNode,
importances: np.ndarray,
total_samples: int,
) -> None:
"""Recursively accumulate feature importances."""
if node.is_leaf or node.split is None:
return
# Impurity decrease at this node
n_left = node.split.n_samples_left
n_right = node.split.n_samples_right
n_node = node.n_samples
impurity_decrease = (
node.impurity -
(n_left / n_node) * node.split.impurity_left -
(n_right / n_node) * node.split.impurity_right
)
# Weight by fraction of samples
weighted_decrease = impurity_decrease * (n_node / total_samples)
# Distribute importance to features based on coefficient magnitude
coeffs = np.abs(node.split.coefficients)
coeff_sum = np.sum(coeffs)
if coeff_sum > 0:
coeffs = coeffs / coeff_sum
for idx, feat_idx in enumerate(node.split.feature_indices):
importances[feat_idx] += weighted_decrease * coeffs[idx]
# Recurse
self._accumulate_importances(node.left, importances, total_samples)
self._accumulate_importances(node.right, importances, total_samples)
[docs]
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict class labels.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples to predict.
Returns
-------
y_pred : ndarray of shape (n_samples,)
Predicted class labels.
"""
proba = self.predict_proba(X)
return self.classes_[np.argmax(proba, axis=1)]
[docs]
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Predict class probabilities.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples to predict.
Returns
-------
proba : ndarray of shape (n_samples, n_classes)
Class probabilities.
"""
check_is_fitted(self, ["tree_", "classes_"])
X = check_array(X)
n_samples = X.shape[0]
proba = np.zeros((n_samples, self.n_classes_))
for i in range(n_samples):
leaf = self._traverse_to_leaf(self.tree_, X[i:i+1])
proba[i] = leaf.value
return proba
def _traverse_to_leaf(self, node: ObliqueTreeNode, x: np.ndarray) -> ObliqueTreeNode:
"""Traverse tree to find the leaf node for a sample."""
if node.is_leaf:
return node
goes_left = node.split.apply(x)[0]
if goes_left:
return self._traverse_to_leaf(node.left, x)
else:
return self._traverse_to_leaf(node.right, x)
[docs]
def apply(self, X: np.ndarray) -> np.ndarray:
"""Return leaf indices for samples.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input samples.
Returns
-------
X_leaves : ndarray of shape (n_samples,)
Leaf node id for each sample.
"""
check_is_fitted(self, ["tree_"])
X = check_array(X)
n_samples = X.shape[0]
leaves = np.zeros(n_samples, dtype=np.int64)
for i in range(n_samples):
leaf = self._traverse_to_leaf(self.tree_, X[i:i+1])
leaves[i] = leaf.node_id
return leaves
[docs]
def decision_path(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Return decision path through the tree.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input samples.
Returns
-------
indicator : ndarray of shape (n_samples, n_nodes)
Dense matrix where element [i, j] = 1 if sample i passes
through node j.
"""
check_is_fitted(self, ["tree_"])
X = check_array(X)
n_samples = X.shape[0]
indicator = np.zeros((n_samples, self.n_nodes_), dtype=np.int32)
for i in range(n_samples):
self._trace_path(self.tree_, X[i:i+1], indicator, i)
return indicator
def _trace_path(
self,
node: ObliqueTreeNode,
x: np.ndarray,
indicator: np.ndarray,
sample_idx: int,
) -> None:
"""Trace path from root to leaf for a sample."""
indicator[sample_idx, node.node_id] = 1
if node.is_leaf:
return
goes_left = node.split.apply(x)[0]
if goes_left:
self._trace_path(node.left, x, indicator, sample_idx)
else:
self._trace_path(node.right, x, indicator, sample_idx)
[docs]
def get_depth(self) -> int:
"""Return the maximum depth of the tree."""
check_is_fitted(self, ["tree_"])
return self._get_max_depth(self.tree_)
def _get_max_depth(self, node: ObliqueTreeNode) -> int:
"""Recursively compute max depth."""
if node.is_leaf:
return node.depth
return max(
self._get_max_depth(node.left),
self._get_max_depth(node.right),
)
[docs]
def get_n_leaves(self) -> int:
"""Return the number of leaves."""
check_is_fitted(self, ["tree_"])
return self._count_leaves(self.tree_)
def _count_leaves(self, node: ObliqueTreeNode) -> int:
"""Recursively count leaves."""
if node.is_leaf:
return 1
return self._count_leaves(node.left) + self._count_leaves(node.right)
[docs]
class ObliqueDecisionTreeRegressor(BaseEstimator, RegressorMixin):
"""A single oblique decision tree for regression.
This is the base estimator used by ObliqueRandomForestRegressor.
Uses linear combinations of features for splits, enabling better
capture of diagonal decision boundaries.
Parameters
----------
oblique_method : str, default='ridge'
Method for finding oblique splits:
- 'ridge': Ridge regression (recommended)
- 'pca': Principal Component Analysis
- 'random': Random projections (fastest)
- 'householder': Householder reflections
Note: 'lda' and 'svm' are not available for regression.
criterion : str, default='squared_error'
Splitting criterion: 'squared_error' or 'absolute_error'.
max_depth : int, default=None
Maximum tree depth. None means unlimited.
min_samples_split : int or float, default=2
Minimum samples required to split a node.
min_samples_leaf : int or float, default=1
Minimum samples required at a leaf.
max_features : int, float, str, or None, default=None
Features to consider per split.
min_impurity_decrease : float, default=0.0
Minimum impurity decrease required for split.
random_state : int, RandomState, or None, default=None
Random seed.
ridge_alpha : float, default=1.0
Ridge regularization for 'ridge' method.
feature_combinations : int, default=2
Features per random combination (for 'random' method).
Attributes
----------
tree_ : ObliqueTreeNode
The root node of the fitted tree.
n_features_in_ : int
Number of features seen during fit.
feature_importances_ : ndarray of shape (n_features_in_,)
Impurity-based feature importances.
n_nodes_ : int
Number of nodes in the tree.
"""
_estimator_type = "regressor"
def __init__(
self,
oblique_method: str = "ridge",
criterion: str = "squared_error",
max_depth: int | None = None,
min_samples_split: int | float = 2,
min_samples_leaf: int | float = 1,
max_features: int | float | str | None = None,
min_impurity_decrease: float = 0.0,
random_state: int | np.random.RandomState | None = None,
ridge_alpha: float = 1.0,
feature_combinations: int = 2,
):
self.oblique_method = oblique_method
self.criterion = criterion
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.max_features = max_features
self.min_impurity_decrease = min_impurity_decrease
self.random_state = random_state
self.ridge_alpha = ridge_alpha
self.feature_combinations = feature_combinations
def _resolve_max_features(self, n_features: int) -> int:
"""Resolve max_features parameter to an integer."""
if self.max_features is None:
return n_features
elif isinstance(self.max_features, int):
return min(self.max_features, n_features)
elif isinstance(self.max_features, float):
return max(1, int(self.max_features * n_features))
elif self.max_features == "sqrt":
return max(1, int(np.sqrt(n_features)))
elif self.max_features == "log2":
return max(1, int(np.log2(n_features)))
else:
raise ValueError(f"Invalid max_features: {self.max_features}")
def _resolve_min_samples(self, param: int | float, n_samples: int) -> int:
"""Resolve min_samples parameter to an integer."""
if isinstance(param, float):
return max(1, int(param * n_samples))
return param
def _get_impurity_func(self) -> Callable:
"""Get the impurity function based on criterion."""
if self.criterion == "squared_error":
return compute_mse
elif self.criterion == "absolute_error":
return compute_mae
else:
raise ValueError(f"Invalid criterion: {self.criterion}")
def _get_oblique_method(self) -> str:
"""Get validated oblique method for regression."""
# LDA and SVM are classification-only
if self.oblique_method in ("lda", "svm"):
return "ridge"
return self.oblique_method
[docs]
def fit(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray | None = None,
) -> ObliqueDecisionTreeRegressor:
"""Build the oblique decision tree.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
self : object
Fitted estimator.
"""
X, y = check_X_y(X, y)
y = y.astype(np.float64)
n_samples, n_features = X.shape
self.n_features_in_ = n_features
# Resolve parameters
self._max_features = self._resolve_max_features(n_features)
self._min_samples_split = self._resolve_min_samples(
self.min_samples_split, n_samples
)
self._min_samples_leaf = self._resolve_min_samples(
self.min_samples_leaf, n_samples
)
self._oblique_method = self._get_oblique_method()
# Random state
self._rng = check_random_state(self.random_state)
# Impurity function
self._impurity_func = self._get_impurity_func()
# Build tree
self._node_count = 0
self.tree_ = self._build_tree(X, y, sample_weight, depth=0)
self.n_nodes_ = self._node_count
# Compute feature importances
self._compute_feature_importances()
return self
def _build_tree(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray | None,
depth: int,
) -> ObliqueTreeNode:
"""Recursively build the oblique decision tree."""
n_samples = len(y)
# Create node
node = ObliqueTreeNode()
node.n_samples = n_samples
node.depth = depth
node.node_id = self._node_count
self._node_count += 1
# Compute node value (mean prediction)
if sample_weight is None:
node.value = np.array([np.mean(y)])
else:
node.value = np.array([np.average(y, weights=sample_weight)])
node.impurity = self._impurity_func(y, sample_weight)
# Check stopping criteria
if self._should_stop(node, depth, n_samples):
node.is_leaf = True
return node
# Find best oblique split
best_split = find_best_oblique_split(
X, y, sample_weight,
oblique_method=self._oblique_method,
max_features=self._max_features,
random_state=self._rng,
min_samples_leaf=self._min_samples_leaf,
impurity_func=self._impurity_func,
include_axis_aligned=True,
ridge_alpha=self.ridge_alpha,
feature_combinations=self.feature_combinations,
)
if best_split is None:
node.is_leaf = True
return node
# Check minimum impurity decrease
impurity_decrease = (
node.impurity -
(best_split.n_samples_left / n_samples) * best_split.impurity_left -
(best_split.n_samples_right / n_samples) * best_split.impurity_right
)
if impurity_decrease < self.min_impurity_decrease:
node.is_leaf = True
return node
# Apply split
left_mask = best_split.apply(X)
right_mask = ~left_mask
if np.sum(left_mask) < self._min_samples_leaf or np.sum(right_mask) < self._min_samples_leaf:
node.is_leaf = True
return node
# Create children
node.is_leaf = False
node.split = best_split
if sample_weight is not None:
left_weight = sample_weight[left_mask]
right_weight = sample_weight[right_mask]
else:
left_weight = None
right_weight = None
node.left = self._build_tree(
X[left_mask], y[left_mask], left_weight, depth + 1
)
node.right = self._build_tree(
X[right_mask], y[right_mask], right_weight, depth + 1
)
return node
def _should_stop(self, node: ObliqueTreeNode, depth: int, n_samples: int) -> bool:
"""Check if we should stop splitting."""
if self.max_depth is not None and depth >= self.max_depth:
return True
if n_samples < self._min_samples_split:
return True
if node.impurity <= 1e-10: # Nearly pure node
return True
return False
def _compute_feature_importances(self) -> None:
"""Compute impurity-based feature importances."""
importances = np.zeros(self.n_features_in_)
self._accumulate_importances(self.tree_, importances, self.tree_.n_samples)
# Normalize
total = np.sum(importances)
if total > 0:
importances = importances / total
self.feature_importances_ = importances
def _accumulate_importances(
self,
node: ObliqueTreeNode,
importances: np.ndarray,
total_samples: int,
) -> None:
"""Recursively accumulate feature importances."""
if node.is_leaf or node.split is None:
return
n_left = node.split.n_samples_left
n_right = node.split.n_samples_right
n_node = node.n_samples
impurity_decrease = (
node.impurity -
(n_left / n_node) * node.split.impurity_left -
(n_right / n_node) * node.split.impurity_right
)
weighted_decrease = impurity_decrease * (n_node / total_samples)
coeffs = np.abs(node.split.coefficients)
coeff_sum = np.sum(coeffs)
if coeff_sum > 0:
coeffs = coeffs / coeff_sum
for idx, feat_idx in enumerate(node.split.feature_indices):
importances[feat_idx] += weighted_decrease * coeffs[idx]
self._accumulate_importances(node.left, importances, total_samples)
self._accumulate_importances(node.right, importances, total_samples)
[docs]
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict target values.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples to predict.
Returns
-------
y_pred : ndarray of shape (n_samples,)
Predicted values.
"""
check_is_fitted(self, ["tree_"])
X = check_array(X)
n_samples = X.shape[0]
predictions = np.zeros(n_samples)
for i in range(n_samples):
leaf = self._traverse_to_leaf(self.tree_, X[i:i+1])
predictions[i] = leaf.value[0]
return predictions
def _traverse_to_leaf(self, node: ObliqueTreeNode, x: np.ndarray) -> ObliqueTreeNode:
"""Traverse tree to find the leaf node for a sample."""
if node.is_leaf:
return node
goes_left = node.split.apply(x)[0]
if goes_left:
return self._traverse_to_leaf(node.left, x)
else:
return self._traverse_to_leaf(node.right, x)
[docs]
def apply(self, X: np.ndarray) -> np.ndarray:
"""Return leaf indices for samples."""
check_is_fitted(self, ["tree_"])
X = check_array(X)
n_samples = X.shape[0]
leaves = np.zeros(n_samples, dtype=np.int64)
for i in range(n_samples):
leaf = self._traverse_to_leaf(self.tree_, X[i:i+1])
leaves[i] = leaf.node_id
return leaves
[docs]
def get_depth(self) -> int:
"""Return the maximum depth of the tree."""
check_is_fitted(self, ["tree_"])
return self._get_max_depth(self.tree_)
def _get_max_depth(self, node: ObliqueTreeNode) -> int:
"""Recursively compute max depth."""
if node.is_leaf:
return node.depth
return max(
self._get_max_depth(node.left),
self._get_max_depth(node.right),
)
[docs]
def get_n_leaves(self) -> int:
"""Return the number of leaves."""
check_is_fitted(self, ["tree_"])
return self._count_leaves(self.tree_)
def _count_leaves(self, node: ObliqueTreeNode) -> int:
"""Recursively count leaves."""
if node.is_leaf:
return 1
return self._count_leaves(node.left) + self._count_leaves(node.right)