machine learning
September 23, 2023

Предсказание цены автомобиля с помощью методов машинного обучения

В прошлом посте провел разведочный анализ данных по автомобилям собранным с сайта.

В этом посте применим методы машинного обучения чтобы попытаться предсказать цену автомобиля.

Вот несколько алгоритмов, которые могут подойти для этой задачи:

  1. Линейная регрессия (Linear Regression):
    • Простой и интерпретируемый метод.
    • Хорошо работает, если существует линейная зависимость между признаками и целевой переменной.
  2. Решающие деревья (Decision Trees) и Случайный лес (Random Forest):
    • Могут улавливать нелинейные зависимости.
    • Random Forest обычно предоставляет более точные прогнозы, чем отдельное решающее дерево, за счет усреднения прогнозов множества деревьев.
  3. Градиентный бустинг (Gradient Boosting), например, XGBoost или LightGBM:
    • Эффективные алгоритмы, которые часто показывают высокую производительность в задачах регрессии.
    • Они строят ансамбль деревьев последовательно, каждое следующее дерево пытается исправить ошибки предыдущих.
  4. Нейронные сети (Neural Networks):
    • Могут быть полезными, если у вас большое количество данных.
    • Способны улавливать сложные нелинейные зависимости.

Линейная регрессия

Линейная регрессия предполагает, что мы хотим прогнозировать одну переменную на основе других переменных.

# Разделение данных на обучающие и тестовые 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Создание и обучение модели 
LR = LinearRegression() 
LR.fit(X_train, y_train)
# Прогнозирование 
y_pred = LR.predict(X_test)
# Оценка модели 
mse = mean_squared_error(y_test, y_pred) 
rmse = mean_squared_error(y_test, y_pred, squared=False) 
mae = mean_absolute_error(y_test, y_pred) 
r2 = r2_score(y_test, y_pred) 
adj_r2 = 1 - (1-r2)*(len(y_test)-1)/(len(y_test)-X_test.shape[1]-1) 
print(f'Mean Squared Error: {mse}') 
print(f'Root Mean Squared Error: {rmse}') 
print(f'Mean Absolute Error: {mae}') 
print(f'R^2: {r2}') 
print(f'Adjusted R^2: {adj_r2}')
Mean Squared Error: 16902445328.964172 
Root Mean Squared Error: 130009.40477120943 
Mean Absolute Error: 102777.23299243717 
R^2: 0.3387269809586828 Adjusted 
R^2: 0.3318098991695476

интерпретируем метрики:

  1. Mean Squared Error (MSE): 16,902,445,328.96
    MSE является мерой качества, где меньшее значение MSE указывает на лучшее качество. Это значение довольно высокое, что может указывать на наличие больших ошибок между фактическими и прогнозируемыми значениями.
  2. Root Mean Squared Error (RMSE): 130,009.40
    RMSE интерпретируется в тех же единицах измерения, что и исходные данные (в данном случае, цена). Модель ошибается в среднем на 130,009.40 рублей при прогнозировании цены.
  3. Mean Absolute Error (MAE): 102,777.23
    MAE представляет собой среднюю абсолютную ошибку между прогнозируемыми и фактическими значениями. Это говорит о том, что модель в среднем ошибается на 102,777.23 единиц.
  4. R^2: 0.3387
    R^2 измеряет долю дисперсии зависимой переменной, объясненную моделью. Значение 0.3387 говорит о том, модель объясняет только 33.87% дисперсии в данных. Это довольно низкое значение, что указывает на то, что модель может быть не очень хорошо подобрана или что может быть много нерассмотренных или нерелевантных признаков.
  5. Adjusted R^2: 0.3318
    Это корректировка R^2, учитывающая количество признаков в модели. Поскольку оно близко к обычному R^2

В целом, на основе представленных метрик, модель не идеальна и может потребовать доработки. Это может включать добавление новых признаков, преобразование существующих признаков, проверку на наличие выбросов в данных или использование другой модели для прогнозирования.

Настройка модели

Проверяем на наличием выбросов.

Красным выделены значения которые считаются выбросами
#Есть выбросы их надо удалить
for column in ['price', 'Year','Mileage','Power']:
  # если столбец числовой  
  Q1 = data[column].quantile(0.25)  
  Q3 = data[column].quantile(0.75)  IQR = Q3 - Q1
  # Границы выбросов  
  lower_bound = Q1 - 1.5 * IQR  
  upper_bound = Q3 + 1.5 * IQR
  # Отфильтровать выбросы  
  data = data[(data[column] >= lower_bound) & (data[column] <= upper_bound)]

Удаление выбросов не помогло.

Полный код есть на github