Note_Tech

All technological notes.


Project maintained by simonangel-fong Hosted on GitHub Pages — Theme by mattgraham

Machine Learning - Decision Tree

Back


Example: Decision Tree

Predict whether a customer go to a comedy show, based on series of data.


1. Load Data

import pandas as pd

FILE_PATH = "./data_decision_tree.csv"
file_data = pd.read_csv(FILE_PATH)

# file_data.shape
# file_data.columns
# file_data.info()
# print(file_data)

file_data.head()


2. Clean Data

# Convert nationality into numerical values
map_nationality = {'UK': 0, 'USA': 1, 'N': 2}
file_data['Nationality'] = file_data['Nationality'].map(map_nationality)

# Convert go into numerical values
map_go = {'YES': 1, 'NO': 0}
file_data['Go'] = file_data['Go'].map(map_go)

print(file_data)

3. Split feature columns and the target column

features = ['Age', 'Experience', 'Rank', 'Nationality']

feature_list = file_data[features]
target_list = file_data['Go']

# print(feature_list)
# print(target_list)

4. Create Decision Tree Model and Fit Data

from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

import matplotlib.pyplot as plt

predict_model = DecisionTreeClassifier()
predict_model = predict_model.fit(feature_list.values, target_list.values)

# sklearn.tree.plot_tree(): Plot a decision tree.
#   decision_tree: The decision tree to be plotted.
#   feature_names: Names of each of the features.
tree.plot_tree(predict_model, feature_names=features)

plt.show()

decision tree


5. Predict Values

predict_value = predict_model.predict([[40, 10, 7, 1]])

# print(predict_value)
print("Yes" if predict_value[0] else "No")      # No

TOP