"""Module containing BoosterString class."""
from collections import OrderedDict
import lightgbm as lgb
[docs]class BoosterString:
"""String version of a LightGBM Booster."""
new_line = "\n"
[docs] def __init__(self, rows: list[str]):
"""Initialise the BoosterString from a list of strings.
Parameters
----------
rows: list[str]
List of strings defining a Booster.
"""
self.rows = rows[:]
try:
self.to_booster()
except Exception as err:
raise ValueError("supplied rows do not produce a valid booster") from err
self.row_markers, self.tree_rows = self._gather_line_markers()
[docs] @classmethod
def from_booster(cls, booster: lgb.Booster) -> "BoosterString":
"""Create BoosterString from a lgb.Booster object.
Returns
-------
model : BoosterString
Model as a BoosterString object.
"""
booster_string: str = booster.model_to_string()
booster_string_split = booster_string.split(cls.new_line)
return BoosterString(booster_string_split)
[docs] def to_booster(self) -> lgb.Booster:
"""Convert the BoosterString back to a Booster.
Returns
-------
model : lgb.Booster
BoosterString as lgb.Booster object.
"""
booster_string = self.new_line.join(self.rows)
return lgb.Booster(model_str=booster_string)
def _gather_line_markers(self) -> tuple[OrderedDict, OrderedDict]:
"""Find specific lines in the booster string data."""
row_to_find = [
"tree_sizes=",
"end of trees",
"feature_importances:",
"parameters:",
"end of parameters",
"pandas_categorical:",
]
row_to_find_description = [
"end_header",
"end_of_trees",
"feature_importances",
"parameters",
"end_of_parameters",
"pandas_categorical",
]
row_found_indexes: list[int] = []
tree_rows = OrderedDict()
tree_sizes_line = self._find_text_in_booster_data(row_to_find[0])
row_found_indexes.append(tree_sizes_line)
n_trees = self._get_number_trees_from_tree_sizes_line(
self.rows[tree_sizes_line]
)
find_tree_from_row = tree_sizes_line
for _ in range(n_trees):
tree_line = self._find_text_in_booster_data("Tree=", find_tree_from_row)
tree_number = self._get_tree_number_from_line(self.rows[tree_line])
tree_rows[tree_number] = tree_line
find_tree_from_row = tree_line + 1
find_from_row = find_tree_from_row
for string_to_find in row_to_find[1:]:
string_found_row = self._find_text_in_booster_data(
string_to_find, find_from_row
)
row_found_indexes.append(string_found_row)
find_from_row = string_found_row + 1
row_markers = OrderedDict()
for name, row_index in zip(row_to_find_description, row_found_indexes):
row_markers[name] = row_index
return row_markers, tree_rows
def _find_text_in_booster_data(self, text: str, start: int = 0) -> int:
for subet_row_number, row in enumerate(self.rows[start:]):
if row.find(text, 0, len(text)) == 0:
row_number = subet_row_number + start
return row_number
raise ValueError(
f"""unable to find row starting with text '{text}' in booster_data """
f"starting from index {start}"
)
def _get_number_of_rows(self) -> int:
return len(self.rows)
def _extract_rows(self, start: int, end: int) -> list[str]:
"""Extract booster rows within range."""
return self.rows[start:end]
def _get_number_trees_from_tree_sizes_line(self, tree_sizes_line: str) -> int:
"""Get the number of trees in the booster from the tree sizes line."""
return len(tree_sizes_line.split("=")[1].split(" "))
def _get_tree_number_from_line(self, tree_line: str) -> int:
"""Extract the tree index from the first line in a tree section."""
return int(tree_line.replace("Tree=", ""))