Classification — KNN, Naive Bayes, Decision Trees

🤖

AI Disclosure: This post was written by Claude Opus 4.6. References to “I” refer to the AI author, not the site owner.

AI edit history
DateModelAction
2026-03-25Claude Opus 4.6authored
🎯 What You Will Learn
  • Understand how KNN, Naive Bayes, and Decision Trees classify data
  • Implement all three algorithms with scikit-learn
  • Compare classifier performance on the same dataset
  • Extract and interpret feature importance from Decision Trees
  • Know when to reach for which algorithm in practice
📋 Prerequisites
link:/posts/what-is-machine-learning/[Part 1: What Is Machine Learning?] — supervised learning, features, labels. + link:/posts/data-preprocessing-and-evaluation/[Part 2: Data Pre-processing and Evaluation] — train/test splits, scaling, evaluation metrics. + Part 3: Python ML Toolkit — working Python environment with scikit-learn, pandas, and numpy.

The Classification Problem

You have a log line. Is it an error, a warning, or informational? You have server metrics from the last hour. Will this machine fail in the next 24 hours? You have a network packet. Is it normal traffic or anomalous?

These are all classification problems — given a set of features, assign a category. In Part 1 we saw that classification is a core supervised learning task. Now we build real classifiers.

We will cover three algorithms that approach the problem in fundamentally different ways:

  • K-Nearest Neighbours — looks at what is nearby

  • Naive Bayes — calculates probabilities

  • Decision Trees — asks a series of yes/no questions

Each has strengths and blind spots. By the end of this post, you will know which to reach for and why.

The Dataset: Classifying Log Entries

Throughout this post, we will use a synthetic log classification dataset. Each entry has numerical features extracted from a log line, and the label is the severity level: error, warning, or info.

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

np.random.seed(42)
n_samples = 600

# Generate synthetic log features
data = {
    'response_time_ms': np.concatenate([
        np.random.normal(50, 15, 200),     # info: fast responses
        np.random.normal(200, 50, 200),    # warning: slow responses
        np.random.normal(500, 100, 200),   # error: very slow or timeouts
    ]),
    'error_code_flag': np.concatenate([
        np.random.binomial(1, 0.05, 200),  # info: rarely has error codes
        np.random.binomial(1, 0.3, 200),   # warning: sometimes
        np.random.binomial(1, 0.85, 200),  # error: almost always
    ]),
    'payload_size_kb': np.concatenate([
        np.random.normal(2, 0.5, 200),     # info: small payloads
        np.random.normal(8, 3, 200),       # warning: medium
        np.random.normal(25, 10, 200),     # error: large (stack traces, dumps)
    ]),
    'retry_count': np.concatenate([
        np.random.poisson(0, 200),         # info: no retries
        np.random.poisson(1, 200),         # warning: occasional retries
        np.random.poisson(4, 200),         # error: many retries
    ]),
    'severity': ['info'] * 200 + ['warning'] * 200 + ['error'] * 200,
}

df = pd.DataFrame(data)
df['response_time_ms'] = df['response_time_ms'].clip(lower=0)
df['payload_size_kb'] = df['payload_size_kb'].clip(lower=0)

print(df.groupby('severity').mean().round(2))

# Split
X = df.drop(columns=['severity'])
y = df['severity']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"\nTrain: {len(X_train)}, Test: {len(X_test)}")
print(f"Class distribution:\n{y_train.value_counts()}")

Four features, three classes, 600 samples. Simple enough to understand, complex enough that the algorithms behave differently.

K-Nearest Neighbours (KNN)

How It Works

KNN is the most intuitive classifier: to classify a new data point, find the K closest points in the training data and let them vote.

If K=5 and three of the five nearest neighbours are error, one is warning, and one is info, KNN predicts error. That is the entire algorithm.

"Closest" means Euclidean distance by default:

distance = sqrt((x1 - x2)² + (y1 - y2)² + ... + (xn - xn)²)

This has an important implication: features must be scaled. If response time ranges from 0 to 600 and retry count ranges from 0 to 10, response time will dominate the distance calculation. Standardisation fixes this.

Implementation

from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report

# KNN needs scaled features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train KNN
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train_scaled, y_train)

# Evaluate
y_pred_knn = knn.predict(X_test_scaled)
print("KNN (K=5) Results:")
print(classification_report(y_test, y_pred_knn))

Choosing K

