June 28, 2022

Разбиение датасета на групповые выборки

Зачастую модели машинного обучения обучаются на наборах векторов, характеризующих однотипные объекты в разные промежутки времени. Например, это может быть история потреблений товаров в филиалах организации. То есть датасет может быть фактически разделен на группы по относимости к филиалу. В этих случаях требования к валидации модели могут быть усилены не просто оценкой прогнозов на будущие периоды (логично модель обучать на прошлых периодах), но и стабильностью работы на новых объектах (которых в обучающей выборки не было). Для этих целей в библиотеке Scikit-learn существует специальный сплиттер GroupKFold.

Создадим демонстрационный датафрейм:

import pandas as pd

df = pd.DataFrame({'id':['id_2', 'id_1', 'id_3', 'id_3', 'id_1', 'id_4', 'id_2'],
                    'value':[23, 44, 21, 221, 2, 21, 22], 
                   },
                  index = ['one', 'two', 'three', 'four', 'five', 'six', 'seven']) 

df

Для разбиения датасета на две группы по колонке id, создадим объект класса GroupKFold с параметром n_splits=2, а затем вызовем его метод split, в параметре groups которого зададим колонку разбиения:

from sklearn.model_selection import GroupKFold
gkfold_splitter = GroupKFold(n_splits=2)

for tr_idx, ts_idx in gkfold_splitter.split(df, groups = df['id']):
    display(df.iloc[tr_idx])

Как видим, группы в получившихся наборах не пересекаются.

Не пропустите ничего интересного и подписывайтесь на страницы канала в других социальных сетях:

Яндекс Дзен

Telegram