машинное обучение
January 2, 2022

Машинное обучение на основе правил

Методы машинного обучения известны своей таинственностью, волшебством и сложной интерпретируемостью. Вместе с тем это не всегда так - существуют алгоритмы достаточно мощные и одновременно обладающие простым объяснением.

Одним из таких является классификатор на основе правил. Готовая реализация продвинутых алгоритмов (IREP и RIPPER) имеется в библиотеке wittgenstein, чтобы установить которую достаточно воспользоваться менеджером пакетов (pip install wittgenstein).

В качестве примера возьмем набор для классификации грибов по (съедобен/ядовит) в зависимости от ряда внешних параметров:

import pandas as pd
import wittgenstein as lw

mush_df = pd.read_csv('data/mushrooms.csv')
mush_df.info()

Колонки содержат категориальные данные и в целях оптимизации могут быть преобразованы к этому типу:

for col in mush_df.columns:
    mush_df[col] = mush_df[col].astype('category')
    print(f'столбец {col}, уникальных - {len(mush_df[col].unique())}')

Теперь датафрейм занимает почти в 10 раз меньше места:

Перейдем обучению. В пакете wittgenstein реализованы классы IREP и RIPPER которые представляют одноименные алгоритмы. В качестве параметров в конструктор можно передать:

  • prune_size - доля тренировочного набора, которая участвует в стадии сокращения (корректировки) правила, полученного с помощью других образцов тренировочного набора (по примеру перекрестной проверки);
  • max_rules - максимальное количество правил;
  • max_rule_conds - максимальное количество условий в одном правиле;
  • random_state - инициализатор случайных значений.

Создадим объект классификатора и разделим данные на обучающую и тестовую выборки:

# clf = lw.IREP()
clf = lw.RIPPER()
from sklearn.model_selection import train_test_split
mush_tr, mush_ts = train_test_split(mush_df, test_size=0.2)

Обучение осуществляется путем вызова метода fit. Передать признаки и target в него можно двумя способами:

clf.fit(mush_tr, class_feat='class', pos_class='p')
# clf.fit(mush_tr.drop('class', axis=1), mush_tr['class'].map({'p':1, 'e':0}).astype(int))

В первой форме указывается датафрейм, а затем имя столбца target-а в нем, также если цель строчная (в нашем случае значения p и e), то нужно указать наименование класса 1. Во втором привычный синтаксис sklearn - датафрейм с признаками и отдельно колонка цели (однако ее надо будет дополнительно закодировать).

Предсказания осуществляются методом predict, который содержит опцию возврата не только предсказаний, но и правил для точек с положительным классом (параметр give_reasons=True):

y_pr = clf.predict(mush_ts)
clf.predict(mush_ts, give_reasons=True)[1][:5]

Для оценки классификатора можно воспользоваться стандартными метриками scikit-learn или встроенным методом score:

from sklearn.metrics import classification_report
print(clf.score(mush_ts, y_pr, score_function=classification_report))
print(classification_report(mush_ts['class'].map({'p':True, 'e':False}).astype(bool), y_pr))

У классификатора имеется опция вывода вероятности прогноза с методом predict_proba.

Также для получения списка правил, на основании которых прогнозируется положительный класс (если ни одно не сработало - отрицательный) предусмотрен метод out_model или свойство ruleset_:

# clf.ruleset_
clf.out_model()

V - означает ИЛИ (фактически разделяет правила), а ^ - И (формирует набор условий в рамках одного правила). В нашем случае классификатор со 100% точностью формирует 8 правил для определения ядовитых грибов, остальные - съедобные. Теперь можно идти в лес собирать грибы без опаски!

Не пропустите ничего интересного и подписывайтесь на страницы канала в других социальных сетях:

Instagram

Яндекс Дзен