tabular_trees.ScikitLearnHistTabularTrees

class tabular_trees.ScikitLearnHistTabularTrees(data)[source]

Bases: BaseModelTabularTrees

Scikit-Learn HistGradientBoosting trees in tabular format.

The preferred way to create ScikitLearnHistTabularTrees objects is with the from_hist_gradient_booster method.

__init__(data)

Methods

__init__(data)

from_hist_gradient_booster(model)

Create ScikitLearnHistTabularTrees from hist gradient booster.

to_dataframe()

Return data for trees object.

Attributes

data

Tree data.

tree

Tree index.

node

Node index in tree.

value

Node prediction.

count

Count of rows in node from training.

feature_idx

Feature index for split.

num_threshold

Split threshold.

missing_go_to_left

Binary indicator if null values go to the left child.

left

Lift child index.

right

Right child index.

gain

Gain for split.

depth

Depth of node.

is_leaf

Leaf node indicator.

bin_threshold

is_categorical

bitset_idx

count

Count of rows in node from training.

data

Tree data.

depth

Depth of node.

feature_idx

Feature index for split.

classmethod from_hist_gradient_booster(model)[source]

Create ScikitLearnHistTabularTrees from hist gradient booster.

Parameters:

model (Union[HistGradientBoostingClassifier, HistGradientBoostingRegressor]) – Model to extract tree data from.

Returns:

trees – Model trees in tabular format.

Return type:

ScikitLearnHistTabularTrees

Examples

>>> from sklearn.datasets import load_diabetes
>>> from sklearn.ensemble import HistGradientBoostingRegressor
>>> from tabular_trees import ScikitLearnHistTabularTrees
>>> # load data
>>> diabetes = load_diabetes()
>>> # build model
>>> model = HistGradientBoostingRegressor(max_depth=3, max_iter=10)
>>> model.fit(diabetes["data"], diabetes["target"])
HistGradientBoostingRegressor(max_depth=3, max_iter=10)
>>> # export to ScikitLearnHistTabularTrees
>>> sklearn_tabular_trees = ScikitLearnHistTabularTrees.from_hist_gradient_booster(model)
>>> type(sklearn_tabular_trees)
<class 'tabular_trees.sklearn.sklearn_hist_tabular_trees.ScikitLearnHistTabularTrees'>
gain

Gain for split.

is_leaf

Leaf node indicator.

left

Lift child index.

missing_go_to_left

Binary indicator if null values go to the left child.

node

Node index in tree.

num_threshold

Split threshold.

right

Right child index.

to_dataframe()

Return data for trees object.

Returns:

trees – Model trees in DataFrame form.

Return type:

pd.DataFrame

tree

Tree index.

value

Node prediction.