Classification — KNN, Naive Bayes, Decision Trees
AI Disclosure: This post was written by Claude Opus 4.6. References to “I” refer to the AI author, not the site owner.
AI edit history
| Date | Model | Action |
|---|---|---|
| 2026-03-25 | Claude Opus 4.6 | authored |
- Understand how KNN, Naive Bayes, and Decision Trees classify data
- Implement all three algorithms with scikit-learn
- Compare classifier performance on the same dataset
- Extract and interpret feature importance from Decision Trees
- Know when to reach for which algorithm in practice
The Classification Problem
You have a log line. Is it an error, a warning, or informational? You have server metrics from the last hour. Will this machine fail in the next 24 hours? You have a network packet. Is it normal traffic or anomalous?
These are all classification problems — given a set of features, assign a category. In Part 1 we saw that classification is a core supervised learning task. Now we build real classifiers.
We will cover three algorithms that approach the problem in fundamentally different ways:
K-Nearest Neighbours — looks at what is nearby
Naive Bayes — calculates probabilities
Decision Trees — asks a series of yes/no questions
Each has strengths and blind spots. By the end of this post, you will know which to reach for and why.
The Dataset: Classifying Log Entries
Throughout this post, we will use a synthetic log classification dataset. Each entry has numerical features extracted from a log line, and the label is the severity level: error, warning, or info.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
np.random.seed(42)
n_samples = 600
# Generate synthetic log features
data = {
'response_time_ms': np.concatenate([
np.random.normal(50, 15, 200), # info: fast responses
np.random.normal(200, 50, 200), # warning: slow responses
np.random.normal(500, 100, 200), # error: very slow or timeouts
]),
'error_code_flag': np.concatenate([
np.random.binomial(1, 0.05, 200), # info: rarely has error codes
np.random.binomial(1, 0.3, 200), # warning: sometimes
np.random.binomial(1, 0.85, 200), # error: almost always
]),
'payload_size_kb': np.concatenate([
np.random.normal(2, 0.5, 200), # info: small payloads
np.random.normal(8, 3, 200), # warning: medium
np.random.normal(25, 10, 200), # error: large (stack traces, dumps)
]),
'retry_count': np.concatenate([
np.random.poisson(0, 200), # info: no retries
np.random.poisson(1, 200), # warning: occasional retries
np.random.poisson(4, 200), # error: many retries
]),
'severity': ['info'] * 200 + ['warning'] * 200 + ['error'] * 200,
}
df = pd.DataFrame(data)
df['response_time_ms'] = df['response_time_ms'].clip(lower=0)
df['payload_size_kb'] = df['payload_size_kb'].clip(lower=0)
print(df.groupby('severity').mean().round(2))
# Split
X = df.drop(columns=['severity'])
y = df['severity']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"\nTrain: {len(X_train)}, Test: {len(X_test)}")
print(f"Class distribution:\n{y_train.value_counts()}")Four features, three classes, 600 samples. Simple enough to understand, complex enough that the algorithms behave differently.
K-Nearest Neighbours (KNN)
How It Works
KNN is the most intuitive classifier: to classify a new data point, find the K closest points in the training data and let them vote.
If K=5 and three of the five nearest neighbours are error, one is warning, and one is info, KNN predicts error. That is the entire algorithm.
"Closest" means Euclidean distance by default:
distance = sqrt((x1 - x2)² + (y1 - y2)² + ... + (xn - xn)²)This has an important implication: features must be scaled. If response time ranges from 0 to 600 and retry count ranges from 0 to 10, response time will dominate the distance calculation. Standardisation fixes this.
Implementation
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
# KNN needs scaled features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Train KNN
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train_scaled, y_train)
# Evaluate
y_pred_knn = knn.predict(X_test_scaled)
print("KNN (K=5) Results:")
print(classification_report(y_test, y_pred_knn))Choosing K
K is a hyperparameter — you choose it, the algorithm does not learn it. The choice matters:
K too small (e.g. K=1) — the model is sensitive to noise. A single mislabelled neighbour changes the prediction.
K too large (e.g. K=100) — the model smooths over real patterns. Class boundaries blur.
Find the best K by testing a range:
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
k_range = range(1, 31)
k_scores = []
for k in k_range:
knn = KNeighborsClassifier(n_neighbors=k)
scores = cross_val_score(knn, X_train_scaled, y_train, cv=5, scoring='accuracy')
k_scores.append(scores.mean())
plt.plot(k_range, k_scores, marker='o', markersize=3)
plt.xlabel('K')
plt.ylabel('Cross-validation accuracy')
plt.title('KNN: Accuracy vs K')
plt.grid(True, alpha=0.3)
plt.show()
best_k = k_range[np.argmax(k_scores)]
print(f"Best K: {best_k}, Accuracy: {max(k_scores):.3f}")Strengths and Weaknesses
| Strengths | Weaknesses |
|---|---|
No training phase — it just stores the data | Slow at prediction time (must scan all training points) |
Works well with small datasets | Suffers from the curse of dimensionality (many features) |
Naturally handles multi-class problems | Requires feature scaling |
Easy to understand and debug | No feature importance — it is a black box at decision time |
KNN works well when you have a moderate dataset, few features, and classes that cluster in feature space. For a fleet of 50 servers with 4–5 metrics each, KNN is a reasonable first choice. For millions of log lines with hundreds of features, it is too slow.
Naive Bayes
Bayes' Theorem Intuition
Naive Bayes flips the question. Instead of "which class is closest?", it asks: "given these features, what is the probability of each class?"
Bayes' theorem:
P(class | features) = P(features | class) × P(class) / P(features)In plain English: the probability that a log entry is an error, given its features, equals the probability of seeing those features in error logs, times the base rate of errors, divided by the overall probability of those features.
The "naive" part is the assumption that all features are conditionally independent — that response time and retry count do not influence each other, given the class. This is almost never true in practice, but Naive Bayes works surprisingly well despite the violation.
Why It Works Well for Text and Logs
Naive Bayes excels when:
Features are high-dimensional and sparse — like word counts in log messages or TF-IDF vectors in text
Classes have distinct feature distributions — error logs tend to contain words like "exception", "timeout", "traceback"
Training data is limited — Naive Bayes needs far less data than most classifiers
This makes it a natural fit for log classification, email filtering, and alert categorisation.
Implementation
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import classification_report
# Gaussian Naive Bayes assumes features follow a normal distribution
nb = GaussianNB()
nb.fit(X_train, y_train) # no scaling needed
y_pred_nb = nb.predict(X_test)
print("Naive Bayes Results:")
print(classification_report(y_test, y_pred_nb))Notice: no scaling required. Naive Bayes works on raw probabilities, not distances. One less thing to worry about.
Naive Bayes for Text-Based Log Classification
Where Naive Bayes really shines is when you classify based on the log message itself, not just extracted numerical features:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report
# Raw log messages
log_messages = [
"GET /api/health 200 OK 12ms",
"Connection timeout after 30000ms to db-primary",
"WARNING: disk usage at 87% on /dev/sda1",
"NullPointerException in AuthService.validate()",
"POST /api/users 201 Created 45ms",
"WARNING: SSL certificate expires in 7 days",
"FATAL: out of memory, cannot allocate 256MB",
"GET /static/logo.png 304 Not Modified 2ms",
"WARNING: high latency detected on eth0 (>100ms)",
"SocketTimeoutException: connect timed out to redis-01",
"GET /api/metrics 200 OK 8ms",
"WARNING: connection pool near capacity (45/50)",
]
labels = [
'info', 'error', 'warning', 'error', 'info', 'warning',
'error', 'info', 'warning', 'error', 'info', 'warning',
]
# Convert text to TF-IDF features
vectoriser = TfidfVectorizer(lowercase=True, stop_words='english')
X_text = vectoriser.fit_transform(log_messages)
# Train
nb_text = MultinomialNB()
nb_text.fit(X_text, labels)
# Classify new logs
new_logs = [
"CRITICAL: database replication lag exceeds 60s",
"GET /api/orders 200 OK 22ms",
"WARNING: swap usage at 72%",
]
X_new = vectoriser.transform(new_logs)
predictions = nb_text.predict(X_new)
for log, pred in zip(new_logs, predictions):
print(f"[{pred:>7}] {log}")Multinomial Naive Bayes with TF-IDF is a fast, reliable baseline for any text classification task. It trains in milliseconds and often performs well enough that you do not need a more complex model.
Decision Trees
How Splits Work
A Decision Tree learns a series of if/else rules from the data. At each node, it picks the feature and threshold that best separates the classes.
Consider classifying log entries. The tree might learn:
Is response_time_ms > 300?
├── Yes: Is error_code_flag > 0.5?
│ ├── Yes → error
│ └── No → warning
└── No: Is retry_count > 1.5?
├── Yes → warning
└── No → infoThe algorithm tries every possible split and picks the one that creates the purest child nodes.
Gini Impurity vs Information Gain
Two common criteria for measuring split quality:
Gini impurity — measures how often a randomly chosen sample would be misclassified:
Gini = 1 - Σ(pᵢ²)A pure node (all one class) has Gini = 0. A node with equal representation of three classes has Gini = 0.667.
Information gain (entropy-based) — measures the reduction in uncertainty:
Entropy = -Σ(pᵢ × log₂(pᵢ))In practice, both criteria produce similar trees. Scikit-learn defaults to Gini.
Implementation
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
# Train a Decision Tree
dt = DecisionTreeClassifier(
max_depth=5, # prevent overfitting
min_samples_split=10, # need at least 10 samples to split
random_state=42,
)
dt.fit(X_train, y_train) # no scaling needed
y_pred_dt = dt.predict(X_test)
print("Decision Tree Results:")
print(classification_report(y_test, y_pred_dt))Like Naive Bayes, Decision Trees do not need feature scaling. They only care about feature ordering, not magnitude.
Visualising the Tree
One of the biggest advantages of Decision Trees is interpretability. You can see exactly what the model learned:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 10))
plot_tree(
dt,
feature_names=X.columns.tolist(),
class_names=dt.classes_.tolist(),
filled=True,
rounded=True,
fontsize=9,
)
plt.title('Decision Tree — Log Severity Classification')
plt.tight_layout()
plt.show()This produces a visual tree where each node shows the split condition, the Gini impurity, the sample count, and the class distribution. You can hand this to someone who has never heard of machine learning and they will understand what the model does.
Pruning to Prevent Overfitting
An unrestricted Decision Tree will grow until every leaf is pure — it memorises the training data perfectly, including noise. This is overfitting in its most visible form.
# Unrestricted tree — will overfit
dt_overfit = DecisionTreeClassifier(random_state=42)
dt_overfit.fit(X_train, y_train)
print(f"Unrestricted tree depth: {dt_overfit.get_depth()}")
print(f" Train accuracy: {dt_overfit.score(X_train, y_train):.3f}")
print(f" Test accuracy: {dt_overfit.score(X_test, y_test):.3f}")
# Pruned tree
dt_pruned = DecisionTreeClassifier(
max_depth=4,
min_samples_leaf=5,
random_state=42,
)
dt_pruned.fit(X_train, y_train)
print(f"\nPruned tree depth: {dt_pruned.get_depth()}")
print(f" Train accuracy: {dt_pruned.score(X_train, y_train):.3f}")
print(f" Test accuracy: {dt_pruned.score(X_test, y_test):.3f}")Key pruning parameters:
| Parameter | Effect |
|---|---|
| Maximum depth of the tree. The most important control. |
| Minimum samples required to split a node. Prevents splits on tiny groups. |
| Minimum samples in a leaf node. Ensures predictions are backed by enough data. |
| Number of features to consider at each split. Adds randomness, reduces overfitting. |
Feature Importance
Decision Trees tell you which features matter most. Importance is measured by how much each feature contributes to reducing impurity across all splits.
import pandas as pd
# Feature importance from the trained tree
importances = pd.Series(
dt.feature_importances_,
index=X.columns
).sort_values(ascending=False)
print("Feature importance:")
print(importances.round(4))
# Visualise
importances.plot(kind='barh', color='#4ae68a')
plt.xlabel('Importance')
plt.title('Decision Tree Feature Importance')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()Feature importance is one of the most practical outputs of a Decision Tree. If response_time_ms has importance 0.55 and payload_size_kb has 0.05, you know where to focus your monitoring.
Comparing All Three Classifiers
Now we put them head to head on the same data:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Scale for KNN
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Define models
models = {
'KNN (K=7)': KNeighborsClassifier(n_neighbors=7),
'Naive Bayes': GaussianNB(),
'Decision Tree': DecisionTreeClassifier(max_depth=5, random_state=42),
}
# Train and evaluate each
results = {}
for name, model in models.items():
# KNN uses scaled data, others use raw
if 'KNN' in name:
model.fit(X_train_scaled, y_train)
y_pred = model.predict(X_test_scaled)
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5)
else:
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
results[name] = {
'predictions': y_pred,
'cv_mean': cv_scores.mean(),
'cv_std': cv_scores.std(),
}
print(f"\n{'='*50}")
print(f"{name}")
print(f"{'='*50}")
print(f"CV Accuracy: {cv_scores.mean():.3f} +/- {cv_scores.std():.3f}")
print(classification_report(y_test, y_pred))
# Side-by-side confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
class_names = ['error', 'info', 'warning']
for ax, (name, result) in zip(axes, results.items()):
cm = confusion_matrix(y_test, result['predictions'], labels=class_names)
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
xticklabels=class_names, yticklabels=class_names, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title(f'{name}\nCV: {result["cv_mean"]:.3f}')
plt.tight_layout()
plt.show()Which Performs Best and Why
The results will vary with your data, but here is the general pattern for infrastructure classification tasks:
Decision Trees tend to win when features have clear thresholds —
response_time > 300is a natural split point for severity classification.KNN performs well when classes form distinct clusters in feature space, but slows down with more data.
Naive Bayes holds its own when features are reasonably independent and often outperforms on text-based classification.
The key insight: there is no universally best algorithm. The right choice depends on your data, your features, and your constraints.
Practical Example: Predicting Server Failures
Let us apply all three to a different problem — predicting whether a server will fail in the next 24 hours based on current metrics.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
np.random.seed(42)
n = 500
# Simulated server health data
server_data = pd.DataFrame({
'cpu_avg': np.concatenate([
np.random.normal(45, 12, 350), # healthy
np.random.normal(88, 8, 150), # will fail
]),
'mem_pct': np.concatenate([
np.random.normal(55, 15, 350),
np.random.normal(90, 5, 150),
]),
'disk_io_ops': np.concatenate([
np.random.normal(150, 50, 350),
np.random.normal(480, 80, 150),
]),
'error_rate_per_min': np.concatenate([
np.random.poisson(2, 350),
np.random.poisson(25, 150),
]),
'open_connections': np.concatenate([
np.random.normal(120, 30, 350),
np.random.normal(380, 60, 150),
]),
'will_fail': [0] * 350 + [1] * 150,
})
X = server_data.drop(columns=['will_fail'])
y = server_data['will_fail']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Compare all three
classifiers = {
'KNN': (KNeighborsClassifier(n_neighbors=5), X_train_scaled, X_test_scaled),
'Naive Bayes': (GaussianNB(), X_train, X_test),
'Decision Tree': (DecisionTreeClassifier(max_depth=4, random_state=42), X_train, X_test),
}
for name, (model, X_tr, X_te) in classifiers.items():
model.fit(X_tr, y_train)
y_pred = model.predict(X_te)
print(f"\n{name}:")
print(classification_report(y_test, y_pred, target_names=['healthy', 'will_fail']))For server failure prediction, recall on the "will_fail" class is the metric that matters. Missing a failure (false negative) is far worse than a false alarm. Check which classifier catches the most failures, not just overall accuracy.
Detecting Anomalous Network Traffic
Classification also applies to security. Here we flag network connections as normal or anomalous:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
np.random.seed(42)
# Simulated network connection features
n_normal, n_anomalous = 800, 200
network_data = pd.DataFrame({
'bytes_sent': np.concatenate([
np.random.lognormal(8, 1, n_normal), # normal: variable
np.random.lognormal(12, 0.5, n_anomalous), # anomalous: large transfers
]),
'bytes_received': np.concatenate([
np.random.lognormal(9, 1, n_normal),
np.random.lognormal(6, 2, n_anomalous), # anomalous: small responses
]),
'duration_sec': np.concatenate([
np.random.exponential(5, n_normal), # normal: short connections
np.random.exponential(120, n_anomalous), # anomalous: long-lived
]),
'unique_ports': np.concatenate([
np.random.poisson(2, n_normal), # normal: few ports
np.random.poisson(15, n_anomalous), # anomalous: port scanning
]),
'is_anomalous': [0] * n_normal + [1] * n_anomalous,
})
X = network_data.drop(columns=['is_anomalous'])
y = network_data['is_anomalous']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Decision Tree is a good fit — interpretable rules for security review
dt = DecisionTreeClassifier(max_depth=4, random_state=42)
dt.fit(X_train, y_train)
print("Network Anomaly Detection:")
print(classification_report(y_test, y_pred, target_names=['normal', 'anomalous']))
# Show the rules the tree learned
importances = pd.Series(dt.feature_importances_, index=X.columns)
print("\nFeature importance:")
print(importances.sort_values(ascending=False).round(4))For security applications, Decision Trees have a major advantage: you can show the learned rules to a security team. "Connections lasting over 60 seconds with more than 8 unique ports are flagged as anomalous" is something a human can validate and trust.
When to Use Which Algorithm
| Factor | KNN | Naive Bayes | Decision Tree |
|---|---|---|---|
Training speed | No training (lazy learner) | Very fast | Fast |
Prediction speed | Slow (scans all data) | Very fast | Very fast |
Needs scaling? | Yes | No | No |
Handles many features? | Poorly (curse of dimensionality) | Well (especially text) | Moderately |
Interpretable? | Limited | Moderate (probabilities) | Excellent (visual rules) |
Handles non-linear boundaries? | Yes (with small K) | Limited (assumes distributions) | Yes (axis-aligned splits) |
Data size sweet spot | Small to medium (<10K samples) | Any size | Any size |
Best for | Quick prototyping, small datasets, anomaly detection | Text classification, log parsing, high-dimensional sparse data | Structured tabular data, interpretable models, feature selection |
Rules of Thumb
Start with a Decision Tree when you have tabular data with clear features. It gives you interpretability and feature importance for free.
Reach for Naive Bayes when you are classifying text — log messages, alert descriptions, ticket categories. Also consider it when you need a fast baseline with minimal tuning.
Use KNN for quick prototyping or when you have a small, well-defined feature space. Avoid it when prediction latency matters or when you have more than a dozen features.
In all cases, try more than one algorithm on your data. The "best" choice often surprises you.
Next
Part 5: Regression — Linear, Polynomial, and Regularisation — predicting continuous values like CPU usage, disk fill time, and request latency.
Videos:
- StatQuest — KNN — clear, visual explanation of K-Nearest Neighbours.
- StatQuest — Naive Bayes — Bayes’ theorem applied to classification, step by step.
- StatQuest — Decision Trees — how Gini impurity and information gain drive splits.
- 3Blue1Brown — Bayes theorem — the geometry of changing beliefs (complements Naive Bayes).
Reading:
- Scikit-learn — Nearest Neighbours — official KNN documentation.
- Scikit-learn — Decision Trees — tree algorithms, pruning, and visualisation.
Companion: link:/posts/why-maths-for-machine-learning/[Maths for ML] — Part 6 covers Bayes’ theorem from first principles, Part 8 covers information gain.
1. Classify your own logs. Export a sample of logs from your system. Extract numerical features (response time, status code, message length). Train all three classifiers and compare. Which performs best on your data?
2. Tune KNN. Using the log dataset from this post, find the optimal K using cross-validation. Plot accuracy vs K. What happens at K=1? At K=100?
3. Overfit and prune. Train a Decision Tree with no depth limit on the server failure dataset. What is the tree depth? What is the train/test accuracy gap? Then restrict max_depth to 3, 5, and 7 — which gives the best test accuracy?
4. Text classification pipeline. Collect 50–100 log lines from a real system. Manually label them by severity. Build a Naive Bayes classifier with TF-IDF vectorisation. What accuracy do you get? Which log messages does it misclassify?
5. Feature importance audit. Train a Decision Tree on your monitoring metrics. Which features have the highest importance? Does this match your intuition about what drives failures in your environment?
Comments
Loading comments...