Пошаговый пример линейного дискриминантного анализа в Python
Линейный дискриминантный анализ (LDA) — это метод, используемый для классификации объектов по набору признаков. Рассмотрим пошаговый базовый пример подобной задачи.
Шаг 1. Загрузка необходимых библиотек
Для начала импортируем все необходимые библиотеки, функции и классы:
from sklearn.model_selection import train_test_split from sklearn.model_selection import RepeatedStratifiedKFold from sklearn.model_selection import cross_val_score from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn import datasets import matplotlib.pyplot as plt import pandas as pd import numpy as np
Шаг 2. Загрузка данных
В качестве примера возьмём набор данных по цветкам ириса из библиотеки sklearn
. В приведённом ниже участке кода показана загрузка этого набора данных и его преобразование в датафрейм.
# Загружаем набор данных по ирисам iris = datasets.load_iris() # Создаём датафрейм df = pd.DataFrame(data = np.c_[iris['data'], iris['target']], columns = iris['feature_names'] + ['target']) df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names) df.columns = ['s_length', 's_width', 'p_length', 'p_width', 'target', 'species'] # Выводим первые 5 строк содержимого датафрейма display(df.head()) s_length s_width p_length p_width target species 0 5.1 3.5 1.4 0.2 0.0 setosa 1 4.9 3.0 1.4 0.2 0.0 setosa 2 4.7 3.2 1.3 0.2 0.0 setosa 3 4.6 3.1 1.5 0.2 0.0 setosa 4 5.0 3.6 1.4 0.2 0.0 setosa # Сколько всего строк в датафрейме? print(len(df.index)) # 150
В наборе данных представлено 150 цветков. Построим модель, которая по характеристикам цветка будет определять, к какому виду он относится.
В качестве факторных переменных используются:
Шаг 3. Обучение модели
Обучим модельку LDA при помощи класса LinearDiscriminantAnalysis из sklearn
:
# Обучаем модель X = df[['s_length', 's_width', 'p_length', 'p_width']] y = df['species'] model = LinearDiscriminantAnalysis() model.fit(X, y)
Шаг 4. Используем модель для прогнозов
После того, как мы обучили модель, можно оценить её качество при помощи стратифицированной k-fold кросс-валидации (repeated stratified k-fold cross validation).
Для примера будем использовать 10 отложенных выборок (фолдов) и параметр repeats=3.
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=42) # Оцениваем модель scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1) print(np.mean(scores)) # 0.9777777777777779
Видим, что средняя точность на 10 фолдах получилась примерно 0.978.
# зададим значения признаков для примера new = [5, 3, 1, 0.4] # прогнозируем, какой класс получим на основе этих значений model.predict([new]) # array(['setosa'], dtype='<U10')
Видим, что модель нвоый объект отнесла к классу setosa.
Шаг 5. Визуализируем результаты
И наконец, мы можем построить визуализацию результатов LDA для того, чтобы понять насколько хорошо модельразделяет три различных вида цветков в нашем наборе данных:
# define data to plot X = iris.data y = iris.target model = LinearDiscriminantAnalysis() data_plot = model.fit(X, y).transform(X) target_names = iris.target_names # создаём график plt.figure() colors = ['red', 'green', 'blue'] lw = 2 for color, i, target_name in zip(colors, [0, 1, 2], target_names): plt.scatter(data_plot[y == i, 0], data_plot[y == i, 1], alpha=.8, color=color, label=target_name) # добавляем легенду plt.legend(loc='best', shadow=False, scatterpoints=1) # отображаем plt.show()
Получаем наглядный результат разделения объектов на 3 класса:
Как вы можете заметить, линейный дискриминантный анализ в Python выполняется крайне просто, особенно если у вас уже есть очищенный и обработанный набор данных :)
Источник: Statology