Visualizing Decision Trees for Product Purchase Prediction with scikit-learn and dtreeviz
This tutorial explains how to prepare advertising click data, train a decision‑tree classifier, and generate clear visualizations using scikit‑learn and dtreeviz, while also showing how to inspect individual prediction paths and feature importance.
Decision Tree Model
Decision trees are interpretable models for classification and regression. Using scikit-learn and dtreeviz you can visualize them, which helps present modeling results.
Product Purchase Prediction
Data preparation
The example uses an advertising click dataset with features Gender, Age, EstimatedSalary and target Purchased. After removing the user ID column, the data are cleaned, categorical text is encoded to numeric, and the dataset is split into training and test sets.
Model training
<code># import libraries
import numpy as np
import sklearn, sklearn.tree
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.metrics as metrics
import seaborn as sn
# load dataset
dataset = pd.read_csv('data/Social_Network_Ads.csv')
dataset = dataset.drop(columns=['User ID'])
# encode categorical gender
enc = sklearn.preprocessing.OneHotEncoder()
enc.fit(dataset.iloc[:,[0]])
onehotlabels = enc.transform(dataset.iloc[:,[0]]).toarray()
genders = pd.DataFrame({'Female': onehotlabels[:,0], 'Male': onehotlabels[:,1]})
result = pd.concat([genders, dataset.iloc[:,1:]], axis=1, sort=False)
# split data
y = result['Purchased']
X = result.drop(columns=['Purchased'])
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
# train decision tree
classifier = sklearn.tree.DecisionTreeClassifier(criterion='entropy', random_state=100, max_depth=2)
classifier.fit(X_train, y_train)
</code>Visualization with scikit-learn
Using plot_tree you can draw the tree with feature and class names.
<code>feature_names=['Female','Male','Age','EstimatedSalary']
class_names=['No Purchase','Purchase']
from sklearn.tree import plot_tree, export_text
plt.figure(figsize=(10,5))
plot_tree(classifier, class_names=class_names, feature_names=X.columns, filled=True)
plt.show()
</code>Advanced visualization with dtreeviz
dtreeviz produces richer visualizations, including histograms and pie charts, and allows orientation changes or depth range selection.
<code>import dtreeviz
random_state = 1234
viz_model = dtreeviz.model(classifier,
X_train=X_train, y_train=y_train,
feature_names=feature_names,
target_name=[0,1],
class_names=['No Purchased','Purchased'])
viz_model.view(scale=1)
</code>You can also inspect a single prediction path and feature importance for a specific sample.
<code>x0 = dataset.iloc[0]
viz_model.view(x=x0)
print(viz_model.explain_prediction_path(x0))
</code>The analysis shows that Age has the greatest impact, followed by EstimatedSalary, while Gender contributes little.
Model Perspective
Insights, knowledge, and enjoyment from a mathematical modeling researcher and educator. Hosted by Haihua Wang, a modeling instructor and author of "Clever Use of Chat for Mathematical Modeling", "Modeling: The Mathematics of Thinking", "Mathematical Modeling Practice: A Hands‑On Guide to Competitions", and co‑author of "Mathematical Modeling: Teaching Design and Cases".
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.