Простая визуализация значимости признаков
Только полное осознание своей значимости — поможет вам лучше осознать значимость других (Петр Квятковский).
Рассмотрим, как извлечь и визуализировать значимость признаков для модели машинного обучения. Загрузим тренировочный набор данных:
from sklearn.datasets import load_iris X, y = load_iris(return_X_y=True, as_frame=True)
Делить на выборки датасет не будем, так как для наших целей понадобится только один тренировочный набор. Обучим на нем решающее дерево:
from sklearn.tree import DecisionTreeClassifier import numpy as np RNG = np.random.RandomState(0) model = DecisionTreeClassifier(random_state=RNG) model.fit(X, y)
Теперь обратимся к свойству или методу (в разных библиотеках по-разному), отражающему значимость признаков для этой обученной модели. Кстати, если не помните точное название, воспользуйтесь методом исследования объектов, о котором я рассказывал ранее, ориентируйтесь на ключевое слово 'feature' или 'importance':
[it for it in dir(model) if 'feat' in it]
model.feature_importances_
В качестве бонуса нашел свойство, содержащее наименование входных признаков (обычно получал через обращение к атрибуту columns в матрице признаков):
model.feature_names_in_
Теперь можно воспользоваться любым методом отображения столбчатой диаграммы (намеренно буду использовать разные):
import matplotlib.pyplot as plt plt.figure(figsize=(20,7)) plt.bar(model.feature_names_in_, model.feature_importances_)
Обычно, особенно когда признаков много, их упорядочивают по значимости и выводят несколько ключевых:
idx = np.argsort(model.feature_importances_)[::-1] plt.figure(figsize=(20,7)) plt.bar(model.feature_names_in_[idx], model.feature_importances_[idx])
А так можно вывести n самых важных:
import seaborn as sns n=3 plt.figure(figsize=(20,7)) fig = sns.barplot(y=model.feature_importances_[idx][:n], x=model.feature_names_in_[idx][:n])
Если категорий много удобнее выводить горизонтальную диаграмму. Покажем как это делать с Pandas:
import pandas as pd pd.Series(model.feature_importances_, index=model.feature_names_in_)\ .sort_values(ascending=True).plot.barh(figsize=(20,7), grid=True)