Decision Trees
Decision trees are one of the most intuitive machine learning algorithms. They partition data through a series of yes/no questions, creating a structure that mirrors how humans naturally make decisions. Unlike many ML models, a decision tree can be printed, read, and explained to a non-technical stakeholder.
How a Decision Tree Works
Age > 30? / \ Yes No / \ Income > 60K? Credit Score > 700? / \ / \ Yes No Yes No | | | | APPROVE DENY APPROVE DENYThe model learns which questions (feature splits) best separate the classes, and at what thresholds to ask them.
Splitting Criteria
Gini Impurity (Classification, Default in sklearn)
Measures the probability of incorrectly classifying a randomly chosen element if it were labeled according to the distribution in the node.
Gini(node) = 1 - Σ pᵢ²
Perfect node (all one class): Gini = 1 - 1² = 0Mixed 50/50 node: Gini = 1 - (0.5² + 0.5²) = 0.5At each node, the algorithm tries all possible splits on all features, picks the one that minimizes weighted Gini impurity of the children.
Information Gain / Entropy
Based on Shannon entropy:
Entropy = -Σ pᵢ × log₂(pᵢ)Information Gain = Entropy(parent) - Weighted Entropy(children)Tends to produce slightly different trees than Gini but similar performance.
MSE (Regression)
For regression trees, split to minimize Mean Squared Error in child nodes.
Building a Tree in Practice
from sklearn.tree import DecisionTreeClassifier, export_textfrom sklearn.model_selection import train_test_splitimport matplotlib.pyplot as pltfrom sklearn.tree import plot_tree
# Trainmodel = DecisionTreeClassifier( max_depth=5, # Prevent overfitting min_samples_leaf=10, # At least 10 samples in each leaf criterion='gini', random_state=42)model.fit(X_train, y_train)
# Visualizeplt.figure(figsize=(20, 10))plot_tree(model, feature_names=feature_names, class_names=['No', 'Yes'], filled=True, rounded=True, fontsize=10)plt.show()
# Print text rulesprint(export_text(model, feature_names=feature_names))Overfitting and Pruning
An unconstrained decision tree will grow until every leaf has a single sample — memorizing the training data perfectly.
Key hyperparameters to control complexity:
| Parameter | Effect |
|---|---|
max_depth | Maximum tree depth; most important control |
min_samples_split | Minimum samples to attempt a split |
min_samples_leaf | Minimum samples in a leaf node |
max_leaf_nodes | Maximum number of leaf nodes |
min_impurity_decrease | Minimum impurity gain to allow a split |
Post-pruning: Grow full tree, then prune back using cost-complexity pruning (alpha parameter):
# Find best alpha via cross-validationpath = model.cost_complexity_pruning_path(X_train, y_train)alphas = path.ccp_alphas
cv_scores = []for alpha in alphas: tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42) scores = cross_val_score(tree, X_train, y_train, cv=5) cv_scores.append(scores.mean())
best_alpha = alphas[np.argmax(cv_scores)]model_pruned = DecisionTreeClassifier(ccp_alpha=best_alpha).fit(X_train, y_train)Decision Tree Regression
Decision trees work equally well for continuous targets:
from sklearn.tree import DecisionTreeRegressor
reg_tree = DecisionTreeRegressor(max_depth=4, random_state=42)reg_tree.fit(X_train, y_train)
# Each leaf outputs the mean of training samples in that leafy_pred = reg_tree.predict(X_test)The output is a step function — predictions are constant within each leaf region. This gives decision trees their distinctive “blocky” prediction surfaces.
Feature Importance
Decision trees naturally provide feature importances — the total impurity reduction attributable to each feature:
import pandas as pdimport matplotlib.pyplot as plt
importance = pd.Series(model.feature_importances_, index=feature_names)importance.sort_values(ascending=True).plot(kind='barh')plt.title("Feature Importances from Decision Tree")Strengths and Weaknesses
Strengths:
- Fully interpretable (you can print the rules)
- Handles mixed data types naturally
- No scaling or preprocessing required
- Captures non-linear patterns
Weaknesses:
- High variance — small data changes can produce very different trees
- Axis-aligned splits only (diagonal patterns need many splits)
- Single trees rarely competitive with ensemble methods
The solution to decision tree weaknesses is to use many of them together — that’s Random Forests and Gradient Boosting, which inherit the interpretability and flexibility of trees while fixing the variance problem.