машинное обучение
May 12

Визуализация ошибок, как навигатор к скрытым проблемам модели

Визуализация — это язык, который позволяет нам видеть данные и понимать их смысл. Простой и эффективный способ диагностики результатов работы модели на различных объектах заключается в анализе разницы между прогнозами и целями. Он может показать, что в некоторых группах поведение модели имеет особенности (например, склонность к завышению или занижению прогнозов). Для демонстрации того, как строится такая визуализация загрузим набор данных:

from sklearn.datasets import load_diabetes
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)

df, y = load_diabetes(return_X_y=True, as_frame=True)
df['target'] = y
display(df.head())
display(df.shape)

Разобьем датасет на две группы для обучения и оценки:

from sklearn.model_selection import train_test_split

X_tr, X_ts, y_tr, y_ts = train_test_split(df.drop(columns='target').copy(),
                                          df['target'], test_size=0.2)
y_tr.shape[0], y_ts.shape[0]

Теперь построим модель и сделаем предсказания:

from sklearn.linear_model import LinearRegression

model = LinearRegression()

model.fit(X_tr, y_tr)
y_p = model.predict(X_ts)

Аналитический прием, о котором шла речь выше, заключается в визуализации распределения разности между целями и прогнозами (ошибками). Это можно сделать, например, путем построения гистограммы или графика разброса точек с координатами по оси y - ошибки, x - предсказания. Объявим функцию с соответствующими свойствами и вызовем ее:

def plot_residuals(target, predictions, bins_num, figsize=(20, 8), style='seaborn'):

    error = target - predictions
    with plt.style.context(style=style):

      plt.figure(figsize=figsize)
      plt.suptitle(f'Анализ ошибок', fontsize=16)

      plt.subplot(1, 2, 1)
      plt.hist(error, edgecolor='blue', bins=bins_num)
      plt.axvline(x=0, color='black', label='ноль', linestyle='--')
      plt.axvline(x=error.median(), color='red', label='медиана')
      plt.axvline(x=error.mean(), color='orange', label='среднее')
      plt.title(f'Гистограмма ошибок', fontsize=15)
      plt.ylabel('плотность распределения', fontsize=14)
      plt.xlabel('ошибки', fontsize=14)
      plt.legend()

      plt.subplot(1, 2, 2)
      plt.scatter(predictions, error, alpha=0.4)
      plt.axhline(y=0, color='red', label='ноль', linestyle='--')
      plt.title(f'Анализ дисперсии ошибок', fontsize=15)
      plt.ylabel('ошибки', fontsize=14)
      plt.xlabel('предсказания модели', fontsize=14)
plot_residuals(y_ts, y_p, bins_num = 10, figsize=(20, 5), style='bmh')

На графике ошибки распределены равномерно относительно нуля, их среднее и медиана почти совпадают и равны 0.

Аналогичные графики можно построить с библиотекой sklearn (потребуется использовать метод from_predictions класса PredictionErrorDisplay из модуля sklearn.metrics):

from sklearn.metrics import PredictionErrorDisplay

PredictionErrorDisplay.from_predictions(y_ts, y_p)

По оси y можно вывести вместо ошибок реальные значения (цели) против предсказанных по оси x:

PredictionErrorDisplay.from_predictions(y_ts, y_p, kind='actual_vs_predicted')

А теперь для демонстрационных целей добавим выброс в виде новой точки с очень большой целью и снова обучим модель:

model.fit(pd.concat([X_tr, X_tr.iloc[[-1]]]),
          pd.concat([y_tr.to_frame(), pd.Series([1e10]).to_frame('target')])['target'])

y_p = model.predict(X_ts)

plot_residuals(y_ts, y_p, bins_num = 10, figsize=(20, 5), style='bmh')

После обучения на большом выбросе, модель для одних точек стала сильно занижать предсказания, а для других - завышать. Такой эффект может быть вызван не только выбросами в обучающих данных, но и простотой модели, плохими признаками. Например, вы пытаетесь предсказать цену квартиры, основываясь только на площади, и не учитываете особенности района, близость к метро. Тогда, если в датасете имеется перекос в сторону квартир из элитных райнов, фактор площади может остаться недооцененным.

Таким образом, неравномерность распределения ошибок относительно нуля является индикатором того, что модель требует доработки и оптимизации, а причины могут от случая к случаю меняться.