Source code for endgame.models.tabular.node

from __future__ import annotations

"""NODE: Neural Oblivious Decision Ensembles.

NODE is a differentiable ensemble of oblivious decision trees (ODTs).
It bridges the gap between gradient boosting and neural networks.

References
----------
- Popov et al. "Neural Oblivious Decision Ensembles for Deep Learning
  on Tabular Data" (2020)
"""

from typing import Any

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder, StandardScaler

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader, TensorDataset
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False


def _check_torch():
    if not HAS_TORCH:
        raise ImportError("PyTorch is required for NODE. Install with: pip install torch")


def entmax15(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Entmax 1.5 activation (sparse softmax).

    Entmax with alpha=1.5 produces sparse outputs while being differentiable.
    """
    # Simplified implementation using sparsemax approximation
    # Full entmax requires iterative algorithm
    return F.softmax(x * 1.5, dim=dim)


def entmoid15(x: torch.Tensor) -> torch.Tensor:
    """Entmoid 1.5 - entmax-based sigmoid."""
    return torch.sigmoid(x * 1.5)


class _DenseODSTBlock(nn.Module):
    """Dense block of Oblivious Decision Trees.

    Each tree makes decisions based on feature comparisons and
    returns a weighted combination of leaf values.
    """

    def __init__(
        self,
        n_features: int,
        n_trees: int = 1024,
        tree_depth: int = 6,
        choice_function: str = "entmax15",
        bin_function: str = "entmoid15",
    ):
        super().__init__()

        self.n_features = n_features
        self.n_trees = n_trees
        self.tree_depth = tree_depth
        self.n_leaves = 2 ** tree_depth

        # Feature selection weights for each internal node
        # Shape: (n_trees, tree_depth, n_features)
        self.feature_selection = nn.Parameter(
            torch.zeros(n_trees, tree_depth, n_features)
        )
        nn.init.xavier_uniform_(self.feature_selection)

        # Thresholds for comparisons
        # Shape: (n_trees, tree_depth)
        self.thresholds = nn.Parameter(torch.zeros(n_trees, tree_depth))

        # Leaf values (responses)
        # Shape: (n_trees, n_leaves)
        self.leaf_values = nn.Parameter(torch.zeros(n_trees, self.n_leaves))
        nn.init.xavier_uniform_(self.leaf_values)

        # Choice and bin functions
        if choice_function == "entmax15":
            self.choice_fn = entmax15
        elif choice_function == "softmax":
            self.choice_fn = lambda x, dim=-1: F.softmax(x, dim=dim)
        else:
            self.choice_fn = entmax15

        if bin_function == "entmoid15":
            self.bin_fn = entmoid15
        elif bin_function == "sigmoid":
            self.bin_fn = torch.sigmoid
        else:
            self.bin_fn = entmoid15

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through ODT block.

        Parameters
        ----------
        x : Tensor of shape (batch, n_features)

        Returns
        -------
        Tensor of shape (batch, n_trees)
            Tree outputs.
        """
        batch_size = x.shape[0]

        # Feature selection with soft attention
        # (n_trees, tree_depth, n_features) -> softmax over features
        feature_weights = self.choice_fn(self.feature_selection, dim=-1)

        # Compute selected features for each node
        # x: (batch, n_features)
        # feature_weights: (n_trees, tree_depth, n_features)
        # Result: (batch, n_trees, tree_depth)
        selected_features = torch.einsum('bf,tdf->btd', x, feature_weights)

        # Compare with thresholds
        # thresholds: (n_trees, tree_depth)
        # node_decisions: (batch, n_trees, tree_depth) - soft decisions [0, 1]
        node_decisions = self.bin_fn(selected_features - self.thresholds)

        # Convert decisions to leaf indices
        # Each leaf corresponds to a binary path through the tree
        # leaf_probs: (batch, n_trees, n_leaves)
        leaf_probs = self._compute_leaf_probabilities(node_decisions)

        # Weighted sum of leaf values
        # leaf_values: (n_trees, n_leaves)
        # output: (batch, n_trees)
        output = torch.einsum('btl,tl->bt', leaf_probs, self.leaf_values)

        return output

    def _compute_leaf_probabilities(self, decisions: torch.Tensor) -> torch.Tensor:
        """Compute soft probabilities for each leaf using vectorized operations.

        Parameters
        ----------
        decisions : Tensor of shape (batch, n_trees, tree_depth)
            Soft decisions at each internal node.

        Returns
        -------
        Tensor of shape (batch, n_trees, n_leaves)
            Probability of reaching each leaf.
        """
        batch_size, n_trees, depth = decisions.shape
        device = decisions.device

        # Pre-compute binary path matrix for all leaves (done once, cached)
        # Shape: (n_leaves, depth) - binary representation of each leaf index
        if not hasattr(self, '_path_matrix') or self._path_matrix.device != device:
            leaf_indices = torch.arange(self.n_leaves, device=device)
            # For each depth d, check if bit (depth-1-d) is set
            self._path_matrix = torch.zeros(self.n_leaves, depth, device=device)
            for d in range(depth):
                self._path_matrix[:, d] = (leaf_indices >> (depth - 1 - d)) & 1

        # Vectorized computation:
        # decisions: (batch, n_trees, depth)
        # path_matrix: (n_leaves, depth)
        #
        # For each leaf, we need product over depth of:
        #   decision[d] if path[d] == 1 else (1 - decision[d])
        #
        # This equals: decision^path * (1-decision)^(1-path)
        # In log space: path * log(decision) + (1-path) * log(1-decision)

        # Clamp for numerical stability
        decisions_clamped = torch.clamp(decisions, 1e-6, 1 - 1e-6)

        # Compute in log space for stability
        # log_decisions: (batch, n_trees, depth)
        log_d = torch.log(decisions_clamped)
        log_1_minus_d = torch.log(1 - decisions_clamped)

        # path_matrix: (n_leaves, depth) -> expand for broadcasting
        # Result: (batch, n_trees, n_leaves)
        log_leaf_probs = torch.einsum('btd,ld->btl', log_d, self._path_matrix) + \
                         torch.einsum('btd,ld->btl', log_1_minus_d, 1 - self._path_matrix)

        leaf_probs = torch.exp(log_leaf_probs)

        return leaf_probs


class _NODEModule(nn.Module):
    """PyTorch NODE module."""

    def __init__(
        self,
        n_features: int,
        n_classes: int,
        n_layers: int = 2,
        n_trees: int = 1024,
        tree_depth: int = 6,
        choice_function: str = "entmax15",
        bin_function: str = "entmoid15",
        is_regression: bool = False,
    ):
        super().__init__()

        self.n_layers = n_layers
        self.is_regression = is_regression

        # Input normalization
        self.input_bn = nn.BatchNorm1d(n_features)

        # NODE layers
        self.layers = nn.ModuleList()
        current_dim = n_features

        for i in range(n_layers):
            self.layers.append(
                _DenseODSTBlock(
                    n_features=current_dim,
                    n_trees=n_trees,
                    tree_depth=tree_depth,
                    choice_function=choice_function,
                    bin_function=bin_function,
                )
            )
            current_dim = n_trees

        # Output layer
        output_dim = 1 if is_regression else n_classes
        self.output = nn.Linear(current_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normalize input
        x = self.input_bn(x)

        # NODE layers
        for layer in self.layers:
            x = layer(x)

        # Output
        return self.output(x)


[docs] class NODEClassifier(ClassifierMixin, BaseEstimator): """NODE: Neural Oblivious Decision Ensembles for classification. A differentiable ensemble of oblivious decision trees. Bridges the gap between gradient boosting and neural networks. Parameters ---------- n_layers : int, default=1 Number of dense NODE layers. Start with 1 for most datasets. n_trees : int, default=128 Number of trees per layer. 64-256 works well for most datasets. tree_depth : int, default=4 Depth of each oblivious tree. 3-5 works well; deeper trees risk overfitting. choice_function : str, default='softmax' Soft choice function: 'entmax15', 'softmax'. Softmax is more stable. bin_function : str, default='sigmoid' Binning function: 'entmoid15', 'sigmoid'. Sigmoid is more stable. learning_rate : float, default=0.01 Learning rate. Higher values (0.01-0.1) often work better for NODE. weight_decay : float, default=1e-5 L2 regularization. n_epochs : int, default=100 Maximum training epochs. batch_size : int, default=128 Training batch size. Smaller batches (64-256) often work better. early_stopping : int, default=20 Early stopping patience. max_grad_norm : float, default=1.0 Maximum gradient norm for clipping. validation_fraction : float, default=0.1 Fraction of training data to use for validation when eval_set not provided. Set to 0 to disable internal validation split (not recommended). device : str, default='auto' Device: 'cuda', 'cpu', 'auto'. random_state : int, optional Random seed. verbose : bool, default=False Verbose output. Attributes ---------- classes_ : ndarray Class labels. model_ : _NODEModule Fitted model. history_ : dict Training history. Examples -------- >>> clf = NODEClassifier(n_layers=1, n_trees=128, tree_depth=4) >>> clf.fit(X_train, y_train, eval_set=(X_val, y_val)) >>> proba = clf.predict_proba(X_test) Notes ----- NODE works best with: - An eval_set for early stopping (or validation_fraction > 0) - Higher learning rates (0.01-0.1) than typical neural networks - Smaller batch sizes (64-256) - Fewer/shallower trees than you might expect (start small) When using sklearn's cross_val_score (which doesn't support eval_set), the model will automatically create an internal validation split using validation_fraction of the training data. """ _estimator_type = "classifier" def __init__( self, n_layers: int = 1, n_trees: int = 128, tree_depth: int = 4, choice_function: str = "softmax", bin_function: str = "sigmoid", learning_rate: float = 0.01, weight_decay: float = 1e-5, n_epochs: int = 100, batch_size: int = 128, early_stopping: int = 20, max_grad_norm: float = 1.0, validation_fraction: float = 0.1, device: str = "auto", random_state: int | None = None, verbose: bool = False, ): self.n_layers = n_layers self.n_trees = n_trees self.tree_depth = tree_depth self.choice_function = choice_function self.bin_function = bin_function self.learning_rate = learning_rate self.weight_decay = weight_decay self.n_epochs = n_epochs self.batch_size = batch_size self.early_stopping = early_stopping self.max_grad_norm = max_grad_norm self.validation_fraction = validation_fraction self.device = device self.random_state = random_state self.verbose = verbose self.classes_: np.ndarray | None = None self.n_classes_: int = 0 self.model_: _NODEModule | None = None self._device = None self._label_encoder: LabelEncoder | None = None self._scaler: StandardScaler | None = None self.history_: dict[str, list[float]] = {"train_loss": [], "val_loss": []} self._is_fitted: bool = False def _log(self, msg: str): if self.verbose: print(f"[NODE] {msg}") def _get_device(self): _check_torch() if self.device == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(self.device) def _set_seed(self): if self.random_state is not None: torch.manual_seed(self.random_state) np.random.seed(self.random_state)
[docs] def fit( self, X, y, eval_set: tuple[Any, Any] | None = None, **fit_params, ) -> NODEClassifier: """Fit the NODE classifier. Parameters ---------- X : array-like Training features. y : array-like Training labels. eval_set : tuple, optional Validation set for early stopping. Returns ------- self """ _check_torch() self._set_seed() self._device = self._get_device() X = np.asarray(X, dtype=np.float32) y = np.asarray(y) # Encode labels self._label_encoder = LabelEncoder() y_encoded = self._label_encoder.fit_transform(y) self.classes_ = self._label_encoder.classes_ self.n_classes_ = len(self.classes_) # Scale features self._scaler = StandardScaler() X_scaled = self._scaler.fit_transform(X) X_scaled = np.nan_to_num(X_scaled, nan=0.0) x_tensor = torch.tensor(X_scaled, dtype=torch.float32) y_tensor = torch.tensor(y_encoded, dtype=torch.long) # Create model self.model_ = _NODEModule( n_features=X.shape[1], n_classes=self.n_classes_, n_layers=self.n_layers, n_trees=self.n_trees, tree_depth=self.tree_depth, choice_function=self.choice_function, bin_function=self.bin_function, is_regression=False, ).to(self._device) # Create internal validation split if no eval_set provided if eval_set is None and self.validation_fraction > 0: from sklearn.model_selection import train_test_split n_val = int(len(X_scaled) * self.validation_fraction) if n_val >= 1: try: X_train_internal, X_val_internal, y_train_internal, y_val_internal = train_test_split( X_scaled, y_encoded, test_size=self.validation_fraction, stratify=y_encoded, random_state=self.random_state, ) except ValueError: X_train_internal, X_val_internal, y_train_internal, y_val_internal = train_test_split( X_scaled, y_encoded, test_size=self.validation_fraction, random_state=self.random_state, ) x_tensor = torch.tensor(X_train_internal, dtype=torch.float32) y_tensor = torch.tensor(y_train_internal, dtype=torch.long) eval_set = (X_val_internal, y_val_internal) # Data loader train_dataset = TensorDataset(x_tensor, y_tensor) train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) # Validation val_loader = None if eval_set is not None: X_val, y_val = eval_set # Check if validation data needs scaling (external eval_set) if isinstance(y_val[0], (str, np.str_)) or (hasattr(y_val, 'dtype') and y_val.dtype.kind in ('U', 'S', 'O')): # External eval_set with original labels X_val = self._scaler.transform(np.asarray(X_val, dtype=np.float32)) X_val = np.nan_to_num(X_val, nan=0.0) y_val_encoded = self._label_encoder.transform(y_val) elif X_val.shape[1] == X_scaled.shape[1] and np.allclose(X_val.mean(), 0, atol=1): # Already scaled internal split y_val_encoded = y_val else: # External eval_set with numerical labels X_val = self._scaler.transform(np.asarray(X_val, dtype=np.float32)) X_val = np.nan_to_num(X_val, nan=0.0) y_val_encoded = self._label_encoder.transform(y_val) if hasattr(self._label_encoder, 'transform') else y_val val_dataset = TensorDataset( torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val_encoded, dtype=torch.long), ) val_loader = DataLoader(val_dataset, batch_size=self.batch_size) # Optimizer optimizer = optim.AdamW( self.model_.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_epochs) best_val_loss = float("inf") best_state = None patience_counter = 0 self._log(f"Training on {self._device}...") for epoch in range(self.n_epochs): self.model_.train() train_loss = 0.0 n_batches = 0 for x_batch, y_batch in train_loader: x_batch = x_batch.to(self._device) y_batch = y_batch.to(self._device) optimizer.zero_grad() logits = self.model_(x_batch) loss = F.cross_entropy(logits, y_batch) loss.backward() # Gradient clipping for stability if self.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.model_.parameters(), self.max_grad_norm) optimizer.step() train_loss += loss.item() n_batches += 1 train_loss /= n_batches self.history_["train_loss"].append(train_loss) if val_loader is not None: self.model_.eval() val_loss = 0.0 n_val_batches = 0 with torch.no_grad(): for x_batch, y_batch in val_loader: x_batch = x_batch.to(self._device) y_batch = y_batch.to(self._device) logits = self.model_(x_batch) loss = F.cross_entropy(logits, y_batch) val_loss += loss.item() n_val_batches += 1 val_loss /= n_val_batches self.history_["val_loss"].append(val_loss) if val_loss < best_val_loss: best_val_loss = val_loss best_state = {k: v.cpu().clone() for k, v in self.model_.state_dict().items()} patience_counter = 0 else: patience_counter += 1 if self.verbose and (epoch + 1) % 10 == 0: self._log(f"Epoch {epoch+1}/{self.n_epochs}: train={train_loss:.4f}, val={val_loss:.4f}") if patience_counter >= self.early_stopping: self._log(f"Early stopping at epoch {epoch + 1}") break scheduler.step() if best_state is not None: self.model_.load_state_dict(best_state) self._is_fitted = True return self
[docs] def predict_proba(self, X) -> np.ndarray: """Predict class probabilities.""" if not self._is_fitted: raise RuntimeError("NODEClassifier has not been fitted.") X = np.asarray(X, dtype=np.float32) X_scaled = np.nan_to_num(self._scaler.transform(X), nan=0.0) self.model_.eval() all_proba = [] with torch.no_grad(): for start in range(0, len(X), self.batch_size): end = min(start + self.batch_size, len(X)) x_batch = torch.tensor(X_scaled[start:end], dtype=torch.float32).to(self._device) logits = self.model_(x_batch) proba = F.softmax(logits, dim=1) all_proba.append(proba.cpu().numpy()) return np.vstack(all_proba)
[docs] def predict(self, X) -> np.ndarray: """Predict class labels.""" proba = self.predict_proba(X) return self._label_encoder.inverse_transform(np.argmax(proba, axis=1))
[docs] class NODERegressor(BaseEstimator, RegressorMixin): """NODE for regression. Same architecture as NODEClassifier but with MSE loss. See NODEClassifier for parameter descriptions. """ _estimator_type = "regressor" def __init__( self, n_layers: int = 2, n_trees: int = 256, tree_depth: int = 4, choice_function: str = "softmax", bin_function: str = "sigmoid", learning_rate: float = 1e-3, weight_decay: float = 1e-5, n_epochs: int = 100, batch_size: int = 512, early_stopping: int = 15, max_grad_norm: float = 1.0, device: str = "auto", random_state: int | None = None, verbose: bool = False, ): self.n_layers = n_layers self.n_trees = n_trees self.tree_depth = tree_depth self.choice_function = choice_function self.bin_function = bin_function self.learning_rate = learning_rate self.weight_decay = weight_decay self.n_epochs = n_epochs self.batch_size = batch_size self.early_stopping = early_stopping self.max_grad_norm = max_grad_norm self.device = device self.random_state = random_state self.verbose = verbose self.model_ = None self._device = None self._scaler = None self._target_scaler = None self.history_ = {"train_loss": [], "val_loss": []} self._is_fitted = False
[docs] def fit(self, X, y, eval_set=None, **fit_params) -> NODERegressor: """Fit NODE regressor.""" _check_torch() if self.random_state is not None: torch.manual_seed(self.random_state) np.random.seed(self.random_state) self._device = torch.device( "cuda" if self.device == "auto" and torch.cuda.is_available() else "cpu" if self.device == "auto" else self.device ) X = np.asarray(X, dtype=np.float32) y = np.asarray(y).reshape(-1, 1).astype(np.float32) self._scaler = StandardScaler() X_scaled = np.nan_to_num(self._scaler.fit_transform(X), nan=0.0) self._target_scaler = StandardScaler() y_scaled = self._target_scaler.fit_transform(y).ravel() x_tensor = torch.tensor(X_scaled, dtype=torch.float32) y_tensor = torch.tensor(y_scaled, dtype=torch.float32) self.model_ = _NODEModule( n_features=X.shape[1], n_classes=1, n_layers=self.n_layers, n_trees=self.n_trees, tree_depth=self.tree_depth, choice_function=self.choice_function, bin_function=self.bin_function, is_regression=True, ).to(self._device) train_dataset = TensorDataset(x_tensor, y_tensor) train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) optimizer = optim.AdamW( self.model_.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) for epoch in range(self.n_epochs): self.model_.train() for x_batch, y_batch in train_loader: x_batch = x_batch.to(self._device) y_batch = y_batch.to(self._device) optimizer.zero_grad() pred = self.model_(x_batch).squeeze() loss = F.mse_loss(pred, y_batch) loss.backward() optimizer.step() self._is_fitted = True return self
[docs] def predict(self, X) -> np.ndarray: """Predict target values.""" if not self._is_fitted: raise RuntimeError("NODERegressor has not been fitted.") X = np.asarray(X, dtype=np.float32) X_scaled = np.nan_to_num(self._scaler.transform(X), nan=0.0) self.model_.eval() all_pred = [] with torch.no_grad(): for start in range(0, len(X), self.batch_size): end = min(start + self.batch_size, len(X)) x_batch = torch.tensor(X_scaled[start:end], dtype=torch.float32).to(self._device) pred = self.model_(x_batch).squeeze() all_pred.append(pred.cpu().numpy()) pred = np.concatenate(all_pred) return self._target_scaler.inverse_transform(pred.reshape(-1, 1)).ravel()