Classifying Iris Flowers with Decision Trees

Classifying Iris flower species using decision trees is a classical educational example in machine learning. The Iris dataset is small, well-understood, and helps illustrate how classification algorithms work in practice. By mastering this, you gain insight into tree-based models, feature selection, overfitting vs pruning, interpretability, and model evaluation — all of which are foundational to more advanced methods like random forests and gradient boosting.

Why is it important to learn this?

  • Decision trees are intuitive and easy to visualize — good for learning and explaining models.
  • They form the building blocks of ensemble methods like random forests and boosted trees.
  • Many interview and exam questions use the Iris dataset as a toy problem.
  • Understanding splitting criteria (entropy, Gini, information gain) is core knowledge in ML theory.
  • You will sharpen your ability to implement algorithms, interpret results, and reason about generalization.

The Iris Dataset — a primer

The Iris dataset was introduced by statistician Ronald Fisher in 1936. ([Wikipedia][1]) It contains 150 samples of iris flowers, with 50 samples from each of three species:

  • Iris setosa
  • Iris versicolor
  • Iris virginica

Each sample is described by four features:

  1. Sepal length (cm)
  2. Sepal width (cm)
  3. Petal length (cm)
  4. Petal width (cm)

The target is the species label (categorical). Because the dataset is clean, balanced, small, and widely used, it’s ideal to experiment with. ([GeeksforGeeks][2])

One interesting fact is that one species (setosa) is linearly separable from the other two in this feature space, but versicolor and virginica overlap more. ([GeeksforGeeks][3])

So when you train a decision tree, it may first split to separate species that are easier to isolate, then refine inside overlapping groups.


What is a Decision Tree?

A decision tree is a supervised machine learning model that uses a tree-like structure of decisions (splits) to classify inputs (or to predict numerical values in regression). ([scikit-learn][4])

  • Each internal node tests a feature (e.g. “petal length ≤ 2.45?”).
  • Each branch corresponds to an outcome (yes or no, or categorical splits).
  • Each leaf node gives a predicted class (for classification) or value (for regression).

The model works by asking a sequence of questions: based on the answer to one, you go down to the next node, and so on until you reach a leaf and output the prediction.

Why decision trees:

  • They are interpretable (you can read off the rules).
  • They require little data preprocessing (can handle categorical features, no need to scale).
  • They can capture non-linear decision boundaries.
  • But they are prone to overfitting, so you need to control depth or prune.

One key internal concept is how to decide which feature to split on at each node. Common criteria: entropy/information gain, Gini impurity, or variance reduction (for regression). ([scikit-learn][4])

Entropy & Information Gain (for classification)

  • Entropy measures the impurity or disorder in a set of examples.
The entropy formula is:
$$
H(S) = -\sum_{c \in C} p(c) \log_2 p(c)
$$

where (p(c)) is the proportion of class (c) in the set.

  • When you split on a feature (A), you partition (S) into subsets (S_{a}) for each possible value or threshold of (A). You compute the weighted average entropy of the subsets.

  • Information Gain is: [ IG(S, A) = H(S) - \sum_{a} \frac{|S_a|}{|S|},H(S_a) ]

You choose the feature that gives the highest information gain (greatest reduction in entropy). ([Wikipedia][5])

Alternatively, Gini impurity is:

[ Gini(S) = 1 - \sum_{c} p(c)^2 ]

You pick splits that maximize reduction in Gini impurity.

Overfitting and Pruning

If you allow your tree to grow without limits, it might perfectly classify the training data — but perform poorly on new data (overfitting).

To control that, you can:

  • Limit max depth
  • Require a minimum number of samples per leaf
  • Use pruning after full growth (e.g. cost-complexity pruning, reduced-error pruning) to cut back weak splits ([Wikipedia][6])

Scikit-learn’s implementation supports max_depth, min_samples_leaf, and cost complexity pruning (ccp_alpha). ([scikit-learn][4])


Step-by-Step: How a Tree Classifies Iris Flowers

