November 24, 2023
Random Forest with Confusion Matrix and Decision tree plot
# Import necessary libraries import pandas as pd from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree # Added plot_tree from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score, classification_report, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns # Load the dataset url = '/home/ignis/400/pima_rf.csv' data = pd.read_csv(url) # Display the first few rows of the dataset print(data.head()) # Define features (X) and target variable (y) X = data.drop('Outcome', axis=1) y = data['Outcome'] # Split the dataset into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Build a Decision Tree model dt_model = DecisionTreeClassifier(random_state=42) dt_model.fit(X_train, y_train) # Make predictions on the test set dt_predictions = dt_model.predict(X_test) # Evaluate Decision Tree performance dt_accuracy = accuracy_score(y_test, dt_predictions) print("Decision Tree Classifier:") print("Accuracy:", dt_accuracy) print("Classification Report:") print(classification_report(y_test, dt_predictions)) # Plot the confusion matrix for Decision Tree dt_cm = confusion_matrix(y_test, dt_predictions) plt.figure(figsize=(5, 4)) sns.heatmap(dt_cm, annot=True, fmt="d", cmap="Blues", cbar=False) plt.title("Decision Tree Confusion Matrix") plt.show() # Plot the Decision Tree plt.figure(figsize=(15, 10)) plot_tree(dt_model, feature_names=X.columns, class_names=["0", "1"], filled=True, rounded=True) plt.title("Decision Tree Visualization") plt.show() # Build a Random Forest model rf_model = RandomForestClassifier(random_state=42) rf_model.fit(X_train, y_train) # Make predictions on the test set rf_predictions = rf_model.predict(X_test) # Evaluate Random Forest performance rf_accuracy = accuracy_score(y_test, rf_predictions) print("\nRandom Forest Classifier:") print("Accuracy:", rf_accuracy) print("Classification Report:") print(classification_report(y_test, rf_predictions)) # Plot the confusion matrix for Random Forest rf_cm = confusion_matrix(y_test, rf_predictions) plt.figure(figsize=(5, 4)) sns.heatmap(rf_cm, annot=True, fmt="d", cmap="Blues", cbar=False) plt.title("Random Forest Confusion Matrix") plt.show()