Какие есть функции скоринга в Scikit-learn и как сделать свою
На простом примере рассмотрим, где найти метрики измерения качества модели и как создать собственный "оценщик", поддерживаемый в функциях кросс-валидации и подбора гиперпараметров модели из sklearn. Сначала создадим демонстрационный набор данных:
import pandas as pd from sklearn.datasets import make_classification import numpy as np np.random.seed(0) ar, y = make_classification(n_samples=10000, n_features=3, n_informative=2, n_redundant=0, class_sep = 0.2, shuffle=False, flip_y=0, n_clusters_per_class=2) df = pd.DataFrame(ar, columns = ['feat1', 'feat2', 'feat3']) df['y'] = y # мешаем df = df.sample(frac=1).reset_index(drop=True) df
from sklearn.model_selection import train_test_split X_cv, X_ts, y_cv, y_ts = train_test_split(df.drop(columns='y').copy(), df['y'], test_size=0.2) X_cv.shape, X_ts.shape
from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(max_depth = 9, min_samples_split=4, random_state=0)
Для вывода метрик качества на кросс-валидации можно воспользоваться функцией cross_validate:
from sklearn.model_selection import cross_validate cross_validate(clf, X_cv, y_cv, scoring='f1')
В параметре scoring передаются метрики для оценивания, полный их список в sklearn можно найти здесь.
Следует отметить, что в cross_validate, GridSearchCV, RandomizedSearchCV в scoring можно передавать и списки метрик:
res = cross_validate(clf, X_cv, y_cv, scoring=['f1', 'accuracy']) res
Вместе с тем зачастую возникает необходимость создать пользовательский метод оценки качества и передать его в перечисленные функции. Для этого следует воспользоваться make_scorer, в которую передается кастомный оценщик и направление оптимизации (большие или меньшие значения). Допустим, в качестве пользовательской метрики вы хотите использовать формулу f1+0.5*accuracy:
from sklearn.metrics import f1_score, accuracy_score, make_scorer def f1_acc(target, pred): return f1_score(target, pred)+1/2*accuracy_score(target, pred) scorer_ = make_scorer(f1_acc, greater_is_better=True) cross_validate(clf, X_cv, y_cv, scoring=scorer_)
res['test_accuracy']*0.5+res['test_f1']
Не пропустите ничего интересного и подписывайтесь на страницы канала в других социальных сетях: