Source code for endgame.models.tabular.modern_nca

from __future__ import annotations

"""Modern Neighborhood Component Analysis.

A kNN-based approach with learned distance metric that's surprisingly
competitive with gradient boosting on many tabular tasks.

References
----------
- Goldberger et al. "Neighbourhood Components Analysis" (2004)
- Movshovitz-Attias et al. "No Fuss Distance Metric Learning Using Proxies" (2017)
"""

from typing import Any

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import LabelEncoder, StandardScaler

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader, TensorDataset
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False


def _check_torch():
    if not HAS_TORCH:
        raise ImportError("PyTorch is required for ModernNCA. Install with: pip install torch")


class _EmbeddingNetwork(nn.Module):
    """Neural network for learning feature embeddings."""

    def __init__(
        self,
        n_features: int,
        embedding_dim: int = 128,
        hidden_dims: list[int] = None,
        dropout: float = 0.1,
    ):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [256, 256]

        layers = []
        in_dim = n_features

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
            ])
            in_dim = hidden_dim

        layers.append(nn.Linear(in_dim, embedding_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Embed features into metric space."""
        embeddings = self.network(x)
        # L2 normalize for cosine similarity
        return F.normalize(embeddings, p=2, dim=1)


[docs] class ModernNCAClassifier(ClassifierMixin, BaseEstimator): """Modern Neighborhood Component Analysis classifier. A kNN-based approach with learned distance metric using neural networks. Surprisingly competitive with gradient boosting on many tasks. The model learns an embedding space where samples of the same class are close together and samples of different classes are far apart. At inference, it uses kNN in this learned space. Parameters ---------- n_neighbors : int, default=32 Number of neighbors for kNN prediction. embedding_dim : int, default=128 Dimension of learned embedding space. hidden_dims : List[int], default=[256, 256] Hidden layer dimensions for embedding network. temperature : float, default=0.1 Softmax temperature for neighbor weighting. dropout : float, default=0.1 Dropout rate in embedding network. learning_rate : float, default=1e-3 Learning rate. weight_decay : float, default=1e-5 L2 regularization. n_epochs : int, default=100 Training epochs. batch_size : int, default=256 Batch size. early_stopping : int, default=15 Early stopping patience. device : str, default='auto' Device. random_state : int, optional Random seed. verbose : bool, default=False Verbose output. Attributes ---------- classes_ : ndarray Class labels. model_ : _EmbeddingNetwork Fitted embedding network. train_embeddings_ : ndarray Embeddings of training data. train_labels_ : ndarray Training labels. history_ : dict Training history. Examples -------- >>> clf = ModernNCAClassifier(n_neighbors=32, embedding_dim=128) >>> clf.fit(X_train, y_train) >>> proba = clf.predict_proba(X_test) Notes ----- This approach is particularly effective when: - The decision boundary is locally smooth - Class separation can benefit from learned features - You want probabilistic predictions based on neighborhood """ _estimator_type = "classifier" def __init__( self, n_neighbors: int = 32, embedding_dim: int = 128, hidden_dims: list[int] = None, temperature: float = 0.1, dropout: float = 0.1, learning_rate: float = 1e-3, weight_decay: float = 1e-5, n_epochs: int = 100, batch_size: int = 256, early_stopping: int = 15, device: str = "auto", random_state: int | None = None, verbose: bool = False, ): self.n_neighbors = n_neighbors self.embedding_dim = embedding_dim self.hidden_dims = hidden_dims or [256, 256] self.temperature = temperature self.dropout = dropout self.learning_rate = learning_rate self.weight_decay = weight_decay self.n_epochs = n_epochs self.batch_size = batch_size self.early_stopping = early_stopping self.device = device self.random_state = random_state self.verbose = verbose self.classes_: np.ndarray | None = None self.n_classes_: int = 0 self.model_: _EmbeddingNetwork | None = None self._device = None self._label_encoder: LabelEncoder | None = None self._scaler: StandardScaler | None = None self.train_embeddings_: np.ndarray | None = None self.train_labels_: np.ndarray | None = None self._X_train: np.ndarray | None = None self.history_: dict[str, list[float]] = {"train_loss": [], "val_loss": []} self._is_fitted: bool = False def _log(self, msg: str): if self.verbose: print(f"[ModernNCA] {msg}") def _get_device(self): _check_torch() if self.device == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(self.device) def _set_seed(self): if self.random_state is not None: torch.manual_seed(self.random_state) np.random.seed(self.random_state) def _nca_loss( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """NCA loss: softmax over pairwise distances. Maximizes probability that a point is correctly classified by its neighbors in embedding space. """ batch_size = embeddings.shape[0] # Compute pairwise distances (squared Euclidean) # ||a - b||^2 = ||a||^2 + ||b||^2 - 2*a.b # Since embeddings are L2-normalized, ||a||^2 = ||b||^2 = 1 # Distance = 2 - 2*a.b = 2*(1 - cosine_similarity) similarities = torch.mm(embeddings, embeddings.t()) # Cosine similarity distances = 2 * (1 - similarities) # Squared Euclidean distance # Apply temperature scaling scaled_distances = -distances / self.temperature # Mask self-similarity (diagonal) mask = torch.eye(batch_size, device=embeddings.device).bool() scaled_distances = scaled_distances.masked_fill(mask, float('-inf')) # Softmax over distances to get "probability" of selecting each neighbor neighbor_probs = F.softmax(scaled_distances, dim=1) # For each sample, compute probability of correct classification # = sum of probs for neighbors of same class labels_match = labels.unsqueeze(0) == labels.unsqueeze(1) # (batch, batch) labels_match = labels_match.float() labels_match = labels_match.masked_fill(mask, 0) # Exclude self # Probability of correct classification correct_probs = (neighbor_probs * labels_match).sum(dim=1) # Avoid log(0) correct_probs = correct_probs.clamp(min=1e-10) # NCA loss: negative log probability loss = -torch.log(correct_probs).mean() return loss
[docs] def fit( self, X, y, eval_set: tuple[Any, Any] | None = None, **fit_params, ) -> ModernNCAClassifier: """Fit the ModernNCA classifier. Parameters ---------- X : array-like Training features. y : array-like Training labels. eval_set : tuple, optional Validation set. Returns ------- self """ _check_torch() self._set_seed() self._device = self._get_device() X = np.asarray(X, dtype=np.float32) y = np.asarray(y) # Encode labels self._label_encoder = LabelEncoder() y_encoded = self._label_encoder.fit_transform(y) self.classes_ = self._label_encoder.classes_ self.n_classes_ = len(self.classes_) # Scale features self._scaler = StandardScaler() X_scaled = self._scaler.fit_transform(X) X_scaled = np.nan_to_num(X_scaled, nan=0.0) # Store for prediction self._X_train = X_scaled self.train_labels_ = y_encoded x_tensor = torch.tensor(X_scaled, dtype=torch.float32) y_tensor = torch.tensor(y_encoded, dtype=torch.long) # Create model self.model_ = _EmbeddingNetwork( n_features=X.shape[1], embedding_dim=self.embedding_dim, hidden_dims=self.hidden_dims, dropout=self.dropout, ).to(self._device) # Data loader train_dataset = TensorDataset(x_tensor, y_tensor) train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) # Validation val_loader = None if eval_set is not None: X_val, y_val = eval_set X_val = self._scaler.transform(np.asarray(X_val, dtype=np.float32)) X_val = np.nan_to_num(X_val, nan=0.0) y_val_encoded = self._label_encoder.transform(y_val) val_dataset = TensorDataset( torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val_encoded, dtype=torch.long), ) val_loader = DataLoader(val_dataset, batch_size=self.batch_size) # Optimizer optimizer = optim.AdamW( self.model_.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_epochs) best_val_loss = float("inf") best_state = None patience_counter = 0 self._log(f"Training on {self._device}...") for epoch in range(self.n_epochs): self.model_.train() train_loss = 0.0 n_batches = 0 for x_batch, y_batch in train_loader: x_batch = x_batch.to(self._device) y_batch = y_batch.to(self._device) optimizer.zero_grad() embeddings = self.model_(x_batch) loss = self._nca_loss(embeddings, y_batch) loss.backward() optimizer.step() train_loss += loss.item() n_batches += 1 train_loss /= n_batches self.history_["train_loss"].append(train_loss) if val_loader is not None: self.model_.eval() val_loss = 0.0 n_val_batches = 0 with torch.no_grad(): for x_batch, y_batch in val_loader: x_batch = x_batch.to(self._device) y_batch = y_batch.to(self._device) embeddings = self.model_(x_batch) loss = self._nca_loss(embeddings, y_batch) val_loss += loss.item() n_val_batches += 1 val_loss /= n_val_batches self.history_["val_loss"].append(val_loss) if val_loss < best_val_loss: best_val_loss = val_loss best_state = {k: v.cpu().clone() for k, v in self.model_.state_dict().items()} patience_counter = 0 else: patience_counter += 1 if self.verbose and (epoch + 1) % 10 == 0: self._log(f"Epoch {epoch+1}/{self.n_epochs}: train={train_loss:.4f}, val={val_loss:.4f}") if patience_counter >= self.early_stopping: self._log(f"Early stopping at epoch {epoch + 1}") break scheduler.step() if best_state is not None: self.model_.load_state_dict(best_state) # Compute training embeddings for kNN self._compute_train_embeddings() self._is_fitted = True return self
def _compute_train_embeddings(self): """Compute embeddings for all training samples.""" self.model_.eval() all_embeddings = [] with torch.no_grad(): for start in range(0, len(self._X_train), self.batch_size): end = min(start + self.batch_size, len(self._X_train)) x_batch = torch.tensor( self._X_train[start:end], dtype=torch.float32 ).to(self._device) embeddings = self.model_(x_batch) all_embeddings.append(embeddings.cpu().numpy()) self.train_embeddings_ = np.vstack(all_embeddings) def _embed(self, X: np.ndarray) -> np.ndarray: """Embed features using the learned network.""" X_scaled = self._scaler.transform(X) X_scaled = np.nan_to_num(X_scaled, nan=0.0) self.model_.eval() all_embeddings = [] with torch.no_grad(): for start in range(0, len(X_scaled), self.batch_size): end = min(start + self.batch_size, len(X_scaled)) x_batch = torch.tensor( X_scaled[start:end], dtype=torch.float32 ).to(self._device) embeddings = self.model_(x_batch) all_embeddings.append(embeddings.cpu().numpy()) return np.vstack(all_embeddings)
[docs] def predict_proba(self, X) -> np.ndarray: """Predict class probabilities using kNN in embedding space.""" if not self._is_fitted: raise RuntimeError("ModernNCAClassifier has not been fitted.") X = np.asarray(X, dtype=np.float32) # Embed test samples test_embeddings = self._embed(X) # Compute distances to all training samples (cosine distance) # Since embeddings are normalized, cosine_dist = 1 - dot_product similarities = np.dot(test_embeddings, self.train_embeddings_.T) # Get k nearest neighbors k = min(self.n_neighbors, len(self.train_embeddings_)) proba = np.zeros((len(X), self.n_classes_)) for i in range(len(X)): # Get top-k most similar top_k_idx = np.argsort(similarities[i])[-k:] top_k_sim = similarities[i, top_k_idx] top_k_labels = self.train_labels_[top_k_idx] # Softmax weighting with numerical stability scaled_sim = top_k_sim / self.temperature scaled_sim = scaled_sim - np.max(scaled_sim) # For numerical stability weights = np.exp(scaled_sim) weights = weights / weights.sum() # Weighted voting for j, (label, weight) in enumerate(zip(top_k_labels, weights)): proba[i, label] += weight # Ensure probabilities are valid (clamp to [0, 1] and normalize) proba = np.clip(proba, 0.0, 1.0) row_sums = proba.sum(axis=1, keepdims=True) # Avoid division by zero row_sums = np.where(row_sums == 0, 1.0, row_sums) proba = proba / row_sums return proba
[docs] def predict(self, X) -> np.ndarray: """Predict class labels.""" proba = self.predict_proba(X) return self._label_encoder.inverse_transform(np.argmax(proba, axis=1))
[docs] def transform(self, X) -> np.ndarray: """Transform features to embedding space. Parameters ---------- X : array-like Input features. Returns ------- ndarray Learned embeddings. """ if not self._is_fitted: raise RuntimeError("ModernNCAClassifier has not been fitted.") return self._embed(np.asarray(X, dtype=np.float32))