from __future__ import annotations
"""ONNX export for sklearn-compatible estimators.
Converts fitted estimators to ONNX format for portable, framework-agnostic
inference. Auto-detects the best conversion backend based on estimator type:
- sklearn / sklearn-compatible models -> skl2onnx
- Tree-based GBDTs (LightGBM, XGBoost, CatBoost) -> skl2onnx (with converters)
- PyTorch ``nn.Module``-backed models -> ``torch.onnx.export``
- Fallback for unsupported sklearn models -> hummingbird-ml
Examples
--------
>>> import endgame as eg
>>> from sklearn.ensemble import RandomForestClassifier
>>> model = RandomForestClassifier().fit(X_train, y_train)
>>> eg.export_onnx(model, "/tmp/model.onnx", sample_input=X_train[:1])
'/tmp/model.onnx'
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from endgame.persistence._detection import find_torch_modules, has_torch_modules
logger = logging.getLogger(__name__)
# ONNX file extension
ONNX_EXT = ".onnx"
# Mapping from numpy dtype to ONNX TensorProto type names
_NP_DTYPE_TO_ONNX = {
np.float32: "FloatTensorType",
np.float64: "DoubleTensorType",
np.int64: "Int64TensorType",
np.int32: "Int32TensorType",
}
def _ensure_skl2onnx():
"""Import skl2onnx, raising a clear error if unavailable."""
try:
import skl2onnx
return skl2onnx
except ImportError:
raise ImportError(
"skl2onnx is required for ONNX export of sklearn models. "
"Install it with: pip install skl2onnx"
)
def _ensure_onnx():
"""Import onnx, raising a clear error if unavailable."""
try:
import onnx
return onnx
except ImportError:
raise ImportError(
"onnx is required for ONNX export. "
"Install it with: pip install onnx"
)
def _ensure_hummingbird():
"""Import hummingbird-ml, raising a clear error if unavailable."""
try:
from hummingbird.ml import convert as hb_convert
return hb_convert
except ImportError:
raise ImportError(
"hummingbird-ml is required for this conversion backend. "
"Install it with: pip install hummingbird-ml"
)
def _is_gbdt(estimator: Any) -> bool:
"""Check if the estimator is a tree-based GBDT model.
Detects LightGBM, XGBoost, and CatBoost models, including endgame
wrappers around them.
Args:
estimator: The estimator to check.
Returns:
True if the estimator is a GBDT.
"""
cls_name = type(estimator).__name__
module_name = type(estimator).__module__ or ""
# Direct LightGBM / XGBoost / CatBoost instances
gbdt_modules = ("lightgbm", "xgboost", "catboost")
if any(mod in module_name for mod in gbdt_modules):
return True
# Endgame wrappers
gbdt_wrapper_names = (
"LGBMWrapper", "XGBWrapper", "CatBoostWrapper", "GBDTWrapper",
)
if cls_name in gbdt_wrapper_names:
return True
return False
def _is_sklearn_compatible(estimator: Any) -> bool:
"""Check if the estimator follows the sklearn API.
Args:
estimator: The estimator to check.
Returns:
True if the estimator has ``fit`` and ``predict`` methods.
"""
return hasattr(estimator, "fit") and hasattr(estimator, "predict")
def _unwrap_estimator(estimator: Any) -> Any:
"""Unwrap endgame wrappers to get the underlying library estimator.
Endgame GBDT wrappers store the underlying model in ``model_`` after
fitting. For ONNX conversion, we need the raw LightGBM/XGBoost/CatBoost
object.
Args:
estimator: An endgame wrapper or raw estimator.
Returns:
The underlying estimator, or the input unchanged.
"""
# Endgame GBDT wrappers store the fitted model in model_
if hasattr(estimator, "model_") and _is_gbdt(estimator):
return estimator.model_
return estimator
def _detect_onnx_backend(estimator: Any, preferred: str) -> str:
"""Determine the best ONNX conversion backend for the estimator.
Args:
estimator: The fitted estimator.
preferred: User-requested backend. ``"auto"`` enables detection.
Returns:
Backend name: ``"skl2onnx"``, ``"torch"``, or ``"hummingbird"``.
Raises:
ValueError: If the preferred backend is not recognised.
"""
valid_backends = ("auto", "skl2onnx", "hummingbird", "torch")
if preferred not in valid_backends:
raise ValueError(
f"Unknown ONNX backend '{preferred}'. "
f"Choose from: {valid_backends}"
)
if preferred != "auto":
return preferred
# PyTorch-backed estimators
if has_torch_modules(estimator):
return "torch"
# Everything else goes through skl2onnx (which has converters for
# sklearn, LightGBM, XGBoost, and CatBoost)
return "skl2onnx"
def _infer_initial_types(
sample_input: np.ndarray | None,
estimator: Any,
) -> list:
"""Build the ``initial_types`` list required by skl2onnx.
Args:
sample_input: A sample input array for shape/dtype inference.
If ``None``, attempts to infer from the estimator's
``n_features_in_`` attribute.
estimator: The fitted estimator (used for fallback shape inference).
Returns:
A list of ``(name, type)`` pairs suitable for
``skl2onnx.convert_sklearn``.
Raises:
ValueError: If the input shape cannot be determined.
"""
skl2onnx = _ensure_skl2onnx()
from skl2onnx.common.data_types import (
FloatTensorType,
Int32TensorType,
Int64TensorType,
)
type_map = {
np.float32: FloatTensorType,
np.float64: FloatTensorType, # cast to float32 for ONNX compat
np.int64: Int64TensorType,
np.int32: Int32TensorType,
}
if sample_input is not None:
arr = np.asarray(sample_input)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
n_features = arr.shape[1]
dtype = arr.dtype.type
tensor_type = type_map.get(dtype, FloatTensorType)
return [("X", tensor_type([None, n_features]))]
# Fallback: infer from estimator metadata
n_features = getattr(estimator, "n_features_in_", None)
if n_features is None:
n_features = getattr(estimator, "_n_features_in", None)
if n_features is not None:
return [("X", FloatTensorType([None, n_features]))]
raise ValueError(
"Cannot infer input shape. Provide a sample_input array or ensure "
"the estimator has n_features_in_ set (call fit() first)."
)
def _register_gbdt_converters() -> None:
"""Register skl2onnx converters for LightGBM, XGBoost, and CatBoost.
These converters are provided by their respective ``onnxmltools``
integration packages and need to be registered before conversion.
"""
# LightGBM
try:
import lightgbm
import skl2onnx
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
convert_lightgbm, # noqa: F401
)
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
calculate_linear_regressor_output_shapes,
)
skl2onnx.update_registered_converter(
lightgbm.LGBMClassifier,
"LightGbmLGBMClassifier",
calculate_linear_classifier_output_shapes,
convert_lightgbm,
options={"zipmap": [True, False, "columns"]},
)
skl2onnx.update_registered_converter(
lightgbm.LGBMRegressor,
"LightGbmLGBMRegressor",
calculate_linear_regressor_output_shapes,
convert_lightgbm,
)
logger.debug("Registered LightGBM ONNX converters")
except ImportError:
logger.debug("LightGBM ONNX converter not available")
# XGBoost
try:
import skl2onnx
import xgboost
from onnxmltools.convert.xgboost.operator_converters.XGBoost import (
convert_xgboost, # noqa: F401
)
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
calculate_linear_regressor_output_shapes,
)
skl2onnx.update_registered_converter(
xgboost.XGBClassifier,
"XGBoostXGBClassifier",
calculate_linear_classifier_output_shapes,
convert_xgboost,
options={"zipmap": [True, False, "columns"]},
)
skl2onnx.update_registered_converter(
xgboost.XGBRegressor,
"XGBoostXGBRegressor",
calculate_linear_regressor_output_shapes,
convert_xgboost,
)
logger.debug("Registered XGBoost ONNX converters")
except ImportError:
logger.debug("XGBoost ONNX converter not available")
# CatBoost
try:
import catboost
import skl2onnx
from onnxmltools.convert.catboost.operator_converters.CatBoost import (
convert_catboost, # noqa: F401
)
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
calculate_linear_regressor_output_shapes,
)
skl2onnx.update_registered_converter(
catboost.CatBoostClassifier,
"CatBoostCatBoostClassifier",
calculate_linear_classifier_output_shapes,
convert_catboost,
)
skl2onnx.update_registered_converter(
catboost.CatBoostRegressor,
"CatBoostCatBoostRegressor",
calculate_linear_regressor_output_shapes,
convert_catboost,
)
logger.debug("Registered CatBoost ONNX converters")
except ImportError:
logger.debug("CatBoost ONNX converter not available")
def _export_skl2onnx(
estimator: Any,
path: Path,
sample_input: np.ndarray | None,
opset_version: int,
) -> Path:
"""Export using skl2onnx (sklearn, LightGBM, XGBoost, CatBoost).
Args:
estimator: Fitted sklearn-compatible estimator.
path: Output file path.
sample_input: Sample input for shape inference.
opset_version: ONNX opset version.
Returns:
Path to the saved ONNX file.
"""
skl2onnx = _ensure_skl2onnx()
onnx = _ensure_onnx()
raw_estimator = _unwrap_estimator(estimator)
# Register GBDT converters if needed
if _is_gbdt(estimator):
_register_gbdt_converters()
initial_types = _infer_initial_types(sample_input, estimator)
model_name = type(estimator).__name__
try:
onnx_model = skl2onnx.convert_sklearn(
raw_estimator,
name=model_name,
initial_types=initial_types,
target_opset=opset_version,
options={id(raw_estimator): {"zipmap": False}}
if hasattr(raw_estimator, "predict_proba")
else None,
)
except Exception as exc:
raise RuntimeError(
f"skl2onnx conversion failed for {model_name}: {exc}. "
f"Try backend='hummingbird' as an alternative."
) from exc
onnx.save_model(onnx_model, str(path))
logger.info("Exported %s to ONNX via skl2onnx: %s", model_name, path)
return path
def _export_hummingbird(
estimator: Any,
path: Path,
sample_input: np.ndarray | None,
opset_version: int,
) -> Path:
"""Export using hummingbird-ml (fallback for unsupported models).
Hummingbird converts traditional ML models to tensor computations
(PyTorch/ONNX) for faster inference. It supports a broader set of
sklearn-compatible models than skl2onnx.
Args:
estimator: Fitted sklearn-compatible estimator.
path: Output file path.
sample_input: Sample input for shape inference.
opset_version: ONNX opset version.
Returns:
Path to the saved ONNX file.
"""
hb_convert = _ensure_hummingbird()
_ensure_onnx()
raw_estimator = _unwrap_estimator(estimator)
if sample_input is not None:
test_input = np.asarray(sample_input, dtype=np.float32)
if test_input.ndim == 1:
test_input = test_input.reshape(1, -1)
else:
# Hummingbird can sometimes work without test input, but
# shape inference is more reliable with one
test_input = None
model_name = type(estimator).__name__
try:
hb_model = hb_convert(
raw_estimator,
"onnx",
test_input=test_input,
extra_config={"onnx_target_opset": opset_version},
)
hb_model.save(str(path))
except Exception as exc:
raise RuntimeError(
f"Hummingbird conversion failed for {model_name}: {exc}. "
f"This model may not be supported for ONNX export."
) from exc
logger.info(
"Exported %s to ONNX via hummingbird: %s", model_name, path
)
return path
def _export_torch(
estimator: Any,
path: Path,
sample_input: np.ndarray | None,
opset_version: int,
) -> Path:
"""Export a PyTorch-backed estimator via ``torch.onnx.export``.
Extracts the ``nn.Module`` from the estimator and traces it with
the provided sample input.
Args:
estimator: Estimator containing one or more ``nn.Module`` attributes.
path: Output file path.
sample_input: Sample input array (required for torch tracing).
opset_version: ONNX opset version.
Returns:
Path to the saved ONNX file.
Raises:
ValueError: If no sample input is provided or no ``nn.Module`` found.
"""
import torch
_ensure_onnx()
if sample_input is None:
raise ValueError(
"sample_input is required for PyTorch ONNX export. "
"Provide a representative input array."
)
modules = find_torch_modules(estimator)
if not modules:
raise ValueError(
f"No nn.Module found in {type(estimator).__name__}. "
f"Use backend='skl2onnx' or 'hummingbird' instead."
)
# Use the first module found (primary model)
attr_name, module = next(iter(modules.items()))
logger.info("Exporting nn.Module from attribute '%s'", attr_name)
module.eval()
# Prepare input tensor
arr = np.asarray(sample_input, dtype=np.float32)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
dummy_input = torch.from_numpy(arr)
# Move to same device as model
try:
device = next(module.parameters()).device
dummy_input = dummy_input.to(device)
except StopIteration:
pass # No parameters, stay on CPU
model_name = type(estimator).__name__
try:
torch.onnx.export(
module,
dummy_input,
str(path),
opset_version=opset_version,
input_names=["X"],
output_names=["output"],
dynamic_axes={
"X": {0: "batch_size"},
"output": {0: "batch_size"},
},
)
except Exception as exc:
raise RuntimeError(
f"torch.onnx.export failed for {model_name}: {exc}"
) from exc
logger.info("Exported %s to ONNX via torch: %s", model_name, path)
return path
[docs]
def export_onnx(
estimator: Any,
path: str | Path,
sample_input: np.ndarray | None = None,
opset_version: int = 15,
backend: str = "auto",
) -> str:
"""Export a fitted estimator to ONNX format.
Auto-detects the best conversion backend based on the estimator type:
- sklearn models -> skl2onnx
- Tree-based GBDTs (LightGBM, XGBoost, CatBoost) -> skl2onnx
(with registered converters from onnxmltools)
- PyTorch ``nn.Module``-backed models -> ``torch.onnx.export``
- Fallback -> hummingbird-ml
Args:
estimator: Fitted sklearn-compatible estimator.
path: Output file path. The ``.onnx`` extension is added
automatically if not present.
sample_input: Sample input array for shape inference. Required
for PyTorch models; strongly recommended for all others.
opset_version: ONNX opset version. Default is 15, which provides
broad operator coverage.
backend: Conversion backend. ``"auto"`` selects the best backend
based on estimator type. Explicit options: ``"skl2onnx"``,
``"hummingbird"``, ``"torch"``.
Returns:
Path to the saved ONNX file.
Raises:
ValueError: If the backend is unknown or input shape cannot be
inferred.
RuntimeError: If the ONNX conversion fails.
Examples:
Export a scikit-learn model::
>>> from sklearn.ensemble import RandomForestClassifier
>>> import endgame as eg
>>> model = RandomForestClassifier(n_estimators=10).fit(X, y)
>>> eg.export_onnx(model, "rf_model.onnx", sample_input=X[:1])
'rf_model.onnx'
Export a LightGBM model::
>>> from endgame.models.wrappers import LGBMWrapper
>>> model = LGBMWrapper(task='classification').fit(X, y)
>>> eg.export_onnx(model, "lgbm.onnx", sample_input=X[:1])
'lgbm.onnx'
Export with a specific backend::
>>> eg.export_onnx(model, "model.onnx", backend='hummingbird')
'model.onnx'
"""
dest = Path(path)
if dest.suffix != ONNX_EXT:
dest = dest.with_suffix(ONNX_EXT)
# Ensure parent directory exists
dest.parent.mkdir(parents=True, exist_ok=True)
resolved = _detect_onnx_backend(estimator, preferred=backend)
logger.info(
"Exporting %s with backend '%s' (requested: '%s')",
type(estimator).__name__,
resolved,
backend,
)
if resolved == "torch":
result = _export_torch(estimator, dest, sample_input, opset_version)
elif resolved == "hummingbird":
result = _export_hummingbird(
estimator, dest, sample_input, opset_version
)
else:
result = _export_skl2onnx(
estimator, dest, sample_input, opset_version
)
return str(result)