Decision Trees: Hierarchical Rule-Based Classification and Regression

Understand decision trees — how they split data, Gini impurity, information gain, pruning, and when to use trees vs. ensemble methods for classification tasks.

Decision Trees

Decision trees are one of the most intuitive machine learning algorithms. They partition data through a series of yes/no questions, creating a structure that mirrors how humans naturally make decisions. Unlike many ML models, a decision tree can be printed, read, and explained to a non-technical stakeholder.


How a Decision Tree Works

Age > 30?
/ \
Yes No
/ \
Income > 60K? Credit Score > 700?
/ \ / \
Yes No Yes No
| | | |
APPROVE DENY APPROVE DENY

The model learns which questions (feature splits) best separate the classes, and at what thresholds to ask them.


Splitting Criteria

Gini Impurity (Classification, Default in sklearn)

Measures the probability of incorrectly classifying a randomly chosen element if it were labeled according to the distribution in the node.

Gini(node) = 1 - Σ pᵢ²
Perfect node (all one class): Gini = 1 - 1² = 0
Mixed 50/50 node: Gini = 1 - (0.5² + 0.5²) = 0.5

At each node, the algorithm tries all possible splits on all features, picks the one that minimizes weighted Gini impurity of the children.

Information Gain / Entropy

Based on Shannon entropy:

Entropy = -Σ pᵢ × log₂(pᵢ)
Information Gain = Entropy(parent) - Weighted Entropy(children)

Tends to produce slightly different trees than Gini but similar performance.

MSE (Regression)

For regression trees, split to minimize Mean Squared Error in child nodes.


Building a Tree in Practice

from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# Train
model = DecisionTreeClassifier(
max_depth=5, # Prevent overfitting
min_samples_leaf=10, # At least 10 samples in each leaf
criterion='gini',
random_state=42
)
model.fit(X_train, y_train)
# Visualize
plt.figure(figsize=(20, 10))
plot_tree(model, feature_names=feature_names, class_names=['No', 'Yes'],
filled=True, rounded=True, fontsize=10)
plt.show()
# Print text rules
print(export_text(model, feature_names=feature_names))

Overfitting and Pruning

An unconstrained decision tree will grow until every leaf has a single sample — memorizing the training data perfectly.

Key hyperparameters to control complexity:

ParameterEffect
max_depthMaximum tree depth; most important control
min_samples_splitMinimum samples to attempt a split
min_samples_leafMinimum samples in a leaf node
max_leaf_nodesMaximum number of leaf nodes
min_impurity_decreaseMinimum impurity gain to allow a split

Post-pruning: Grow full tree, then prune back using cost-complexity pruning (alpha parameter):

# Find best alpha via cross-validation
path = model.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas
cv_scores = []
for alpha in alphas:
tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
scores = cross_val_score(tree, X_train, y_train, cv=5)
cv_scores.append(scores.mean())
best_alpha = alphas[np.argmax(cv_scores)]
model_pruned = DecisionTreeClassifier(ccp_alpha=best_alpha).fit(X_train, y_train)

Decision Tree Regression

Decision trees work equally well for continuous targets:

from sklearn.tree import DecisionTreeRegressor
reg_tree = DecisionTreeRegressor(max_depth=4, random_state=42)
reg_tree.fit(X_train, y_train)
# Each leaf outputs the mean of training samples in that leaf
y_pred = reg_tree.predict(X_test)

The output is a step function — predictions are constant within each leaf region. This gives decision trees their distinctive “blocky” prediction surfaces.


Feature Importance

Decision trees naturally provide feature importances — the total impurity reduction attributable to each feature:

import pandas as pd
import matplotlib.pyplot as plt
importance = pd.Series(model.feature_importances_, index=feature_names)
importance.sort_values(ascending=True).plot(kind='barh')
plt.title("Feature Importances from Decision Tree")

Strengths and Weaknesses

Strengths:

  • Fully interpretable (you can print the rules)
  • Handles mixed data types naturally
  • No scaling or preprocessing required
  • Captures non-linear patterns

Weaknesses:

  • High variance — small data changes can produce very different trees
  • Axis-aligned splits only (diagonal patterns need many splits)
  • Single trees rarely competitive with ensemble methods

The solution to decision tree weaknesses is to use many of them together — that’s Random Forests and Gradient Boosting, which inherit the interpretability and flexibility of trees while fixing the variance problem.