"""Scikit-learn histogram based GBM trees in tabular format."""
from dataclasses import dataclass, field
from typing import Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
try:
from sklearn.ensemble import ( # type: ignore[import-not-found]
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
except ModuleNotFoundError as err:
raise ImportError(
"scikit-learn must be installed to use functionality in sklearn module"
) from err
from .. import checks
from ..trees import BaseModelTabularTrees, export_tree_data
[docs]@dataclass
class ScikitLearnHistTabularTrees(BaseModelTabularTrees):
"""Scikit-Learn HistGradientBoosting trees in tabular format.
The preferred way to create ScikitLearnHistTabularTrees objects is with the
from_hist_gradient_booster method.
"""
data: pd.DataFrame
"""Tree data."""
tree: NDArray[np.int_] = field(init=False, repr=False)
"""Tree index."""
node: NDArray[np.int_] = field(init=False, repr=False)
"""Node index in tree."""
value: NDArray[np.float64] = field(init=False, repr=False)
"""Node prediction."""
count: NDArray[np.int_] = field(init=False, repr=False)
"""Count of rows in node from training."""
feature_idx: NDArray[np.int_] = field(init=False, repr=False)
"""Feature index for split."""
num_threshold: NDArray[np.float64] = field(init=False, repr=False)
"""Split threshold."""
missing_go_to_left: NDArray[np.int_] = field(init=False, repr=False)
"""Binary indicator if null values go to the left child."""
left: NDArray[np.int_] = field(init=False, repr=False)
"""Lift child index."""
right: NDArray[np.int_] = field(init=False, repr=False)
"""Right child index."""
gain: NDArray[np.float64] = field(init=False, repr=False)
"""Gain for split."""
depth: NDArray[np.int_] = field(init=False, repr=False)
"""Depth of node."""
is_leaf: NDArray[np.int_] = field(init=False, repr=False)
"""Leaf node indicator."""
bin_threshold: NDArray[np.int_] = field(init=False, repr=False)
is_categorical: NDArray[np.int_] = field(init=False, repr=False)
bitset_idx: NDArray[np.int_] = field(init=False, repr=False)
[docs] @classmethod
def from_hist_gradient_booster( # type: ignore[no-any-unimported]
cls, model: Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor]
) -> "ScikitLearnHistTabularTrees":
"""Create ScikitLearnHistTabularTrees from hist gradient booster.
Parameters
----------
model : Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor]
Model to extract tree data from.
Returns
-------
trees : ScikitLearnHistTabularTrees
Model trees in tabular format.
Examples
--------
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.ensemble import HistGradientBoostingRegressor
>>> from tabular_trees import ScikitLearnHistTabularTrees
>>> # load data
>>> diabetes = load_diabetes()
>>> # build model
>>> model = HistGradientBoostingRegressor(max_depth=3, max_iter=10)
>>> model.fit(diabetes["data"], diabetes["target"])
HistGradientBoostingRegressor(max_depth=3, max_iter=10)
>>> # export to ScikitLearnHistTabularTrees
>>> sklearn_tabular_trees = ScikitLearnHistTabularTrees.from_hist_gradient_booster(model)
>>> type(sklearn_tabular_trees)
<class 'tabular_trees.sklearn.sklearn_hist_tabular_trees.ScikitLearnHistTabularTrees'>
""" # noqa: E501
checks.check_type(
model,
(HistGradientBoostingClassifier, HistGradientBoostingRegressor),
"model",
)
if not model._is_fitted():
raise ValueError("model is not fitted, cannot export trees")
if len(model._predictors[0]) > 1:
raise NotImplementedError("model with multiple responses not supported")
tree_data = ScikitLearnHistTabularTrees._extract_hist_gbm_tree_data(model)
return ScikitLearnHistTabularTrees(tree_data)
@staticmethod
def _extract_hist_gbm_tree_data( # type: ignore[no-any-unimported]
model: Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor],
) -> pd.DataFrame:
"""Extract tree data from HistGradientBoosting model.
Tree data is pulled from _predictors attributes in
HistGradientBoostingClassifier or HistGradientBoostingRegressor object.
Parameters
----------
model : Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor]
Model to extract tree data from.
"""
tree_data_list = []
for tree_no in range(model.n_iter_):
tree_df = pd.DataFrame(model._predictors[tree_no][0].nodes)
tree_df["tree"] = tree_no
tree_df = tree_df.reset_index().rename(columns={"index": "node"})
tree_data_list.append(tree_df)
tree_data = pd.concat(tree_data_list, axis=0)
starting_value = (
ScikitLearnHistTabularTrees._get_starting_value_hist_gradient_booster(model)
)
tree_data.loc[tree_data["tree"] == 0, "value"] = (
tree_data.loc[tree_data["tree"] == 0, "value"] + starting_value
)
return tree_data
@staticmethod
def _get_starting_value_hist_gradient_booster( # type: ignore[no-any-unimported]
model: Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor],
) -> Union[int, float]:
"""Extract the initial prediction for the ensemble."""
return model._baseline_prediction[0][0] # type: ignore[no-any-return]
@export_tree_data.register(HistGradientBoostingClassifier)
@export_tree_data.register(HistGradientBoostingRegressor)
def _export_tree_data__hist_gradient_boosting_model( # type: ignore[no-any-unimported]
model: Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor],
) -> ScikitLearnHistTabularTrees:
"""Export tree data from HistGradientBoostingRegressor or Classifier object.
Parameters
----------
model : Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor]
Model to export tree data from.
"""
return ScikitLearnHistTabularTrees.from_hist_gradient_booster(model)