from __future__ import annotations
"""Model registry for AutoML.
This module provides a centralized registry of all available models
with metadata about their capabilities, computational costs, and
recommended usage scenarios.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ModelInfo:
"""Information about a model in the registry.
Attributes
----------
name : str
Short name for the model (used as key).
display_name : str
Human-readable name.
family : str
Model family (gbdt, neural, linear, tree, kernel, rules, bayesian, foundation).
class_path : str
Full import path for the model class.
task_types : list of str
Supported task types ("classification", "regression", "both").
supports_sample_weight : bool
Whether the model supports sample_weight in fit().
supports_feature_importance : bool
Whether the model provides feature_importances_.
supports_gpu : bool
Whether the model can use GPU acceleration.
requires_torch : bool
Whether PyTorch is required.
requires_julia : bool
Whether Julia is required.
typical_fit_time : str
Expected fit time category: "fast", "medium", "slow", "very_slow".
memory_usage : str
Expected memory usage: "low", "medium", "high".
interpretable : bool
Whether the model is considered interpretable.
handles_categorical : bool
Whether the model natively handles categorical features.
handles_missing : bool
Whether the model natively handles missing values.
max_samples : int or None
Recommended maximum samples (None = no limit).
min_samples : int
Recommended minimum samples.
default_params : dict
Default hyperparameters for the model.
tuning_space : str or None
Name of the predefined tuning space (if any).
notes : str
Additional notes about the model.
"""
name: str
display_name: str
family: str
class_path: str
task_types: list[str] = field(default_factory=lambda: ["classification", "regression"])
supports_sample_weight: bool = True
supports_feature_importance: bool = True
supports_gpu: bool = False
requires_torch: bool = False
requires_julia: bool = False
typical_fit_time: str = "medium"
memory_usage: str = "medium"
interpretable: bool = False
handles_categorical: bool = False
handles_missing: bool = False
max_samples: int | None = None
min_samples: int = 10
default_params: dict[str, Any] = field(default_factory=dict)
tuning_space: str | None = None
required_packages: list[str] = field(default_factory=list)
notes: str = ""
# Define model families
MODEL_FAMILIES = {
"gbdt": "Gradient Boosting Decision Trees",
"neural": "Neural Networks",
"linear": "Linear Models",
"tree": "Tree-Based Models",
"kernel": "Kernel Methods",
"rules": "Rule-Based Models",
"bayesian": "Bayesian Models",
"foundation": "Foundation Models",
"ensemble": "Ensemble Methods",
}
# The main model registry
MODEL_REGISTRY: dict[str, ModelInfo] = {
# ==================== GBDT Models ====================
"lgbm": ModelInfo(
name="lgbm",
display_name="LightGBM",
family="gbdt",
class_path="endgame.models.LGBMWrapper",
supports_gpu=True,
handles_categorical=True,
handles_missing=True,
typical_fit_time="fast",
memory_usage="medium",
default_params={"preset": "endgame", "n_estimators": 2000},
tuning_space="lgbm_standard",
notes="Default choice for tabular. Fast, accurate, handles missing/categorical.",
),
"xgb": ModelInfo(
name="xgb",
display_name="XGBoost",
family="gbdt",
class_path="endgame.models.XGBWrapper",
supports_gpu=True,
handles_missing=True,
typical_fit_time="fast",
memory_usage="medium",
default_params={"preset": "endgame", "n_estimators": 2000},
tuning_space="xgb_standard",
notes="Robust GBDT, slightly different regularization than LightGBM.",
),
"catboost": ModelInfo(
name="catboost",
display_name="CatBoost",
family="gbdt",
class_path="endgame.models.CatBoostWrapper",
supports_gpu=True,
handles_categorical=True,
handles_missing=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={"preset": "endgame", "iterations": 2000},
tuning_space="catboost_standard",
notes="Best native categorical handling. Ordered boosting.",
),
"ngboost": ModelInfo(
name="ngboost",
display_name="NGBoost",
family="gbdt",
class_path="endgame.models.NGBoostClassifier", # or NGBoostRegressor
typical_fit_time="slow",
memory_usage="high",
default_params={"n_estimators": 500},
notes="Natural gradient boosting for probabilistic predictions.",
),
# ==================== Neural Network Models ====================
"ft_transformer": ModelInfo(
name="ft_transformer",
display_name="FT-Transformer",
family="neural",
class_path="endgame.models.tabular.FTTransformerClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=500,
default_params={"n_epochs": 100, "batch_size": 256},
notes="Transformer for tabular. Best deep learning for tabular.",
),
"saint": ModelInfo(
name="saint",
display_name="SAINT",
family="neural",
class_path="endgame.models.tabular.SAINTClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=500,
default_params={"n_epochs": 100},
notes="Self-attention + intersample attention transformer.",
),
"tabnet": ModelInfo(
name="tabnet",
display_name="TabNet",
family="neural",
class_path="endgame.models.neural.TabNetClassifier",
supports_gpu=True,
requires_torch=True,
supports_feature_importance=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=500,
default_params={"n_steps": 3, "n_d": 32, "n_a": 32},
required_packages=["pytorch_tabnet"],
notes="Attention-based with built-in feature selection.",
),
"node": ModelInfo(
name="node",
display_name="NODE",
family="neural",
class_path="endgame.models.tabular.NODEClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=500,
default_params={},
notes="Neural Oblivious Decision Ensembles.",
),
"nam": ModelInfo(
name="nam",
display_name="Neural Additive Model",
family="neural",
class_path="endgame.models.tabular.NAMClassifier",
supports_gpu=True,
requires_torch=True,
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
min_samples=200,
max_samples=20000,
default_params={"n_epochs": 100},
notes="Interpretable neural network (GAM-like). Slow on large datasets.",
),
"tabular_resnet": ModelInfo(
name="tabular_resnet",
display_name="Tabular ResNet",
family="neural",
class_path="endgame.models.tabular.TabularResNetClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=500,
default_params={},
notes="ResNet architecture for tabular data.",
),
"mlp": ModelInfo(
name="mlp",
display_name="MLP",
family="neural",
class_path="endgame.models.neural.MLPClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={"hidden_dims": [256, 128], "dropout": 0.3},
notes="Standard multi-layer perceptron.",
),
"embedding_mlp": ModelInfo(
name="embedding_mlp",
display_name="Embedding MLP",
family="neural",
class_path="endgame.models.neural.EmbeddingMLPClassifier",
supports_gpu=True,
requires_torch=True,
handles_categorical=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
notes="MLP with entity embeddings for categoricals.",
),
# ==================== Tree-Based Models ====================
"rotation_forest": ModelInfo(
name="rotation_forest",
display_name="Rotation Forest",
family="tree",
class_path="endgame.models.RotationForestClassifier",
typical_fit_time="slow",
memory_usage="high",
max_samples=50000,
default_params={"n_estimators": 100, "n_subsets": 3},
notes="PCA rotation for diverse trees. Good for ensemble diversity.",
),
"c50": ModelInfo(
name="c50",
display_name="C5.0",
family="tree",
class_path="endgame.models.C50Classifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Quinlan's C5.0. Interpretable rules/trees.",
),
"oblique_forest": ModelInfo(
name="oblique_forest",
display_name="Oblique Random Forest",
family="tree",
class_path="endgame.models.ObliqueRandomForestClassifier",
typical_fit_time="medium",
memory_usage="medium",
max_samples=50000,
default_params={
"n_estimators": 100,
"max_depth": None,
"n_jobs": 2,
},
notes=(
"Oblique splits (linear combos of features) for diagonal decision "
"boundaries. Uses Cython-compiled treeple when available; falls "
"back to pure-Python implementation otherwise."
),
),
"extra_oblique_forest": ModelInfo(
name="extra_oblique_forest",
display_name="Extra Oblique Random Forest",
family="tree",
class_path="endgame.models.ExtraObliqueRandomForestClassifier",
typical_fit_time="fast",
memory_usage="medium",
max_samples=100000,
default_params={
"n_estimators": 100,
"max_depth": None,
"n_jobs": 2,
},
notes=(
"Extra-trees variant of oblique RF — maximally random oblique splits. "
"Fastest oblique forest, excellent ensemble diversity. Requires treeple."
),
),
"patch_oblique_forest": ModelInfo(
name="patch_oblique_forest",
display_name="Patch Oblique Random Forest",
family="tree",
class_path="endgame.models.PatchObliqueRandomForestClassifier",
typical_fit_time="medium",
memory_usage="medium",
max_samples=50000,
default_params={
"n_estimators": 100,
"max_depth": None,
"n_jobs": 2,
},
notes=(
"Oblique splits on contiguous feature patches — effective on spatial "
"or structured tabular data. Requires treeple."
),
),
"honest_forest": ModelInfo(
name="honest_forest",
display_name="Honest Forest",
family="tree",
class_path="endgame.models.HonestForestClassifier",
task_types=["classification"],
typical_fit_time="medium",
memory_usage="medium",
max_samples=50000,
default_params={
"n_estimators": 100,
"n_jobs": 2,
"honest_fraction": 0.5,
},
notes=(
"Honest forests use separate data for tree structure and leaf estimates, "
"yielding better-calibrated probabilities. Requires treeple."
),
),
"evolutionary_tree": ModelInfo(
name="evolutionary_tree",
display_name="Evolutionary Tree",
family="tree",
class_path="endgame.models.trees.EvolutionaryTreeClassifier",
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
max_samples=10000,
default_params={"n_generations": 100},
notes="Genetic algorithm optimized trees.",
),
"quantile_forest": ModelInfo(
name="quantile_forest",
display_name="Quantile Regression Forest",
family="tree",
class_path="endgame.models.QuantileRegressorForest",
task_types=["regression"],
typical_fit_time="medium",
memory_usage="high",
default_params={"n_estimators": 100},
notes="For prediction intervals and uncertainty quantification.",
),
# ==================== Linear Models ====================
"linear": ModelInfo(
name="linear",
display_name="Linear Model",
family="linear",
class_path="endgame.models.baselines.LinearClassifier",
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={"penalty": "l2", "C": 1.0},
notes="Fast baseline. L1/L2/ElasticNet regularization.",
),
"elm": ModelInfo(
name="elm",
display_name="Extreme Learning Machine",
family="linear",
class_path="endgame.models.baselines.ELMClassifier",
typical_fit_time="fast",
memory_usage="low",
default_params={"n_hidden": 500},
notes="Random projection NN. Very fast, good for ensembles.",
),
"mars": ModelInfo(
name="mars",
display_name="MARS",
family="linear",
class_path="endgame.models.MARSClassifier",
interpretable=True,
typical_fit_time="medium",
memory_usage="low",
max_samples=50000,
default_params={},
notes="Multivariate Adaptive Regression Splines. Interpretable.",
),
# ==================== Kernel Methods ====================
"svm": ModelInfo(
name="svm",
display_name="SVM",
family="kernel",
class_path="endgame.models.kernel.SVMClassifier",
typical_fit_time="slow",
memory_usage="high",
max_samples=10000,
default_params={"kernel": "rbf", "C": 1.0},
notes="Support Vector Machine. Slow for large data.",
),
"gp": ModelInfo(
name="gp",
display_name="Gaussian Process",
family="kernel",
class_path="endgame.models.kernel.GPClassifier",
typical_fit_time="very_slow",
memory_usage="high",
max_samples=5000,
default_params={},
notes="Bayesian kernel method. Great for small data + uncertainty.",
),
# ==================== Rule-Based Models ====================
"rulefit": ModelInfo(
name="rulefit",
display_name="RuleFit",
family="rules",
class_path="endgame.models.RuleFitClassifier",
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
max_samples=15000,
default_params={},
notes="Rule-based ensemble. Interpretable rules. Slow on large datasets due to RF base.",
),
"furia": ModelInfo(
name="furia",
display_name="FURIA",
family="rules",
class_path="endgame.models.FURIAClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
max_samples=20000,
default_params={},
notes="Fuzzy Unordered Rule Induction. Interpretable fuzzy rules. Slow on large datasets.",
),
# ==================== Bayesian Models ====================
"tan": ModelInfo(
name="tan",
display_name="TAN",
family="bayesian",
class_path="endgame.models.bayesian.TANClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Tree Augmented Naive Bayes. Fast probabilistic.",
),
"eskdb": ModelInfo(
name="eskdb",
display_name="ESKDB",
family="bayesian",
class_path="endgame.models.bayesian.ESKDBClassifier",
task_types=["classification"],
typical_fit_time="medium",
memory_usage="medium",
default_params={},
notes="Ensemble of Selective K-Dependence Bayes.",
),
"kdb": ModelInfo(
name="kdb",
display_name="KDB",
family="bayesian",
class_path="endgame.models.bayesian.KDBClassifier",
task_types=["classification"],
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="K-Dependence Bayes classifier.",
),
"bart": ModelInfo(
name="bart",
display_name="BART",
family="bayesian",
class_path="endgame.models.probabilistic.BARTClassifier",
typical_fit_time="very_slow",
memory_usage="high",
max_samples=10000,
default_params={},
required_packages=["pymc_bart"],
notes="Bayesian Additive Regression Trees. MCMC-based.",
),
# ==================== Baselines ====================
"naive_bayes": ModelInfo(
name="naive_bayes",
display_name="Naive Bayes",
family="bayesian",
class_path="endgame.models.baselines.NaiveBayesClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Auto-selecting Naive Bayes (Gaussian/Bernoulli/Multinomial).",
),
"lda": ModelInfo(
name="lda",
display_name="LDA",
family="linear",
class_path="endgame.models.baselines.LDAClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Linear Discriminant Analysis.",
),
"qda": ModelInfo(
name="qda",
display_name="QDA",
family="linear",
class_path="endgame.models.baselines.QDAClassifier",
task_types=["classification"],
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Quadratic Discriminant Analysis.",
),
"knn": ModelInfo(
name="knn",
display_name="KNN",
family="kernel",
class_path="endgame.models.baselines.KNNClassifier",
typical_fit_time="fast",
memory_usage="high", # Stores all training data
max_samples=50000,
default_params={"n_neighbors": 5},
notes="K-Nearest Neighbors. Memory scales with data size.",
),
# ==================== Interpretable Models ====================
"ebm": ModelInfo(
name="ebm",
display_name="EBM",
family="ensemble",
class_path="endgame.models.EBMClassifier",
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
default_params={"interactions": 10},
required_packages=["interpret"],
notes="Explainable Boosting Machine. Glass-box interpretability.",
),
# ==================== Foundation Models ====================
"tabpfn": ModelInfo(
name="tabpfn",
display_name="TabPFN",
family="foundation",
class_path="endgame.models.tabular.TabPFNClassifier",
task_types=["classification"],
requires_julia=False,
typical_fit_time="fast",
memory_usage="high",
max_samples=10000,
min_samples=10,
default_params={},
required_packages=["tabpfn"],
notes="In-context learning. Zero training time. Great for small data.",
),
# ==================== Additional Neural Models ====================
"tab_transformer": ModelInfo(
name="tab_transformer",
display_name="Tab Transformer",
family="neural",
class_path="endgame.models.tabular.TabTransformerClassifier",
supports_gpu=True,
requires_torch=True,
handles_categorical=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=500,
default_params={"max_epochs": 100, "batch_size": 256},
required_packages=["pytorch_tabular"],
notes="Transformer for tabular with column-wise attention.",
),
"gandalf": ModelInfo(
name="gandalf",
display_name="GANDALF",
family="neural",
class_path="endgame.models.tabular.GANDALFClassifier",
supports_gpu=True,
requires_torch=True,
supports_feature_importance=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
required_packages=["pytorch_tabular"],
notes="Gated Adaptive Network for Deep Automated Learning of Features.",
),
"modern_nca": ModelInfo(
name="modern_nca",
display_name="Modern NCA",
family="neural",
class_path="endgame.models.tabular.ModernNCAClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=100,
default_params={},
notes="Modern Neighborhood Component Analysis with neural networks.",
),
# ==================== Additional Bayesian Models ====================
"neural_kdb": ModelInfo(
name="neural_kdb",
display_name="Neural KDB",
family="bayesian",
class_path="endgame.models.bayesian.NeuralKDBClassifier",
task_types=["classification"],
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={},
notes="Neural K-Dependence Bayes classifier with learned embeddings.",
),
# ==================== Additional Ensemble Models ====================
"ebmc": ModelInfo(
name="ebmc",
display_name="EBM (Classification)",
family="ensemble",
class_path="endgame.models.EBMClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
default_params={"interactions": 10},
notes="Explainable Boosting Machine for classification.",
),
"ebmr": ModelInfo(
name="ebmr",
display_name="EBM (Regression)",
family="ensemble",
class_path="endgame.models.EBMRegressor",
task_types=["regression"],
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
default_params={"interactions": 10},
notes="Explainable Boosting Machine for regression.",
),
# ==================== Interpretable Models ====================
"corels": ModelInfo(
name="corels",
display_name="CORELS",
family="rules",
class_path="endgame.models.interpretable.CORELSClassifier",
task_types=["classification"],
supports_sample_weight=False,
supports_feature_importance=False,
interpretable=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={"max_card": 2, "c": 0.001},
notes="Rule List Classifier. Greedy sequential covering with beam search.",
),
"node_gam": ModelInfo(
name="node_gam",
display_name="NODE-GAM",
family="neural",
class_path="endgame.models.interpretable.NodeGAMClassifier",
supports_gpu=True,
requires_torch=True,
interpretable=True,
typical_fit_time="very_slow",
memory_usage="high",
min_samples=100,
max_samples=10000,
default_params={"n_trees_per_feature": 32, "depth": 4},
notes="Neural Oblivious Decision Ensembles as GAMs. Differentiable tree-based GAM. Very slow on large datasets.",
),
"gami_net": ModelInfo(
name="gami_net",
display_name="GAMI-Net",
family="neural",
class_path="endgame.models.interpretable.GAMINetClassifier",
supports_gpu=True,
requires_torch=True,
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
min_samples=100,
max_samples=20000,
default_params={"interact_num": 10},
notes="Generalized Additive Models with Structured Interactions. Slow on large datasets.",
),
"slim": ModelInfo(
name="slim",
display_name="SLIM",
family="linear",
class_path="endgame.models.interpretable.SLIMClassifier",
task_types=["classification"],
supports_sample_weight=False,
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={"max_coef": 5},
required_packages=["fasterrisk"],
notes="Supersparse Linear Integer Models. Produces scorecards with small integers.",
),
"fasterrisk": ModelInfo(
name="fasterrisk",
display_name="FasterRisk",
family="linear",
class_path="endgame.models.interpretable.FasterRiskClassifier",
task_types=["classification"],
supports_sample_weight=False,
interpretable=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={"max_coef": 5, "sparsity": 10},
notes="Fast and accurate risk scores. Optimized SLIM variant.",
),
"gam": ModelInfo(
name="gam",
display_name="GAM",
family="linear",
class_path="endgame.models.interpretable.GAMClassifier",
interpretable=True,
typical_fit_time="medium",
memory_usage="low",
default_params={"n_splines": 25},
required_packages=["pygam"],
notes="Generalized Additive Models via pyGAM. Smooth shape functions.",
),
"gosdt": ModelInfo(
name="gosdt",
display_name="GOSDT",
family="tree",
class_path="endgame.models.interpretable.GOSDTClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="slow",
memory_usage="medium",
max_samples=10000,
default_params={"regularization": 0.01, "depth_budget": 5},
required_packages=["gosdt"],
notes="Globally Optimal Sparse Decision Trees. Provably optimal trees.",
),
# ==================== Symbolic Regression ====================
"symbolic_regression": ModelInfo(
name="symbolic_regression",
display_name="Symbolic Regression",
family="rules",
class_path="endgame.models.symbolic.SymbolicClassifier",
interpretable=True,
requires_julia=False,
typical_fit_time="slow",
memory_usage="medium",
max_samples=50000,
default_params={"preset": "default", "operators": "scientific"},
notes="GP-based symbolic regression. Discovers interpretable equations.",
),
"symbolic_regressor": ModelInfo(
name="symbolic_regressor",
display_name="Symbolic Regressor",
family="rules",
class_path="endgame.models.symbolic.SymbolicRegressor",
task_types=["regression"],
interpretable=True,
requires_julia=False,
typical_fit_time="slow",
memory_usage="medium",
max_samples=50000,
default_params={"preset": "default", "operators": "scientific"},
notes="GP-based symbolic regression for regression tasks.",
),
# ==================== Additional Tree Models ====================
"adtree": ModelInfo(
name="adtree",
display_name="Alternating Decision Tree",
family="tree",
class_path="endgame.models.trees.AlternatingDecisionTreeClassifier",
task_types=["classification"],
supports_feature_importance=True,
interpretable=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={"n_iterations": 10},
notes="Alternating Decision Tree. Interpretable boosted tree ensemble.",
),
"model_tree": ModelInfo(
name="model_tree",
display_name="Alternating Model Tree",
family="tree",
class_path="endgame.models.trees.AlternatingModelTreeRegressor",
task_types=["regression"],
supports_feature_importance=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={},
notes="Model trees with linear models at leaves. Piecewise linear regression.",
),
"cubist": ModelInfo(
name="cubist",
display_name="Cubist",
family="tree",
class_path="endgame.models.CubistRegressor",
task_types=["regression"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={"committees": 1, "neighbors": 0},
notes="Rule-based regression (Quinlan). Interpretable piecewise linear.",
),
"c50_ensemble": ModelInfo(
name="c50_ensemble",
display_name="C5.0 Ensemble",
family="tree",
class_path="endgame.models.C50Ensemble",
task_types=["classification"],
supports_feature_importance=True,
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={"n_trials": 10},
notes="Boosted C5.0 ensemble. Interpretable boosted trees.",
),
# ==================== Ordinal Regression ====================
"ordinal": ModelInfo(
name="ordinal",
display_name="Ordinal Classifier",
family="linear",
class_path="endgame.models.ordinal.OrdinalClassifier",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={"variant": "auto"},
notes="Auto-selecting ordinal classifier for ordered categorical targets.",
),
"logistic_at": ModelInfo(
name="logistic_at",
display_name="Logistic All-Threshold",
family="linear",
class_path="endgame.models.ordinal.LogisticAT",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Ordinal regression with all-threshold loss.",
),
"logistic_it": ModelInfo(
name="logistic_it",
display_name="Logistic Immediate-Threshold",
family="linear",
class_path="endgame.models.ordinal.LogisticIT",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Ordinal regression with immediate-threshold loss.",
),
"logistic_se": ModelInfo(
name="logistic_se",
display_name="Logistic Squared-Error",
family="linear",
class_path="endgame.models.ordinal.LogisticSE",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Ordinal regression with squared-error loss.",
),
"ordinal_ridge": ModelInfo(
name="ordinal_ridge",
display_name="Ordinal Ridge",
family="linear",
class_path="endgame.models.ordinal.OrdinalRidge",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Ordinal regression with ridge penalty.",
),
"ordinal_lad": ModelInfo(
name="ordinal_lad",
display_name="Ordinal LAD",
family="linear",
class_path="endgame.models.ordinal.LAD",
task_types=["classification"],
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Ordinal regression with least absolute deviation.",
),
# ==================== Additional Bayesian Models ====================
"auto_sle": ModelInfo(
name="auto_sle",
display_name="AutoSLE",
family="bayesian",
class_path="endgame.models.bayesian.AutoSLE",
task_types=[],
typical_fit_time="medium",
memory_usage="medium",
default_params={},
notes="Structure learning only — not an sklearn estimator, excluded from AutoML.",
),
"ebmc_classifier": ModelInfo(
name="ebmc_classifier",
display_name="EBMC",
family="bayesian",
class_path="endgame.models.bayesian.EBMCClassifier",
task_types=["classification"],
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Efficient Bayesian Multivariate Classifier.",
),
# ==================== Modern Neural Models ====================
"tabm": ModelInfo(
name="tabm",
display_name="TabM",
family="neural",
class_path="endgame.models.tabular.tabm.TabMClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
notes="TabM: Tabular Model with batch ensembling. Strong SOTA performance.",
),
"realmlp": ModelInfo(
name="realmlp",
display_name="RealMLP",
family="neural",
class_path="endgame.models.tabular.realmlp.RealMLPClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
notes="RealMLP: Well-tuned MLP baseline with modern training recipe.",
),
"grande": ModelInfo(
name="grande",
display_name="GRANDE",
family="neural",
class_path="endgame.models.tabular.grande.GRANDEClassifier",
supports_gpu=True,
requires_torch=True,
supports_feature_importance=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
notes="GRANDE: Gradient-based decision tree ensembles. Differentiable trees.",
),
"tabdpt": ModelInfo(
name="tabdpt",
display_name="TabDPT",
family="neural",
class_path="endgame.models.tabular.tabdpt.TabDPTClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="slow",
memory_usage="high",
min_samples=200,
default_params={},
required_packages=["tabdpt"],
notes="TabDPT: Tabular Data Pre-Training. In-context learning for tabular.",
),
"tabr": ModelInfo(
name="tabr",
display_name="TabR",
family="neural",
class_path="endgame.models.tabular.tabr.TabRClassifier",
supports_gpu=True,
requires_torch=True,
typical_fit_time="medium",
memory_usage="medium",
min_samples=200,
default_params={},
notes="TabR: Tabular Retrieval-augmented learning. kNN-enhanced deep model.",
),
# ==================== Random Forest Variants ====================
"rf": ModelInfo(
name="rf",
display_name="Random Forest",
family="tree",
class_path="sklearn.ensemble.RandomForestClassifier",
supports_feature_importance=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={"n_estimators": 200, "max_depth": None, "n_jobs": -1},
notes="Standard sklearn Random Forest. Good baseline ensemble.",
),
"extra_trees": ModelInfo(
name="extra_trees",
display_name="Extra Trees",
family="tree",
class_path="sklearn.ensemble.ExtraTreesClassifier",
supports_feature_importance=True,
typical_fit_time="fast",
memory_usage="medium",
default_params={"n_estimators": 200, "n_jobs": -1},
notes="Extremely Randomized Trees. Faster than RF, different bias.",
),
# ==================== Foundation Models (v2) ====================
"tabpfn_v2": ModelInfo(
name="tabpfn_v2",
display_name="TabPFN v2",
family="foundation",
class_path="endgame.models.tabular.tabpfn.TabPFNv2Classifier",
task_types=["classification"],
supports_sample_weight=False,
supports_feature_importance=False,
handles_categorical=True,
typical_fit_time="fast",
memory_usage="high",
max_samples=10000,
min_samples=10,
default_params={},
notes="TabPFN v2: improved in-context learning. Great for small data.",
),
"tabpfn_25": ModelInfo(
name="tabpfn_25",
display_name="TabPFN 2025",
family="foundation",
class_path="endgame.models.tabular.tabpfn.TabPFN25Classifier",
task_types=["classification"],
supports_sample_weight=False,
supports_feature_importance=False,
handles_categorical=True,
typical_fit_time="fast",
memory_usage="high",
max_samples=50000,
min_samples=10,
default_params={},
notes="TabPFN 2025: scales to 50K samples. State-of-the-art ICL.",
),
# ==================== Foundation Models (xRFM) ====================
"xrfm": ModelInfo(
name="xrfm",
display_name="xRFM",
family="foundation",
class_path="endgame.models.tabular.xrfm.xRFMClassifier",
interpretable=True,
typical_fit_time="medium",
memory_usage="medium",
default_params={},
notes="Explainable Random Feature Model. Interpretable foundation model.",
),
# ==================== Subgroup Discovery ====================
"prim": ModelInfo(
name="prim",
display_name="PRIM",
family="rules",
class_path="endgame.models.subgroup.prim.PRIMClassifier",
task_types=["classification"],
supports_sample_weight=False,
interpretable=True,
typical_fit_time="fast",
memory_usage="low",
default_params={},
notes="Patient Rule Induction Method. Bump hunting for subgroup discovery.",
),
}
[docs]
def register_model(
name: str,
display_name: str,
family: str,
class_path: str,
task_types: list[str] | None = None,
supports_sample_weight: bool = True,
supports_feature_importance: bool = True,
supports_gpu: bool = False,
requires_torch: bool = False,
requires_julia: bool = False,
typical_fit_time: str = "medium",
memory_usage: str = "medium",
interpretable: bool = False,
handles_categorical: bool = False,
handles_missing: bool = False,
max_samples: int | None = None,
min_samples: int = 10,
default_params: dict[str, Any] | None = None,
tuning_space: str | None = None,
notes: str = "",
overwrite: bool = False,
) -> ModelInfo:
"""Register a new model in the registry.
This function provides a convenient way to add new models to the registry
at runtime, useful for plugins or custom model extensions.
Parameters
----------
name : str
Short name for the model (used as key).
display_name : str
Human-readable name.
family : str
Model family (gbdt, neural, linear, tree, kernel, rules, bayesian, foundation, ensemble).
class_path : str
Full import path for the model class.
task_types : list of str, optional
Supported task types. Default is ["classification", "regression"].
supports_sample_weight : bool, default=True
Whether the model supports sample_weight in fit().
supports_feature_importance : bool, default=True
Whether the model provides feature_importances_.
supports_gpu : bool, default=False
Whether the model can use GPU acceleration.
requires_torch : bool, default=False
Whether PyTorch is required.
requires_julia : bool, default=False
Whether Julia is required.
typical_fit_time : str, default="medium"
Expected fit time category: "fast", "medium", "slow", "very_slow".
memory_usage : str, default="medium"
Expected memory usage: "low", "medium", "high".
interpretable : bool, default=False
Whether the model is considered interpretable.
handles_categorical : bool, default=False
Whether the model natively handles categorical features.
handles_missing : bool, default=False
Whether the model natively handles missing values.
max_samples : int, optional
Recommended maximum samples (None = no limit).
min_samples : int, default=10
Recommended minimum samples.
default_params : dict, optional
Default hyperparameters for the model.
tuning_space : str, optional
Name of the predefined tuning space (if any).
notes : str, default=""
Additional notes about the model.
overwrite : bool, default=False
If True, overwrite existing entry. If False, raise error if exists.
Returns
-------
ModelInfo
The registered model info.
Raises
------
ValueError
If model already exists and overwrite=False.
Examples
--------
>>> register_model(
... name="my_model",
... display_name="My Custom Model",
... family="neural",
... class_path="mypackage.models.MyModelClassifier",
... supports_gpu=True,
... requires_torch=True,
... typical_fit_time="fast",
... notes="My custom neural network model.",
... )
"""
if name in MODEL_REGISTRY and not overwrite:
raise ValueError(
f"Model '{name}' already exists in registry. "
f"Use overwrite=True to replace it."
)
if family not in MODEL_FAMILIES:
logger.warning(
f"Family '{family}' not in standard families: {list(MODEL_FAMILIES.keys())}. "
f"Adding anyway."
)
info = ModelInfo(
name=name,
display_name=display_name,
family=family,
class_path=class_path,
task_types=task_types or ["classification", "regression"],
supports_sample_weight=supports_sample_weight,
supports_feature_importance=supports_feature_importance,
supports_gpu=supports_gpu,
requires_torch=requires_torch,
requires_julia=requires_julia,
typical_fit_time=typical_fit_time,
memory_usage=memory_usage,
interpretable=interpretable,
handles_categorical=handles_categorical,
handles_missing=handles_missing,
max_samples=max_samples,
min_samples=min_samples,
default_params=default_params or {},
tuning_space=tuning_space,
notes=notes,
)
MODEL_REGISTRY[name] = info
logger.debug(f"Registered model: {name} ({display_name})")
return info
[docs]
def unregister_model(name: str) -> bool:
"""Remove a model from the registry.
Parameters
----------
name : str
Model name to remove.
Returns
-------
bool
True if model was removed, False if it didn't exist.
"""
if name in MODEL_REGISTRY:
del MODEL_REGISTRY[name]
logger.debug(f"Unregistered model: {name}")
return True
return False
[docs]
def get_model_info(name: str) -> ModelInfo:
"""Get information about a model.
Parameters
----------
name : str
Model name (key in registry).
Returns
-------
ModelInfo
Information about the model.
Raises
------
KeyError
If model is not in registry.
"""
if name not in MODEL_REGISTRY:
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
raise KeyError(f"Model '{name}' not found. Available: {available}")
return MODEL_REGISTRY[name]
[docs]
def list_models(
family: str | None = None,
task_type: str | None = None,
interpretable_only: bool = False,
gpu_only: bool = False,
exclude_slow: bool = False,
max_samples: int | None = None,
) -> list[str]:
"""List models matching criteria.
Parameters
----------
family : str, optional
Filter by model family.
task_type : str, optional
Filter by task type ("classification" or "regression").
interpretable_only : bool, default=False
Only include interpretable models.
gpu_only : bool, default=False
Only include GPU-capable models.
exclude_slow : bool, default=False
Exclude slow and very_slow models.
max_samples : int, optional
Only include models that can handle this many samples.
Returns
-------
list of str
Model names matching criteria.
"""
result = []
for name, info in MODEL_REGISTRY.items():
# Family filter
if family and info.family != family:
continue
# Task type filter
if task_type:
if task_type not in info.task_types and "both" not in info.task_types:
continue
# Interpretable filter
if interpretable_only and not info.interpretable:
continue
# GPU filter
if gpu_only and not info.supports_gpu:
continue
# Speed filter
if exclude_slow and info.typical_fit_time in ("slow", "very_slow"):
continue
# Sample size filter
if max_samples and info.max_samples and max_samples > info.max_samples:
continue
result.append(name)
return sorted(result)
[docs]
def get_model_class(name: str) -> type:
"""Get the model class for a given name.
Parameters
----------
name : str
Model name.
Returns
-------
type
The model class.
Raises
------
ImportError
If the model class cannot be imported.
"""
info = get_model_info(name)
# Parse the class path
parts = info.class_path.rsplit(".", 1)
if len(parts) != 2:
raise ImportError(f"Invalid class path: {info.class_path}")
module_path, class_name = parts
try:
import importlib
module = importlib.import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(
f"Could not import {info.class_path}: {e}. "
f"You may need to install additional dependencies."
) from e
def instantiate_model(
name: str,
task_type: str = "classification",
**override_params,
) -> Any:
"""Instantiate a model by name.
Parameters
----------
name : str
Model name.
task_type : str, default="classification"
Task type ("classification" or "regression").
Some models have different classes for different tasks.
**override_params
Parameters to override the defaults.
Returns
-------
estimator
Instantiated model.
"""
info = get_model_info(name)
# Get the class (may need to adjust for task type)
class_path = info.class_path
if task_type == "regression" and "Classifier" in class_path:
class_path = class_path.replace("Classifier", "Regressor")
parts = class_path.rsplit(".", 1)
module_path, class_name = parts
try:
import importlib
module = importlib.import_module(module_path)
model_class = getattr(module, class_name)
except (ImportError, AttributeError):
# Fallback: try the original class path
model_class = get_model_class(name)
# Combine default params with overrides
params = info.default_params.copy()
params.update(override_params)
return model_class(**params)
[docs]
def get_models_by_family() -> dict[str, list[str]]:
"""Get models grouped by family.
Returns
-------
dict
Mapping from family name to list of model names.
"""
result: dict[str, list[str]] = {}
for name, info in MODEL_REGISTRY.items():
if info.family not in result:
result[info.family] = []
result[info.family].append(name)
return {k: sorted(v) for k, v in result.items()}
[docs]
def get_default_portfolio(
task_type: str = "classification",
n_samples: int = 10000,
time_budget: str = "medium",
gpu_available: bool = False,
) -> list[str]:
"""Get a recommended portfolio of models based on data characteristics.
Parameters
----------
task_type : str
Task type.
n_samples : int
Number of samples in the dataset.
time_budget : str
Time budget: "fast", "medium", "high", "unlimited".
gpu_available : bool
Whether GPU is available.
Returns
-------
list of str
Recommended model names.
"""
# Start with GBDTs (always include)
portfolio = ["lgbm", "xgb", "catboost"]
# Add based on data size
if n_samples < 5000:
portfolio.append("tabpfn")
portfolio.append("gp")
# Add based on time budget
if time_budget in ("high", "unlimited"):
portfolio.extend(["ft_transformer", "saint", "rotation_forest"])
if task_type == "classification":
portfolio.extend(["ebm", "ngboost"])
elif time_budget == "medium":
portfolio.extend(["linear", "elm"])
if gpu_available:
portfolio.append("ft_transformer")
else: # fast
# Just keep LightGBM
portfolio = ["lgbm"]
# Filter by task type
portfolio = [
m
for m in portfolio
if task_type in MODEL_REGISTRY[m].task_types
or "both" in MODEL_REGISTRY[m].task_types
]
# Filter by sample size
portfolio = [
m
for m in portfolio
if MODEL_REGISTRY[m].max_samples is None
or n_samples <= MODEL_REGISTRY[m].max_samples
]
return portfolio
# ==================== Interpretable Models ====================
# Canonical list of interpretable model names
# These are models that provide human-understandable explanations
INTERPRETABLE_MODELS: set[str] = {
# Rule-based models
"corels", # Certifiably Optimal Rule Lists
"rulefit", # Rule ensemble with linear combination
"furia", # Fuzzy Unordered Rule Induction
"symbolic_regression", # Symbolic regression (GP-based)
"symbolic_regressor", # Symbolic regression for regression
# GAM-style models (additive with shape functions)
"gam", # Generalized Additive Models (pyGAM)
"ebm", # Explainable Boosting Machine
"ebmc", # EBM for classification
"ebmr", # EBM for regression
"nam", # Neural Additive Models
"node_gam", # NODE-GAM (differentiable tree-based GAM)
"gami_net", # GAMI-Net (GAM with structured interactions)
# Sparse linear models (scorecards)
"slim", # Supersparse Linear Integer Models
"fasterrisk", # Fast risk scores
"linear", # Linear/logistic regression
"mars", # Multivariate Adaptive Regression Splines
# Interpretable trees
"gosdt", # Globally Optimal Sparse Decision Trees
"c50", # C5.0 decision trees/rules
"c50_ensemble", # Boosted C5.0 ensemble
"evolutionary_tree", # Evolutionary optimized trees
"adtree", # Alternating Decision Tree
"cubist", # Cubist rule-based regression
# Ordinal regression (linear, interpretable)
"ordinal", # Auto-selecting ordinal classifier
"logistic_at", # Logistic All-Threshold
"logistic_it", # Logistic Immediate-Threshold
"logistic_se", # Logistic Squared-Error
"ordinal_ridge", # Ordinal Ridge
"ordinal_lad", # Ordinal LAD
# Bayesian models (probabilistic interpretability)
"naive_bayes", # Naive Bayes
"tan", # Tree Augmented Naive Bayes
"lda", # Linear Discriminant Analysis
# Foundation models (interpretable)
"xrfm", # Explainable Random Feature Model
# Subgroup discovery
"prim", # Patient Rule Induction Method
}
def get_interpretable_models(
task_type: str | None = None,
exclude_slow: bool = False,
max_samples: int | None = None,
) -> list[str]:
"""Get list of interpretable models.
Parameters
----------
task_type : str, optional
Filter by task type ("classification" or "regression").
exclude_slow : bool, default=False
Exclude slow and very_slow models.
max_samples : int, optional
Only include models that can handle this many samples.
Returns
-------
list of str
Interpretable model names.
"""
return list_models(
task_type=task_type,
interpretable_only=True,
exclude_slow=exclude_slow,
max_samples=max_samples,
)
def get_interpretable_portfolio(
task_type: str = "classification",
n_samples: int = 10000,
time_budget: str = "medium",
) -> list[str]:
"""Get a recommended portfolio of interpretable models.
Parameters
----------
task_type : str
Task type ("classification" or "regression").
n_samples : int
Number of samples in the dataset.
time_budget : str
Time budget: "fast", "medium", "high", "unlimited".
Returns
-------
list of str
Recommended interpretable model names.
"""
# Core interpretable models (always include)
if task_type == "classification":
portfolio = ["ebm", "gam", "linear"]
else:
portfolio = ["ebmr", "gam", "linear", "mars"]
# Add based on time budget
if time_budget in ("high", "unlimited"):
if task_type == "classification":
portfolio.extend(["rulefit", "furia", "corels", "gosdt", "nam", "gami_net", "node_gam"])
else:
portfolio.extend(["nam", "gami_net", "node_gam", "symbolic_regressor"])
elif time_budget == "medium":
if task_type == "classification":
portfolio.extend(["rulefit", "slim", "node_gam"])
else:
portfolio.extend(["rulefit", "node_gam"])
# fast: just keep the basics
# Filter by task type
portfolio = [
m for m in portfolio
if m in MODEL_REGISTRY and (
task_type in MODEL_REGISTRY[m].task_types
or "both" in MODEL_REGISTRY[m].task_types
)
]
# Filter by sample size
portfolio = [
m for m in portfolio
if m in MODEL_REGISTRY and (
MODEL_REGISTRY[m].max_samples is None
or n_samples <= MODEL_REGISTRY[m].max_samples
)
]
# Remove duplicates while preserving order
seen = set()
unique_portfolio = []
for m in portfolio:
if m not in seen:
seen.add(m)
unique_portfolio.append(m)
return unique_portfolio