K is a hyperparameter — you choose it, the algorithm does not learn it. The choice matters:

  • K too small (e.g. K=1) — the model is sensitive to noise. A single mislabelled neighbour changes the prediction.

  • K too large (e.g. K=100) — the model smooths over real patterns. Class boundaries blur.

Find the best K by testing a range:

from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt

k_range = range(1, 31)
k_scores = []

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X_train_scaled, y_train, cv=5, scoring='accuracy')
    k_scores.append(scores.mean())

plt.plot(k_range, k_scores, marker='o', markersize=3)
plt.xlabel('K')
plt.ylabel('Cross-validation accuracy')
plt.title('KNN: Accuracy vs K')
plt.grid(True, alpha=0.3)
plt.show()

best_k = k_range[np.argmax(k_scores)]
print(f"Best K: {best_k}, Accuracy: {max(k_scores):.3f}")

Strengths and Weaknesses

StrengthsWeaknesses

No training phase — it just stores the data

Slow at prediction time (must scan all training points)

Works well with small datasets

Suffers from the curse of dimensionality (many features)

Naturally handles multi-class problems

Requires feature scaling

Easy to understand and debug

No feature importance — it is a black box at decision time

KNN works well when you have a moderate dataset, few features, and classes that cluster in feature space. For a fleet of 50 servers with 4–5 metrics each, KNN is a reasonable first choice. For millions of log lines with hundreds of features, it is too slow.

Naive Bayes

Bayes' Theorem Intuition

Naive Bayes flips the question. Instead of "which class is closest?", it asks: "given these features, what is the probability of each class?"

Bayes' theorem:

P(class | features) = P(features | class) × P(class) / P(features)

In plain English: the probability that a log entry is an error, given its features, equals the probability of seeing those features in error logs, times the base rate of errors, divided by the overall probability of those features.

The "naive" part is the assumption that all features are conditionally independent — that response time and retry count do not influence each other, given the class. This is almost never true in practice, but Naive Bayes works surprisingly well despite the violation.

Why It Works Well for Text and Logs

Naive Bayes excels when:

  • Features are high-dimensional and sparse — like word counts in log messages or TF-IDF vectors in text

  • Classes have distinct feature distributions — error logs tend to contain words like "exception", "timeout", "traceback"

  • Training data is limited — Naive Bayes needs far less data than most classifiers

This makes it a natural fit for log classification, email filtering, and alert categorisation.

Implementation

from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import classification_report

# Gaussian Naive Bayes assumes features follow a normal distribution
nb = GaussianNB()
nb.fit(X_train, y_train)  # no scaling needed

y_pred_nb = nb.predict(X_test)
print("Naive Bayes Results:")
print(classification_report(y_test, y_pred_nb))

Notice: no scaling required. Naive Bayes works on raw probabilities, not distances. One less thing to worry about.

Naive Bayes for Text-Based Log Classification

Where Naive Bayes really shines is when you classify based on the log message itself, not just extracted numerical features:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report

# Raw log messages
log_messages = [
    "GET /api/health 200 OK 12ms",
    "Connection timeout after 30000ms to db-primary",
    "WARNING: disk usage at 87% on /dev/sda1",
    "NullPointerException in AuthService.validate()",
    "POST /api/users 201 Created 45ms",
    "WARNING: SSL certificate expires in 7 days",
    "FATAL: out of memory, cannot allocate 256MB",
    "GET /static/logo.png 304 Not Modified 2ms",
    "WARNING: high latency detected on eth0 (>100ms)",
    "SocketTimeoutException: connect timed out to redis-01",
    "GET /api/metrics 200 OK 8ms",
    "WARNING: connection pool near capacity (45/50)",
]
labels = [
    'info', 'error', 'warning', 'error', 'info', 'warning',
    'error', 'info', 'warning', 'error', 'info', 'warning',
]

# Convert text to TF-IDF features
vectoriser = TfidfVectorizer(lowercase=True, stop_words='english')
X_text = vectoriser.fit_transform(log_messages)

# Train
nb_text = MultinomialNB()
nb_text.fit(X_text, labels)

# Classify new logs
new_logs = [
    "CRITICAL: database replication lag exceeds 60s",
    "GET /api/orders 200 OK 22ms",
    "WARNING: swap usage at 72%",
]
X_new = vectoriser.transform(new_logs)
predictions = nb_text.predict(X_new)

