Какие есть функции скоринга в Scikit-learn и как сделать свою
На простом примере рассмотрим, где найти метрики измерения качества модели и как создать собственный "оценщик", поддерживаемый в функциях кросс-валидации и подбора гиперпараметров модели из sklearn. Сначала создадим демонстрационный набор данных:
import pandas as pd
from sklearn.datasets import make_classification
import numpy as np
np.random.seed(0)
ar, y = make_classification(n_samples=10000, n_features=3, n_informative=2,
n_redundant=0, class_sep = 0.2, shuffle=False,
flip_y=0, n_clusters_per_class=2)
df = pd.DataFrame(ar, columns = ['feat1', 'feat2', 'feat3'])
df['y'] = y
# мешаем
df = df.sample(frac=1).reset_index(drop=True)
dffrom sklearn.model_selection import train_test_split
X_cv, X_ts, y_cv, y_ts = train_test_split(df.drop(columns='y').copy(),
df['y'], test_size=0.2)
X_cv.shape, X_ts.shapefrom sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(max_depth = 9, min_samples_split=4, random_state=0)
Для вывода метрик качества на кросс-валидации можно воспользоваться функцией cross_validate:
from sklearn.model_selection import cross_validate cross_validate(clf, X_cv, y_cv, scoring='f1')
В параметре scoring передаются метрики для оценивания, полный их список в sklearn можно найти здесь.
Следует отметить, что в cross_validate, GridSearchCV, RandomizedSearchCV в scoring можно передавать и списки метрик:
res = cross_validate(clf, X_cv, y_cv, scoring=['f1', 'accuracy']) res
Вместе с тем зачастую возникает необходимость создать пользовательский метод оценки качества и передать его в перечисленные функции. Для этого следует воспользоваться make_scorer, в которую передается кастомный оценщик и направление оптимизации (большие или меньшие значения). Допустим, в качестве пользовательской метрики вы хотите использовать формулу f1+0.5*accuracy:
from sklearn.metrics import f1_score, accuracy_score, make_scorer
def f1_acc(target, pred):
return f1_score(target, pred)+1/2*accuracy_score(target, pred)
scorer_ = make_scorer(f1_acc, greater_is_better=True)
cross_validate(clf, X_cv, y_cv, scoring=scorer_)res['test_accuracy']*0.5+res['test_f1']
Не пропустите ничего интересного и подписывайтесь на страницы канала в других социальных сетях: