Вскрытие AI модели или как заглянуть внутрь без скальпеля
Как понять причину неполадок в работе AI модели, не вдаваясь в сложности внутреннего механизма работы? Один из способов реализуем в этой статье.
В качестве площадки для экспериментов будем использовать сгенерированный ранее набор данных для предсказания расходов людей по их доходам и размерам дотаций из городского бюджета (подробнее здесь):
Также будем использовать обученную на этих данных модель градиентного бустинга CatBoost, построенную в этом примере.
Основная идея метода заключается в исследовании результатов работы модели при последовательном изменении каждого из параметров объекта в рамках некоторого диапазона. Это поможет локализовать набор факторов, оказывающих реальное влияние на конкретный прогноз.
Предполагается, что в нашем распоряжении имеется матрица с прогнозами, где находятся признаки интересуемого объекта перед предсказанием. На первом шаге загрузим матрицу признаков (df_pred) и выберем нужную строку (etal_str, пусть нулевая):
import pandas as pd from catboost import CatBoostRegressor import numpy as np pred_fn = 'pred_income.xlsx' df_pred = pd.read_excel(pred_fn) etal_str = df_pred.iloc[[0]]
Далее зададим словарь с диапазонами значений параметров объекта (inds_values_d), указав в качестве ключа - индекс фактора и значения - список. Также предусмотрим возможность задания в качестве ключей имен столбцов (names_values_d), но в этом случае они будут преобразовываться в индексы и только в таком виде поступать в последующую обработку:
# selecting column names and values names_values_d = {'зарплата':np.arange(1,100,10), 'сумма_помощи':[70, 4,-1]} colname_pos_d = {it:pos for pos,it in enumerate(df_pred.columns)} inds_values_d = {colname_pos_d[key]:value for key, value in names_values_d.items()} # or selecting columns inds and values # inds_values_d = {0:[1,2,4], 4:[7, 4,-1]}
В блокноте это выглядит следующим образом:
Далее на базе интересуемой строки сформируем дополнительные с одним модифицируемым значением и склеим их в датафрейм:
# generating rows for predict and gathering them together pred_str_l = [] # first row will be unchanged pred_str_l.append(etal_str.copy()) for key in inds_values_d.keys(): for value in inds_values_d[key]: pred_str = etal_str.copy() pred_str.iloc[0,key]=value pred_str['field']=key pred_str['value']=value pred_str_l.append(pred_str) df_synth_pred = pd.concat(pred_str_l, ignore_index=True)
Осталось загрузить модель и получить таблицу с предсказаниями:
reg = CatBoostRegressor() reg.load_model("model") df_synth_pred['pred']=reg.predict(df_synth_pred.drop(['field', 'value'], axis=1)) df_synth_pred[['field','value','pred']+list(df_synth_pred.columns[:-3])]
Как можно заметить, ввиду особенности наших данных диапазон зарплат до 21 не влияет на прогноз (в обучающей выборке такие люди отсутствовали, и модель никаких выводов не сделала).
Теперь объединим сделанное в функцию, которая будет получать строку с признаками объекта, словарь модицикаций, модель, имя выходного файла и список признаков (последние два поля необязательные):
# вход столбец датафрейма со строкой примером, список_стб либо None, словарь с стб и значениями def pred_synth_output(etal_str, col_val_d, model, fn_out=None, col_names=None): if col_names: colname_pos_d = {it:pos for pos,it in enumerate(col_names)} col_val_d = {colname_pos_d[key]:value for key, value in col_val_d.items()} pos_colname_d = {val:key for key, val in colname_pos_d.items()} # generating rows for predict and gathering them together pred_str_l = [] # first row will be unchanged pred_str_l.append(etal_str.copy()) for key in col_val_d.keys(): for value in col_val_d[key]: pred_str = etal_str.copy() pred_str.iloc[0,key]=value pred_str['field']=pos_colname_d[key] if col_names else key pred_str['value']=value pred_str_l.append(pred_str) df_synth_pred = pd.concat(pred_str_l, ignore_index=True) df_synth_pred['pred']=model.predict(df_synth_pred.drop(['field', 'value'], axis=1)) if fn_out: df_synth_pred[['field','value','pred']+list(df_synth_pred.columns[:-3])].to_excel(fn_out, index=False) else: return df_synth_pred[['field','value','pred']+list(df_synth_pred.columns[:-3])]
А так она применяется (отображен не весь вывод):