машинное обучение
March 8, 2023

Функции sklearn, с которыми понимание работы дерева решений сильно облегчится 

Рассмотрим пример построения дерева решений и работы модели на примере классификации цветков Ириса:

from sklearn.datasets import load_iris
iris_df = load_iris(as_frame=True)['frame']
iris_df.head()

Обучим классификатор:

from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier(random_state=0).fit(iris_df.drop(columns='target'), iris_df.target)
features_l = iris_df.drop(columns='target').columns.tolist()

Визуализация дерева

В модуле sklearn.tree есть функция plot_tree, с которой можно легко нарисовать дерево, для каждого узла включается признак ветвления, граница, загрязненность, количество примеров всего и их распределение по классам:

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(18,7))
_ = plot_tree(model, feature_names = features_l)

Экспорт в текст

Есть и текстовое представление того же дерева, которое можно получить при помощи export_text:

from sklearn.tree import export_text
print(export_text(model, feature_names=features_l))

Атрибуты работы

Более детальные единицы работы алгоритма извлекаются из свойств атрибута tree_: номера левого и правого узлов ветвления (children_left, children_right) и "загрязненность" (impurities) для каждого узла, номера признаков ветвления (features), соответствующие им границы (threshold). Проще поместить эти элементы в один датафрейм и получится целостная картина работы дерева:

import pandas as pd

children_left = model.tree_.children_left
children_right = model.tree_.children_right
features = model.tree_.feature
names = [features_l[i] if i>0 else '' for i in features ]
thres = model.tree_.threshold
impurities = model.tree_.impurity

pd.DataFrame({'children_left':children_left, 'children_right':children_right, 
              'names':names, 'features':features, 'thresholds':thres, 
              'impurities':impurities})