Визуализация матрицы расхождений - ключ к пониманию ошибок классификации
«Ошибки — это наука, помогающая нам двигаться вперёд», — говорил Уильям Ченнинг. Визуализация - отличный инструмент, который помогает анализировать данные и выявлять закономерности.
Рассмотрим удобный способ отображения в Python одной из метрик классификации под названием confusion matrix (на русский переводят по-разному - матрица ошибок, неточностей, расхождений или несоответствий).
Сначала загрузим демонстрационный датасет.
import numpy as np import pandas as pd from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression df = load_iris(as_frame=True)['frame'] d = {k:v for k, v in enumerate(load_iris().target_names)} df['target'] = df['target'].map(d) df = df.sample(frac=1, random_state=0) df.head()
Имитируем обучение модели и предсказание, не проводя разбиения на выборки, так как наша цель - показать возможности для визуализации:
model = LogisticRegression(random_state=0, max_iter=1000).fit(df.drop(columns='target'), df.target) y_p = model.predict(df.drop(columns='target'))
Теперь получим значения ошибок, воспользовавшись функцией confusion_matrix из модуля sklearn.metrics:
from sklearn.metrics import confusion_matrix labels = np.sort(df['target'].unique()) cm = confusion_matrix(y_true = df['target'], y_pred = y_p, labels=labels, normalize=None) cm
В строке i и колонке j матрицы располагаются значения, соответствующие количеству объектов класса i, которые предсказаны как j. Значения классов берутся, как метки встретившиеся хотя бы раз в y_true или y_pred упорядоченные по возрастанию. Я намеренно передаю метки явно в параметре labels для демонстрации этого поведения (этим же параметром, можно поменять порядок вывода или указать подмножество/список меток для вывода).
matshow
Визуализацию начнем с низкоуровневых способов matplotlib закончим более быстрыми. Первым кандидатом будет matshow модуля matplotlib.pyplot, которая предоставляет полотно, а элементы потребуется добавлять самостоятельно. Ниже показано, как задать размеры картинки, метки классов, а также дополнить ячейки значениями матрицы через текстовые элементы (большинство настроек разбирались ранее здесь и здесь):
import matplotlib.pyplot as plt fig = plt.figure(figsize=(5, 5)) plt.matshow(cm, alpha=1, cmap='coolwarm', fignum=fig.number) plt.xticks(range(len(labels)), labels, rotation='vertical') plt.yticks(range(len(labels)), labels) plt.gca().xaxis.set_ticks_position('bottom') for i in range(cm.shape[0]): for j in range(cm.shape[1]): plt.gca().text(x=j, y=i, s=cm[i][j], ha='center', va='center') plt.xlabel('Predicted label') plt.ylabel('True label')
imshow
Функция imshow работает аналогично, только для управления размером не надо явно передавать номер фигуры (как делали выше используя fignum, так как по умолчанию matshow сама создает фигуру):
fig = plt.figure(figsize=(5, 5)) plt.imshow(cm, cmap='viridis') plt.xticks(range(len(labels)), labels, rotation='vertical') plt.yticks(range(len(labels)), labels) plt.gca().xaxis.set_ticks_position('bottom') for i in range(cm.shape[0]): for j in range(cm.shape[1]): plt.gca().text(x=j, y=i, s=cm[i][j], ha='center', va='center', color="white") plt.xlabel('Predicted label') plt.ylabel('True label')
heatmap
Для отображения с помощью функции heatmap из библиотеки seaborn надо будет внести минимальные правки (закомментировал строки, которые помогут добавить подписи):
import seaborn as sns plt.figure(figsize=(5,5)) # fmt='.1f' sns.heatmap(pd.DataFrame(cm, index=labels, columns=labels), annot=True, fmt='d', cmap='inferno', cbar=False) # plt.xticks(range(len(labels)), list(labels), rotation='vertical') # plt.yticks(range(len(labels)), labels) # plt.xlabel('Predicted label') # plt.ylabel('True label')
ConfusionMatrixDisplay
Об удобстве вывода задумались и разработчики scikit-learn, написав класс ConfusionMatrixDisplay. С помощью его методов, например, from_predictions можно сразу нарисовать картинку:
from sklearn.metrics import ConfusionMatrixDisplay ConfusionMatrixDisplay.from_predictions(y_true=df['target'], y_pred=y_p, labels=labels, colorbar=False, xticks_rotation='vertical', im_kw={'cmap':'viridis'});
Однако для более гибкой настройки лучше создать класс с помощью конструктора, инициализировав параметр confusion_matrix нашей матрицей ошибок, а затем вызвать метод plot объекта с настройками отображения:
cmp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) fig, ax = plt.subplots(figsize=(5,5)) cmp.plot(ax=ax, xticks_rotation='vertical', colorbar=False, im_kw={'cmap':'coolwarm'}, text_kw={'color':"white"})