for log, pred in zip(new_logs, predictions):
    print(f"[{pred:>7}] {log}")

Multinomial Naive Bayes with TF-IDF is a fast, reliable baseline for any text classification task. It trains in milliseconds and often performs well enough that you do not need a more complex model.

Decision Trees

How Splits Work

A Decision Tree learns a series of if/else rules from the data. At each node, it picks the feature and threshold that best separates the classes.

Consider classifying log entries. The tree might learn:

Is response_time_ms > 300?
├── Yes: Is error_code_flag > 0.5?
│   ├── Yes → error
│   └── No → warning
└── No: Is retry_count > 1.5?
    ├── Yes → warning
    └── No → info

The algorithm tries every possible split and picks the one that creates the purest child nodes.

Gini Impurity vs Information Gain

Two common criteria for measuring split quality:

Gini impurity — measures how often a randomly chosen sample would be misclassified:

Gini = 1 - Σ(pᵢ²)

A pure node (all one class) has Gini = 0. A node with equal representation of three classes has Gini = 0.667.

Information gain (entropy-based) — measures the reduction in uncertainty:

Entropy = -Σ(pᵢ × log₂(pᵢ))

In practice, both criteria produce similar trees. Scikit-learn defaults to Gini.

Implementation

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

# Train a Decision Tree
dt = DecisionTreeClassifier(
    max_depth=5,          # prevent overfitting
    min_samples_split=10, # need at least 10 samples to split
    random_state=42,
)
dt.fit(X_train, y_train)  # no scaling needed

y_pred_dt = dt.predict(X_test)
print("Decision Tree Results:")
print(classification_report(y_test, y_pred_dt))

Like Naive Bayes, Decision Trees do not need feature scaling. They only care about feature ordering, not magnitude.

Visualising the Tree

One of the biggest advantages of Decision Trees is interpretability. You can see exactly what the model learned:

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 10))
plot_tree(
    dt,
    feature_names=X.columns.tolist(),
    class_names=dt.classes_.tolist(),
    filled=True,
    rounded=True,
    fontsize=9,
)
plt.title('Decision Tree — Log Severity Classification')
plt.tight_layout()
plt.show()

This produces a visual tree where each node shows the split condition, the Gini impurity, the sample count, and the class distribution. You can hand this to someone who has never heard of machine learning and they will understand what the model does.

Pruning to Prevent Overfitting

An unrestricted Decision Tree will grow until every leaf is pure — it memorises the training data perfectly, including noise. This is overfitting in its most visible form.

# Unrestricted tree — will overfit
dt_overfit = DecisionTreeClassifier(random_state=42)
dt_overfit.fit(X_train, y_train)

print(f"Unrestricted tree depth: {dt_overfit.get_depth()}")
print(f"  Train accuracy: {dt_overfit.score(X_train, y_train):.3f}")
print(f"  Test accuracy:  {dt_overfit.score(X_test, y_test):.3f}")

# Pruned tree
dt_pruned = DecisionTreeClassifier(
    max_depth=4,
    min_samples_leaf=5,
    random_state=42,
)
dt_pruned.fit(X_train, y_train)

print(f"\nPruned tree depth: {dt_pruned.get_depth()}")
print(f"  Train accuracy: {dt_pruned.score(X_train, y_train):.3f}")
print(f"  Test accuracy:  {dt_pruned.score(X_test, y_test):.3f}")

Key pruning parameters:

ParameterEffect

max_depth

Maximum depth of the tree. The most important control.

min_samples_split

Minimum samples required to split a node. Prevents splits on tiny groups.

min_samples_leaf

Minimum samples in a leaf node. Ensures predictions are backed by enough data.

max_features

Number of features to consider at each split. Adds randomness, reduces overfitting.

Feature Importance

Decision Trees tell you which features matter most. Importance is measured by how much each feature contributes to reducing impurity across all splits.

import pandas as pd

# Feature importance from the trained tree
importances = pd.Series(
    dt.feature_importances_,
    index=X.columns
).sort_values(ascending=False)

print("Feature importance:")
print(importances.round(4))

# Visualise
importances.plot(kind='barh', color='#4ae68a')
plt.xlabel('Importance')
plt.title('Decision Tree Feature Importance')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

Feature importance is one of the most practical outputs of a Decision Tree. If response_time_ms has importance 0.55 and payload_size_kb has 0.05, you know where to focus your monitoring.

Comparing All Three Classifiers

Now we put them head to head on the same data:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Scale for KNN
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Define models
models = {
    'KNN (K=7)': KNeighborsClassifier(n_neighbors=7),
    'Naive Bayes': GaussianNB(),
    'Decision Tree': DecisionTreeClassifier(max_depth=5, random_state=42),
}

# Train and evaluate each
results = {}
for name, model in models.items():
    # KNN uses scaled data, others use raw
    if 'KNN' in name:
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5)
    else:
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        cv_scores = cross_val_score(model, X_train, y_train, cv=5)

    results[name] = {
        'predictions': y_pred,
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std(),
    }

    print(f"\n{'='*50}")
    print(f"{name}")
    print(f"{'='*50}")
    print(f"CV Accuracy: {cv_scores.mean():.3f} +/- {cv_scores.std():.3f}")
    print(classification_report(y_test, y_pred))

# Side-by-side confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
class_names = ['error', 'info', 'warning']

for ax, (name, result) in zip(axes, results.items()):
    cm = confusion_matrix(y_test, result['predictions'], labels=class_names)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
                xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title(f'{name}\nCV: {result["cv_mean"]:.3f}')

plt.tight_layout()
plt.show()

Which Performs Best and Why

The results will vary with your data, but here is the general pattern for infrastructure classification tasks:

  • Decision Trees tend to win when features have clear thresholds — response_time > 300 is a natural split point for severity classification.

  • KNN performs well when classes form distinct clusters in feature space, but slows down with more data.

  • Naive Bayes holds its own when features are reasonably independent and often outperforms on text-based classification.

The key insight: there is no universally best algorithm. The right choice depends on your data, your features, and your constraints.

Practical Example: Predicting Server Failures

Let us apply all three to a different problem — predicting whether a server will fail in the next 24 hours based on current metrics.

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

np.random.seed(42)
n = 500

# Simulated server health data
server_data = pd.DataFrame({
    'cpu_avg': np.concatenate([
        np.random.normal(45, 12, 350),   # healthy
        np.random.normal(88, 8, 150),    # will fail
    ]),
    'mem_pct': np.concatenate([
        np.random.normal(55, 15, 350),
        np.random.normal(90, 5, 150),
    ]),
    'disk_io_ops': np.concatenate([
        np.random.normal(150, 50, 350),
        np.random.normal(480, 80, 150),
    ]),
    'error_rate_per_min': np.concatenate([
        np.random.poisson(2, 350),
        np.random.poisson(25, 150),
    ]),
    'open_connections': np.concatenate([
        np.random.normal(120, 30, 350),
        np.random.normal(380, 60, 150),
    ]),
    'will_fail': [0] * 350 + [1] * 150,
})

X = server_data.drop(columns=['will_fail'])
y = server_data['will_fail']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Compare all three
classifiers = {
    'KNN': (KNeighborsClassifier(n_neighbors=5), X_train_scaled, X_test_scaled),
    'Naive Bayes': (GaussianNB(), X_train, X_test),
    'Decision Tree': (DecisionTreeClassifier(max_depth=4, random_state=42), X_train, X_test),
}

for name, (model, X_tr, X_te) in classifiers.items():
    model.fit(X_tr, y_train)
    y_pred = model.predict(X_te)
    print(f"\n{name}:")
    print(classification_report(y_test, y_pred, target_names=['healthy', 'will_fail']))

For server failure prediction, recall on the "will_fail" class is the metric that matters. Missing a failure (false negative) is far worse than a false alarm. Check which classifier catches the most failures, not just overall accuracy.

Detecting Anomalous Network Traffic

Classification also applies to security. Here we flag network connections as normal or anomalous:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

np.random.seed(42)

# Simulated network connection features
n_normal, n_anomalous = 800, 200

network_data = pd.DataFrame({
    'bytes_sent': np.concatenate([
        np.random.lognormal(8, 1, n_normal),     # normal: variable
        np.random.lognormal(12, 0.5, n_anomalous), # anomalous: large transfers
    ]),
    'bytes_received': np.concatenate([
        np.random.lognormal(9, 1, n_normal),
        np.random.lognormal(6, 2, n_anomalous),  # anomalous: small responses
    ]),
    'duration_sec': np.concatenate([
        np.random.exponential(5, n_normal),       # normal: short connections
        np.random.exponential(120, n_anomalous),  # anomalous: long-lived
    ]),
    'unique_ports': np.concatenate([
        np.random.poisson(2, n_normal),           # normal: few ports
        np.random.poisson(15, n_anomalous),       # anomalous: port scanning
    ]),
    'is_anomalous': [0] * n_normal + [1] * n_anomalous,
})

X = network_data.drop(columns=['is_anomalous'])
y = network_data['is_anomalous']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Decision Tree is a good fit — interpretable rules for security review
dt = DecisionTreeClassifier(max_depth=4, random_state=42)
dt.fit(X_train, y_train)

print("Network Anomaly Detection:")
print(classification_report(y_test, y_pred, target_names=['normal', 'anomalous']))

# Show the rules the tree learned
importances = pd.Series(dt.feature_importances_, index=X.columns)
print("\nFeature importance:")
print(importances.sort_values(ascending=False).round(4))

For security applications, Decision Trees have a major advantage: you can show the learned rules to a security team. "Connections lasting over 60 seconds with more than 8 unique ports are flagged as anomalous" is something a human can validate and trust.

When to Use Which Algorithm

FactorKNNNaive BayesDecision Tree

Training speed

No training (lazy learner)

Very fast

Fast

Prediction speed

Slow (scans all data)

Very fast

Very fast

Needs scaling?

Yes

No

No

Handles many features?

Poorly (curse of dimensionality)

Well (especially text)

Moderately

Interpretable?

Limited

Moderate (probabilities)

Excellent (visual rules)

Handles non-linear boundaries?

Yes (with small K)

Limited (assumes distributions)

Yes (axis-aligned splits)

Data size sweet spot

Small to medium (<10K samples)

Any size

Any size

Best for

Quick prototyping, small datasets, anomaly detection

Text classification, log parsing, high-dimensional sparse data

Structured tabular data, interpretable models, feature selection

Rules of Thumb

Start with a Decision Tree when you have tabular data with clear features. It gives you interpretability and feature importance for free.

Reach for Naive Bayes when you are classifying text — log messages, alert descriptions, ticket categories. Also consider it when you need a fast baseline with minimal tuning.

Use KNN for quick prototyping or when you have a small, well-defined feature space. Avoid it when prediction latency matters or when you have more than a dozen features.

In all cases, try more than one algorithm on your data. The "best" choice often surprises you.

Next

Part 5: Regression — Linear, Polynomial, and Regularisation — predicting continuous values like CPU usage, disk fill time, and request latency.

📚 Resources

Videos:

  • StatQuest — KNN — clear, visual explanation of K-Nearest Neighbours.
  • StatQuest — Naive Bayes — Bayes’ theorem applied to classification, step by step.
  • StatQuest — Decision Trees — how Gini impurity and information gain drive splits.
  • 3Blue1Brown — Bayes theorem — the geometry of changing beliefs (complements Naive Bayes).

Reading:

  • Scikit-learn — Nearest Neighbours — official KNN documentation.
  • Scikit-learn — Decision Trees — tree algorithms, pruning, and visualisation.

Companion: link:/posts/why-maths-for-machine-learning/[Maths for ML] — Part 6 covers Bayes’ theorem from first principles, Part 8 covers information gain.

🔬 Try It Yourself

1. Classify your own logs. Export a sample of logs from your system. Extract numerical features (response time, status code, message length). Train all three classifiers and compare. Which performs best on your data?

2. Tune KNN. Using the log dataset from this post, find the optimal K using cross-validation. Plot accuracy vs K. What happens at K=1? At K=100?

3. Overfit and prune. Train a Decision Tree with no depth limit on the server failure dataset. What is the tree depth? What is the train/test accuracy gap? Then restrict max_depth to 3, 5, and 7 — which gives the best test accuracy?

4. Text classification pipeline. Collect 50–100 log lines from a real system. Manually label them by severity. Build a Naive Bayes classifier with TF-IDF vectorisation. What accuracy do you get? Which log messages does it misclassify?

5. Feature importance audit. Train a Decision Tree on your monitoring metrics. Which features have the highest importance? Does this match your intuition about what drives failures in your environment?

Next in series Regression and Decision Boundaries

Part 5 of the ML Fundamentals series. Linear and polynomial regression, Ridge and Lasso regularisation, logistic regression, the Perceptron, and visualising decision boundaries — all with infrastructure examples.

Comments

Loading comments...