Разбиение датасета на групповые выборки
Зачастую модели машинного обучения обучаются на наборах векторов, характеризующих однотипные объекты в разные промежутки времени. Например, это может быть история потреблений товаров в филиалах организации. То есть датасет может быть фактически разделен на группы по относимости к филиалу. В этих случаях требования к валидации модели могут быть усилены не просто оценкой прогнозов на будущие периоды (логично модель обучать на прошлых периодах), но и стабильностью работы на новых объектах (которых в обучающей выборки не было). Для этих целей в библиотеке Scikit-learn существует специальный сплиттер GroupKFold.
Создадим демонстрационный датафрейм:
import pandas as pd df = pd.DataFrame({'id':['id_2', 'id_1', 'id_3', 'id_3', 'id_1', 'id_4', 'id_2'], 'value':[23, 44, 21, 221, 2, 21, 22], }, index = ['one', 'two', 'three', 'four', 'five', 'six', 'seven']) df
Для разбиения датасета на две группы по колонке id, создадим объект класса GroupKFold с параметром n_splits=2, а затем вызовем его метод split, в параметре groups которого зададим колонку разбиения:
from sklearn.model_selection import GroupKFold gkfold_splitter = GroupKFold(n_splits=2) for tr_idx, ts_idx in gkfold_splitter.split(df, groups = df['id']): display(df.iloc[tr_idx])
Как видим, группы в получившихся наборах не пересекаются.
Не пропустите ничего интересного и подписывайтесь на страницы канала в других социальных сетях: