Dec 05, 2025
Continuning in our journey through machine learning, today we turn to decision trees, a powerful and intuitive algorithm that mimics human decision-making. Imagine you’re playing a game of “20 Questions” where you try to guess what animal someone is thinking of. You might ask: “Does it live in water?” If yes, you’ve ruled out all land animals. Then: “Does it have scales?” This narrows it down further. Each question splits the possibilities into smaller groups until you identify the answer. This is exactly how decision trees work, they ask a series of yes/no questions about data features, splitting the dataset at each step until they can make accurate predictions.
In an era where “black box” models like deep neural networks dominate headlines, decision trees stand out as beautifully transparent. When a decision tree makes a prediction, you can trace the exact path it took, which questions it asked and which answers led to the conclusion. This interpretability makes decision trees invaluable in scenarios where understanding the reasoning behind predictions is crucial, for example:
Decision trees offer unique advantages:
A decision tree consists of:
1. Root Node: The starting point containing all training data
2. Internal Nodes (Decision Nodes): Each represents a question about a feature
3. Branches (Edges): Represent the answer to the question (Yes/No, or multiple categories)
4. Leaf Nodes (Terminal Nodes): The final decision/prediction
Let’s say we’re predicting whether a customer will buy a product based on age and income:
[Root: All Customers]
|
Is Income > $50k?
/ \
Yes No
/ \
[Is Age > 35?] [Predict: No Buy]
/ \
Yes No
/ \
[Predict: Buy] [Predict: No Buy]
Reading this tree:
We can represent a decision tree as a function:
\[f(x) = \sum_{m=1}^{M} c_m \cdot \mathbb{1}(x \in R_m)\]Where:
Now let’s dive into the magic of how decision trees actually make decisions.
Building a decision tree involves recursive partitioning:
This is actually a greedy algorithm, at each step, we make the locally optimal choice without looking ahead.
A good split should:
Suppose we have this dataset of 10 people:
| Age | Income | Bought |
|---|---|---|
| 25 | 40k | No |
| 30 | 45k | No |
| 35 | 60k | Yes |
| 40 | 65k | Yes |
| 45 | 70k | Yes |
| 28 | 50k | No |
| 50 | 75k | Yes |
| 22 | 35k | No |
| 38 | 58k | Yes |
| 42 | 62k | Yes |
Current distribution: 6 “Yes”, 4 “No” (impure)
Candidate split: Income > 55k
This split significantly reduces impurity, making it a good choice.
To find the best split, we need to quantify “impurity” or “disorder” in a node. To do this, three main metrics are used:
The Intuition: Gini impurity measures the probability of incorrectly classifying a randomly chosen element if we randomly assign a label according to the distribution in the node.
Mathematical Definition:
\[\text{Gini}(S) = 1 - \sum_{i=1}^{C} p_i^2\]Where:
Properties:
Example Calculation:
For a node with 6 “Yes” and 4 “No”:
\[p_{\text{Yes}} = \frac{6}{10} = 0.6, \quad p_{\text{No}} = \frac{4}{10} = 0.4\] \[\text{Gini} = 1 - (0.6^2 + 0.4^2) = 1 - (0.36 + 0.16) = 0.48\]Interpretation: 0.48 indicates high impurity (close to maximum of 0.5).
The Intuition: Entropy comes from information theory and measures the average amount of information needed to identify the class of an example.
Mathematical Definition:
\[\text{Entropy}(S) = -\sum_{i=1}^{C} p_i \log_2(p_i)\]Where:
Properties:
Example Calculation:
Same node (6 “Yes”, 4 “No”):
\(\text{Entropy} = -(0.6 \log_2(0.6) + 0.4 \log_2(0.4))\) \(= -(0.6 \times (-0.737) + 0.4 \times (-1.322))\) \(= -(-0.442 - 0.529) = 0.971\)
Interpretation: 0.971 out of maximum 1.0 indicates high disorder.
The Intuition: Simply the fraction of examples that would be misclassified if we assign the majority class to all examples.
Mathematical Definition:
\[\text{Error}(S) = 1 - \max_i(p_i)\]Example Calculation:
Same node (6 “Yes”, 4 “No”):
\[\text{Error} = 1 - \max(0.6, 0.4) = 1 - 0.6 = 0.4\]When to Use Each Metric:
For binary classification:
Probability p₁ Gini Entropy
0.0 0.0 0.0
0.1 0.18 0.469
0.2 0.32 0.722
0.3 0.42 0.881
0.4 0.48 0.971
0.5 0.5 1.0
Key observation: Both reach maximum at 50-50 split, both reach 0 at pure nodes. Entropy is slightly more sensitive to probability changes.
Information Gain measures how much a split reduces impurity:
\[\text{IG}(S, A) = \text{Impurity}(S) - \sum_{v \in \text{Values}(A)} \frac{|S_v|}{|S|} \text{Impurity}(S_v)\]Where:
| $ | S | $ is the number of examples in $S$ |
Dataset (10 examples): 6 “Yes”, 4 “No”
Parent Gini: 0.48
Option 1: Split on Income > 55k
Left (Income ≤ 55k): 5 examples (1 Yes, 4 No) \(\text{Gini}_{\text{left}} = 1 - (0.2^2 + 0.8^2) = 1 - 0.68 = 0.32\)
Right (Income > 55k): 5 examples (5 Yes, 0 No) \(\text{Gini}_{\text{right}} = 1 - (1^2 + 0^2) = 0\)
Weighted Gini after split: \(\text{Gini}_{\text{split}} = \frac{5}{10}(0.32) + \frac{5}{10}(0) = 0.16\)
Gini Gain: \(\text{Gain} = 0.48 - 0.16 = 0.32\)
Option 2: Split on Age > 35
Left (Age ≤ 35): 6 examples (2 Yes, 4 No) \(\text{Gini}_{\text{left}} = 1 - (0.333^2 + 0.667^2) = 0.444\)
Right (Age > 35): 4 examples (4 Yes, 0 No) \(\text{Gini}_{\text{right}} = 0\)
Weighted Gini after split: \(\text{Gini}_{\text{split}} = \frac{6}{10}(0.444) + \frac{4}{10}(0) = 0.267\)
Gini Gain: \(\text{Gain} = 0.48 - 0.267 = 0.213\)
Decision: Choose Income > 55k (Gini Gain of 0.32 > 0.213)
Problem with Information Gain: It favors features with many distinct values.
Example: An “ID” feature with unique values for each example would perfectly split the data (Gini = 0 for each child), but it’s useless for generalization!
Solution - Gain Ratio:
\[\text{GainRatio}(S, A) = \frac{\text{IG}(S, A)}{\text{SplitInfo}(S, A)}\]Where Split Information penalizes features with many splits:
\[\text{SplitInfo}(S, A) = -\sum_{v \in \text{Values}(A)} \frac{|S_v|}{|S|} \log_2 \frac{|S_v|}{|S|}\]This normalizes the information gain, making comparisons fairer across features with different numbers of values.
Algorithm: Developed by Ross Quinlan in 1986
function ID3(examples, attributes):
if all examples have same class:
return leaf node with that class
if attributes is empty:
return leaf with majority class
best_attribute = attribute with highest information gain
tree = new decision node for best_attribute
for each value v of best_attribute:
examples_v = subset where best_attribute = v
if examples_v is empty:
add leaf with majority class
else:
subtree = ID3(examples_v, attributes - {best_attribute})
add branch to tree for value v with subtree
return tree
Limitations:
Improvements:
Handling Continuous Attributes:
For a continuous feature like Age:
Algorithm: Developed by Breiman, Friedman, Olshen, and Stone (1984)
Key Differences:
CART Algorithm:
function CART(examples):
if stopping criterion met:
return leaf node with prediction
best_split = None
best_gain = 0
for each feature:
for each possible split point:
gain = calculate_gini_gain(split)
if gain > best_gain:
best_gain = gain
best_split = split
left_node = CART(examples where feature < threshold)
right_node = CART(examples where feature >= threshold)
return decision node with best_split, left_node, right_node
Stopping Criteria:
Prediction: Instead of a class label, predict a continuous value
Leaf Node Values:
Mean Squared Error (MSE):
\[\text{MSE}(S) = \frac{1}{|S|} \sum_{i \in S} (y_i - \bar{y})^2\]Where $\bar{y}$ is the mean of $y$ values in node $S$.
MSE Reduction (equivalent to information gain):
\[\text{MSE}_{\text{reduction}} = \text{MSE}(S) - \left(\frac{|S_L|}{|S|}\text{MSE}(S_L) + \frac{|S_R|}{|S|}\text{MSE}(S_R)\right)\]Mean Absolute Error (MAE) (alternative):
\[\text{MAE}(S) = \frac{1}{|S|} \sum_{i \in S} |y_i - \text{median}(S)|\]Task: Predict house price based on size
| Size (sq ft) | Price ($k) |
|---|---|
| 800 | 150 |
| 1000 | 180 |
| 1200 | 220 |
| 1500 | 280 |
| 2000 | 350 |
| 2500 | 400 |
Current MSE (all data):
Try split: Size ≤ 1200
Left (Size ≤ 1200): [150, 180, 220]
Right (Size > 1200): [280, 350, 400]
Weighted MSE after split: \(\text{MSE}_{\text{split}} = \frac{3}{6}(816.67) + \frac{3}{6}(2222.22) = 1,519.44\)
MSE Reduction: \(9,222.22 - 1,519.44 = 7,702.78 \quad \text{(Excellent split!)}\)
Regression trees create piecewise constant predictions:
Price = 183.33 if Size ≤ 1200
Price = 343.33 if Size > 1200
This creates a step function rather than smooth curve. For smoother predictions, ensemble methods (Random Forests, Gradient Boosting) work better.
import numpy as np
from collections import Counter
class Node:
"""Represents a node in the decision tree."""
def __init__(
self,
feature_index=None,
threshold=None,
left=None,
right=None,
value=None,
impurity=None,
n_samples=None
):
# Internal node properties
self.feature_index = feature_index # Index of feature to split on
self.threshold = threshold # Threshold value for split
self.left = left # Left child node
self.right = right # Right child node
# Leaf node properties
self.value = value # Predicted class/value for leaf
# Node statistics
self.impurity = impurity # Gini or MSE
self.n_samples = n_samples # Number of samples in node
class DecisionTreeClassifier:
"""Decision Tree Classifier using Gini impurity."""
def __init__(
self,
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_impurity_decrease=0.0
):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_impurity_decrease = min_impurity_decrease
self.root = None
def fit(self, X, y):
"""Build decision tree classifier."""
self.n_classes = len(np.unique(y))
self.n_features = X.shape[1]
self.root = self._grow_tree(X, y, depth=0)
return self
def _gini(self, y):
"""Calculate Gini impurity."""
n_samples = len(y)
if n_samples == 0:
return 0
# Count occurrences of each class
class_counts = np.bincount(y)
probabilities = class_counts / n_samples
# Gini = 1 - sum(p_i^2)
gini = 1.0 - np.sum(probabilities ** 2)
return gini
def _best_split(self, X, y):
"""Find the best split for a node."""
n_samples, n_features = X.shape
if n_samples <= 1:
return None, None
# Current impurity
parent_impurity = self._gini(y)
best_gain = 0.0
best_feature = None
best_threshold = None
# Try each feature
for feature_idx in range(n_features):
# Get unique values and sort them
thresholds = np.unique(X[:, feature_idx])
# Try each unique value as threshold
for threshold in thresholds:
# Split data
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
# Skip if split doesn't satisfy min_samples_leaf
if (np.sum(left_mask) < self.min_samples_leaf or
np.sum(right_mask) < self.min_samples_leaf):
continue
# Calculate weighted impurity of children
n_left = np.sum(left_mask)
n_right = np.sum(right_mask)
left_impurity = self._gini(y[left_mask])
right_impurity = self._gini(y[right_mask])
weighted_impurity = (
(n_left / n_samples) * left_impurity +
(n_right / n_samples) * right_impurity
)
# Calculate information gain
gain = parent_impurity - weighted_impurity
# Update best split if this is better
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
# Check minimum impurity decrease
if best_gain < self.min_impurity_decrease:
return None, None
return best_feature, best_threshold
def _grow_tree(self, X, y, depth=0):
"""Recursively grow the decision tree."""
n_samples = len(y)
n_classes = len(np.unique(y))
# Calculate current node impurity
impurity = self._gini(y)
# Determine majority class for this node
class_counts = np.bincount(y, minlength=self.n_classes)
predicted_class = np.argmax(class_counts)
# Create node
node = Node(
value=predicted_class,
impurity=impurity,
n_samples=n_samples
)
# Stopping criteria
if (depth >= self.max_depth if self.max_depth else False):
return node
if n_samples < self.min_samples_split:
return node
if n_classes == 1: # Pure node
return node
# Find best split
feature_idx, threshold = self._best_split(X, y)
if feature_idx is None: # No valid split found
return node
# Split data
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
# Recursively build left and right subtrees
node.feature_index = feature_idx
node.threshold = threshold
node.left = self._grow_tree(X[left_mask], y[left_mask], depth + 1)
node.right = self._grow_tree(X[right_mask], y[right_mask], depth + 1)
return node
def _traverse_tree(self, x, node):
"""Traverse tree to make prediction for a single sample."""
# If leaf node, return prediction
if node.feature_index is None:
return node.value
# Traverse left or right based on feature value
if x[node.feature_index] <= node.threshold:
return self._traverse_tree(x, node.left)
else:
return self._traverse_tree(x, node.right)
def predict(self, X):
"""Predict class labels for samples."""
return np.array([self._traverse_tree(x, self.root) for x in X])
def predict_proba(self, X):
"""Predict class probabilities for samples."""
# For simplicity, return one-hot encoded predictions
# A full implementation would track class distributions in leaves
predictions = self.predict(X)
n_samples = len(predictions)
proba = np.zeros((n_samples, self.n_classes))
proba[np.arange(n_samples), predictions] = 1
return proba
# Generate sample data
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(
n_samples=200,
n_features=4,
n_informative=3,
n_redundant=1,
n_classes=2,
random_state=42
)
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
# Train decision tree
tree = DecisionTreeClassifier(
max_depth=5,
min_samples_split=10,
min_samples_leaf=5
)
tree.fit(X_train, y_train)
# Make predictions
y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Accuracy: {accuracy:.2%}") # Output: ~85-95%
class DecisionTreeRegressor:
"""Decision Tree Regressor using MSE."""
def __init__(
self,
max_depth=None,
min_samples_split=2,
min_samples_leaf=1
):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.root = None
def _mse(self, y):
"""Calculate Mean Squared Error."""
if len(y) == 0:
return 0
mean_y = np.mean(y)
return np.mean((y - mean_y) ** 2)
def _best_split(self, X, y):
"""Find best split for regression."""
n_samples, n_features = X.shape
if n_samples <= 1:
return None, None
parent_mse = self._mse(y)
best_gain = 0.0
best_feature = None
best_threshold = None
for feature_idx in range(n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
if (np.sum(left_mask) < self.min_samples_leaf or
np.sum(right_mask) < self.min_samples_leaf):
continue
n_left = np.sum(left_mask)
n_right = np.sum(right_mask)
left_mse = self._mse(y[left_mask])
right_mse = self._mse(y[right_mask])
weighted_mse = (
(n_left / n_samples) * left_mse +
(n_right / n_samples) * right_mse
)
gain = parent_mse - weighted_mse
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
return best_feature, best_threshold
def _grow_tree(self, X, y, depth=0):
"""Recursively grow regression tree."""
n_samples = len(y)
# Leaf value is mean of target values
predicted_value = np.mean(y)
node = Node(value=predicted_value, n_samples=n_samples)
# Stopping criteria
if (depth >= self.max_depth if self.max_depth else False):
return node
if n_samples < self.min_samples_split:
return node
# Find best split
feature_idx, threshold = self._best_split(X, y)
if feature_idx is None:
return node
# Split and recurse
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
node.feature_index = feature_idx
node.threshold = threshold
node.left = self._grow_tree(X[left_mask], y[left_mask], depth + 1)
node.right = self._grow_tree(X[right_mask], y[right_mask], depth + 1)
return node
def fit(self, X, y):
"""Build regression tree."""
self.root = self._grow_tree(X, y)
return self
def _traverse_tree(self, x, node):
"""Traverse tree for prediction."""
if node.feature_index is None:
return node.value
if x[node.feature_index] <= node.threshold:
return self._traverse_tree(x, node.left)
else:
return self._traverse_tree(x, node.right)
def predict(self, X):
"""Predict continuous values."""
return np.array([self._traverse_tree(x, self.root) for x in X])
1. Interpretability
2. No Data Preprocessing
3. Non-Linear Relationships
4. Fast Prediction
5. Feature Selection
6. Handles Missing Values
1. High Variance (Overfitting)
2. Greedy Learning
3. Bias Toward Features with Many Values
4. Difficulty with XOR and Diagonal Boundaries
5. Unstable
6. Biased with Imbalanced Data
7. Extrapolation Problems (Regression)
Good For:
Not Ideal For:
While a single decision tree may seem simple compared to modern deep learning models, understanding decision trees is crucial because they:
Offer unmatched interpretability in many domains where explainability is legally or ethically required
Decision trees are not just another algorithm, they’re a fundamental way of thinking about how machines can learn to make decisions. Master them, and you’ll have intuition that carries through to the most advanced ensemble methods used in industry today.
For hands-on practice, check out the companion notebooks - Decision Trees Tutorial