December 10, 2022

Повышаем устойчивость кросс-валидации через сид

Случайными кажутся события, причины которых мы не знаем (Демокрит).

При кросс-валидационной проверке качества модели установка ее случайного инициализатора в целочисленное значение может понизить вариацию в данных. В этом случае на всех сплитах тестируется качество только одной случайной вариации алгоритма. Если же передать объект класса np.random.RandomState, то на каждом fit-е алгоритм будет подбирать разные случайные параметры.

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import numpy as np
SEED = 0
RNG = np.random.RandomState(SEED)


X, y = make_classification(n_samples=10000, n_features=5, n_informative=2, 
                           n_redundant=0, class_sep = 2, random_state=SEED, shuffle=True, 
                           flip_y=0.3, n_clusters_per_class=2)
part = int(X.shape[0]*0.8)

model = RandomForestClassifier(random_state=SEED)

display(model.fit(X[:part], y[:part]).predict(X[part:]).sum())
display(model.fit(X[:part], y[:part]).predict(X[part:]).sum())

Целочисленный сид создает одинаковые модели и получает равные результаты. Для демонстрационных целей выше выведена сумма предсказаний.

А ниже то же с использованием RNG:

model = RandomForestClassifier(random_state=RNG)

display(model.fit(X[:part], y[:part]).predict(X[part:]).sum())
display(model.fit(X[:part], y[:part]).predict(X[part:]).sum())

Такое поведение обусловлено тем, что после первого вызова fit RNG изменяется. Соответственно, так как при кросс-валидационной оценке на каждом fit-е получается немного иная вариация нашего оценщика, итоговое качество будет статистически устойчивее и меньше зависимо от случайных величин.

С учетом этого, результаты кросс-валидации при задании random_state=RNG и целым - random_state=SEED будут отличаться:

from sklearn.model_selection import cross_val_score

display(cross_val_score(RandomForestClassifier(random_state=RNG), X, y))
display(cross_val_score(RandomForestClassifier(random_state=SEED), X, y))

Таким образом, чтобы повысить статистическую достоверность кросс-валидации, используйте вместо целочисленного сида объект np.random.RandomState.