Here is a conceptual step-by-step:

  1. Root Node All 150 Iris samples (with 3 classes) are in the root. You compute entropy. You consider splitting on each of the four features (and possibly different thresholds). Suppose petal length yields the highest information gain when you split at threshold, say, petal_length ≤ 2.45.

  2. First Split

    • If petal_length ≤ 2.45 → One subset
    • Else → Another subset

    You check class proportions in each subset. The ≤ subset might contain mostly setosa examples. If so, that branch might end quickly in a leaf predicting setosa.

  3. Recursive Splitting For the other branch (petal_length > 2.45), you again compute entropy inside that subset and consider splitting on the remaining features (e.g. petal width, sepal width).

  4. Continue Splitting until Stop Criteria Either you reach pure subsets (all one class), or further splitting doesn’t increase gain beyond a threshold, or you hit max_depth or minimal sample constraints.

  5. Prediction To classify a new sample, you follow the decision path (e.g. “petal_length ≤ 2.45? No → petal_width ≤ 1.75? Yes → Predict versicolor”).

Visualization of decision boundaries shows piecewise rectangular or axis-aligned splits in feature space. ([scikit-learn][7])

Internally, scikit-learn’s tree stores arrays of feature indices, thresholds, child pointers, impurity values, and leaf values. ([scikit-learn][8])


(Flow of Decision)

Here is a simple of a decision tree for Iris classification:

Image

Image

Below is a sample diagram in Mermaid syntax:

petal_length ≤ 2.45?

Leaf: setosa

petal_width ≤ 1.75?

Leaf: versicolor

Leaf: virginica

Interpretation:

  • At the root, test petal_length ≤ 2.45.

    • If yes, you go to a leaf predicting setosa.

    • If no, you go to a second node testing petal_width ≤ 1.75.

      • If yes → versicolor
      • If no → virginica

You can extend this further (multiple splits, more depth) depending on data.


Three Example Programs

Below are three different example styles (vanilla Python, scikit-learn, and a custom implementation) for classifying Iris with decision trees. Use them as learning aids.

Example 1: Using Scikit-Learn (typical, simplest)

example1_sklearn.py
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.metrics import classification_report, accuracy_score
def main():
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
clf = DecisionTreeClassifier(max_depth=3, random_state=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=iris.target_names))
# print textual tree
print(export_text(clf, feature_names=iris.feature_names))
if __name__ == "__main__":
main()

What this demonstrates:

  • Loading data
  • Splitting train/test
  • Training a DecisionTreeClassifier
  • Evaluating accuracy & classification report
  • Printing the textual form of the tree

Example 2: Manual Decision Tree (toy, rule-based)

This example is minimal and only handles a few nodes manually — for educational purposes.

example2_manual.py
from sklearn.datasets import load_iris
def predict_rule(sample):
# sample: [sepal_len, sepal_wid, petal_len, petal_wid]
if sample[2] <= 2.45:
return 0 # setosa
else:
# further split
if sample[3] <= 1.75:
return 1 # versicolor
else:
return 2 # virginica
def main():
iris = load_iris()
X, y = iris.data, iris.target
correct = 0
for xi, yi in zip(X, y):
if predict_rule(xi) == yi:
correct += 1
print("Accuracy on entire dataset (rule-based):", correct / len(y))
# Test on a custom sample
test = [5.1, 3.5, 1.4, 0.2]
print("Prediction for", test, "->", iris.target_names[predict_rule(test)])
if __name__ == "__main__":
main()

What this demonstrates:

  • Hardcoding a decision tree manually for the Iris data
  • Applying it to all samples and custom inputs
  • Illustrates how splits map to if-else logic

You can expand this toy tree by manually adding more splits (e.g. additional thresholds) if needed.


Example 3: Building a Decision Tree from Scratch (simplified)

This is more advanced: a minimal implementation of ID3-style tree building. (Not optimized, for learning.)

example3_id3.py
import math
from collections import Counter, defaultdict
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
def entropy(labels):
total = len(labels)
counts = Counter(labels)
ent = 0.0
for cnt in counts.values():
p = cnt / total
ent -= p * math.log2(p)
return ent
def information_gain(dataset, labels, feature_index, threshold):
# split into two groups: ≤ threshold and > threshold
left_labels, right_labels = [], []
for xi, yi in zip(dataset, labels):
if xi[feature_index] <= threshold:
left_labels.append(yi)
else:
right_labels.append(yi)
if not left_labels or not right_labels:
return 0
total = len(labels)
ent_before = entropy(labels)
ent_after = (len(left_labels)/total) * entropy(left_labels) + \
(len(right_labels)/total) * entropy(right_labels)
return ent_before - ent_after
class Node:
def __init__(self, *, feature_index=None, threshold=None, left=None, right=None, *, value=None):
self.feature_index = feature_index
self.threshold = threshold
self.left = left
self.right = right
self.value = value # if leaf, the class label
def build_tree(dataset, labels, depth=0, max_depth=3):
# If pure or reached max depth
if len(set(labels)) == 1 or depth >= max_depth:
majority = Counter(labels).most_common(1)[0][0]
return Node(value=majority)
# find best split
best_gain = 0.0
best_feat = None
best_thresh = None
n_features = len(dataset[0])
for feat in range(n_features):
# pick candidate thresholds (distinct values)
d_vals = sorted(set(x[feat] for x in dataset))
for t in d_vals:
gain = information_gain(dataset, labels, feat, t)
if gain > best_gain:
best_gain = gain
best_feat = feat
best_thresh = t
if best_gain == 0:
majority = Counter(labels).most_common(1)[0][0]
return Node(value=majority)
# partition
left_data, left_labels, right_data, right_labels = [], [], [], []
for xi, yi in zip(dataset, labels):
if xi[best_feat] <= best_thresh:
left_data.append(xi)
left_labels.append(yi)
else:
right_data.append(xi)
right_labels.append(yi)
left_node = build_tree(left_data, left_labels, depth+1, max_depth)
right_node = build_tree(right_data, right_labels, depth+1, max_depth)
return Node(feature_index=best_feat, threshold=best_thresh, left=left_node, right=right_node)
def predict(tree, sample):
if tree.value is not None:
return tree.value
if sample[tree.feature_index] <= tree.threshold:
return predict(tree.left, sample)
else:
return predict(tree.right, sample)
def main():
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.3, random_state=0
)
tree = build_tree(list(X_train), list(y_train), max_depth=4)
correct = 0
for xi, yi in zip(X_test, y_test):
if predict(tree, xi) == yi:
correct += 1
print("Test accuracy (from-scratch):", correct / len(y_test))
sample = [6.0, 2.9, 4.5, 1.5]
print("Prediction for", sample, "->", iris.target_names[predict(tree, sample)])
if __name__ == "__main__":
main()

What this demonstrates:

  • Basic tree construction using information gain
  • Recursion, splitting, majority voting
  • Making predictions
  • A bridge from theory to code

You can expand or optimize this (e.g. choose best thresholds, pruning) as exercise.


How to Remember Core Concepts (Interview / Exam Tips)

To internalize and recall easily, here are strategies and mnemonics:

  1. Mnemonic for steps of building a decision tree“E G S P”

    • E: Compute Entropy (or impurity)
    • G: Compute Gain for candidate splits
    • S: Select best feature and threshold
    • P: Partition and recurse (or Prun / Propagate)
  2. Visualize as a set of if-else rules Always imagine the resulting tree as nested if-else statements. Many questions ask you to interpret the splits or write the rules.

  3. Associate Iris features with easiest splits In Iris, petal length is the best first split (because setosa tends to have very small petals). Remember that “petal length splits early” as a heuristic.

  4. Practice drawing small trees by hand Use a tiny subset (e.g. only two features) and draw splits, compute entropy manually, choose splits. This improves clarity under exam conditions.

  5. Flashcards for formulae A card with “Entropy formula” on one side, “Information gain = parent − weighted children” on the back. Another for “Gini impurity”, etc.

  6. Compare decision tree to human decision-making When diagnosing, doctors ask sequential binary decisions (e.g. “Is fever > 38 °C? Yes → test A; No → test B”). This analogy helps anchor how tree logic proceeds.

  7. Teach it to someone else or write blog posts Explaining forces clarity. Even re-writing your own example code helps retention.

  8. Do past interview / exam problems Many ML interviews or exams ask “given this small table, draw the decision tree” or “compute entropy and choose the best split”. Practicing those cements your skill.

  9. Mind map of pitfalls Keep in mind: overfitting, pruning, bias toward features with more levels. Place these in your mental map so you can anticipate tricky questions.


Why This Concept Shows Up Frequently

  • It’s conceptually simple yet rich: you can ask deep questions about splits, bias, overfitting, pruning, complexity, etc.
  • The Iris dataset is small, clean, and known — it’s standard in textbooks and courses.
  • Decision trees serve as a gateway to more advanced topics (ensemble methods, feature importance, tree boosting).
  • Many benchmarks, competitions, and papers build more complex tree-based learners — so understanding the base concept is vital.
  • It provides a bridge between discrete (logic) thinking and statistical learning — bridging rule-based AI and probabilistic models.