December 16, 2022

Простая визуализация значимости признаков

Только полное осознание своей значимости — поможет вам лучше осознать значимость других (Петр Квятковский).

Рассмотрим, как извлечь и визуализировать значимость признаков для модели машинного обучения. Загрузим тренировочный набор данных:

from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True, as_frame=True)

Делить на выборки датасет не будем, так как для наших целей понадобится только один тренировочный набор. Обучим на нем решающее дерево:

from sklearn.tree import DecisionTreeClassifier
import numpy as np
RNG = np.random.RandomState(0)

model = DecisionTreeClassifier(random_state=RNG)
model.fit(X, y)

Теперь обратимся к свойству или методу (в разных библиотеках по-разному), отражающему значимость признаков для этой обученной модели. Кстати, если не помните точное название, воспользуйтесь методом исследования объектов, о котором я рассказывал ранее, ориентируйтесь на ключевое слово 'feature' или 'importance':

[it for it in dir(model) if 'feat' in it]

Вот оно:

model.feature_importances_

В качестве бонуса нашел свойство, содержащее наименование входных признаков (обычно получал через обращение к атрибуту columns в матрице признаков):

model.feature_names_in_

Теперь можно воспользоваться любым методом отображения столбчатой диаграммы (намеренно буду использовать разные):

import matplotlib.pyplot as plt

plt.figure(figsize=(20,7))
plt.bar(model.feature_names_in_, model.feature_importances_)

Обычно, особенно когда признаков много, их упорядочивают по значимости и выводят несколько ключевых:

idx = np.argsort(model.feature_importances_)[::-1]

plt.figure(figsize=(20,7))
plt.bar(model.feature_names_in_[idx], model.feature_importances_[idx])

А так можно вывести n самых важных:

import seaborn as sns
n=3
plt.figure(figsize=(20,7))
fig = sns.barplot(y=model.feature_importances_[idx][:n], x=model.feature_names_in_[idx][:n])

Если категорий много удобнее выводить горизонтальную диаграмму. Покажем как это делать с Pandas:

import pandas as pd
pd.Series(model.feature_importances_, index=model.feature_names_in_)\
            .sort_values(ascending=True).plot.barh(figsize=(20,7), grid=True)