Функции sklearn, с которыми понимание работы дерева решений сильно облегчится
Рассмотрим пример построения дерева решений и работы модели на примере классификации цветков Ириса:
from sklearn.datasets import load_iris iris_df = load_iris(as_frame=True)['frame'] iris_df.head()
from sklearn.tree import DecisionTreeClassifier model = DecisionTreeClassifier(random_state=0).fit(iris_df.drop(columns='target'), iris_df.target) features_l = iris_df.drop(columns='target').columns.tolist()
Визуализация дерева
В модуле sklearn.tree есть функция plot_tree, с которой можно легко нарисовать дерево, для каждого узла включается признак ветвления, граница, загрязненность, количество примеров всего и их распределение по классам:
from sklearn.tree import plot_tree import matplotlib.pyplot as plt plt.figure(figsize=(18,7)) _ = plot_tree(model, feature_names = features_l)
Экспорт в текст
Есть и текстовое представление того же дерева, которое можно получить при помощи export_text:
from sklearn.tree import export_text print(export_text(model, feature_names=features_l))
Атрибуты работы
Более детальные единицы работы алгоритма извлекаются из свойств атрибута tree_: номера левого и правого узлов ветвления (children_left, children_right) и "загрязненность" (impurities) для каждого узла, номера признаков ветвления (features), соответствующие им границы (threshold). Проще поместить эти элементы в один датафрейм и получится целостная картина работы дерева:
import pandas as pd children_left = model.tree_.children_left children_right = model.tree_.children_right features = model.tree_.feature names = [features_l[i] if i>0 else '' for i in features ] thres = model.tree_.threshold impurities = model.tree_.impurity pd.DataFrame({'children_left':children_left, 'children_right':children_right, 'names':names, 'features':features, 'thresholds':thres, 'impurities':impurities})