在机器学习入门阶段,许多教程会直接给出Fisher线性判别的代码实现,却很少解释为什么这个算法能有效分类。今天我们将从几何空间的角度,一步步推导Fisher判别的数学原理,并用NumPy手动实现整个过程。不同于简单地调用sklearn,我们将深入算法内核,理解如何通过投影优化找到最佳分类方向。
想象一个三维空间中的两类数据点,我们希望找到一个二维平面,使得投影后的数据点能够最好地被区分。Fisher判别的核心思想就是找到这样一个投影方向,使得不同类别的数据在投影后尽可能分开,而同一类别的数据尽可能聚集。
类间散度(Between-class scatter)衡量的是不同类别均值之间的距离:
python复制# 计算两类均值向量的差
m_diff = m1 - m2
# 类间散度矩阵
S_B = np.outer(m_diff, m_diff.T)
类内散度(Within-class scatter)则衡量同一类别数据的离散程度:
python复制# 对每个样本计算与均值的偏差
deviation = X_class1 - m1
# 类内散度矩阵
S_W = np.dot(deviation.T, deviation)
Fisher准则函数可以表示为:
code复制J(w) = (wᵀS_B w) / (wᵢS_W w)
这个比值越大,表示投影方向w的分类效果越好。我们需要找到使J(w)最大化的w。
通过拉格朗日乘数法,我们可以证明最优投影方向w满足:
code复制S_W⁻¹ S_B w = λw
这实际上是一个广义特征值问题。对于两类情况,可以简化为:
python复制# 计算最优投影方向
w = np.linalg.inv(S_W).dot(m1 - m2)
为了验证这个结果,我们可以计算投影后的数据:
python复制# 将数据投影到w方向
projected_data = np.dot(X, w)
关键理解点:
让我们用经典的鸢尾花数据集来实践这个算法。首先加载并准备数据:
python复制from sklearn.datasets import load_iris
import numpy as np
iris = load_iris()
X = iris.data
y = iris.target
# 只取前两类(setosa和versicolor)做二分类
X = X[y != 2]
y = y[y != 2]
计算类别统计量:
python复制# 分离两类数据
class0 = X[y == 0]
class1 = X[y == 1]
# 计算均值向量
m0 = np.mean(class0, axis=0)
m1 = np.mean(class1, axis=0)
# 计算类内散度矩阵
S_W = np.zeros((4, 4))
for x in class0:
S_W += np.outer(x - m0, x - m0)
for x in class1:
S_W += np.outer(x - m1, x - m1)
求解投影方向并进行分类:
python复制# 计算最优投影方向
w = np.linalg.inv(S_W).dot(m0 - m1)
# 计算投影后的数据
projected = np.dot(X, w)
# 计算分类阈值
w0 = -0.5 * (m0 + m1).dot(w)
将高维数据投影到一维空间后,我们可以直观地看到分类效果:
python复制import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.scatter(projected[y == 0], np.zeros_like(projected[y == 0]),
color='red', label='Setosa')
plt.scatter(projected[y == 1], np.zeros_like(projected[y == 1]),
color='blue', label='Versicolor')
plt.axvline(x=-w0, color='green', linestyle='--', label='Decision Boundary')
plt.legend()
plt.title('Fisher Projection of Iris Data')
plt.show()
评估分类准确率:
python复制predictions = (np.dot(X, w) + w0) > 0
accuracy = np.mean(predictions == (y == 0))
print(f"Classification accuracy: {accuracy:.2%}")
虽然我们以二分类为例,但Fisher判别可以扩展到多类情况。关键变化在于类间散度矩阵的定义:
python复制# 多类情况下的类间散度
overall_mean = np.mean(X, axis=0)
S_B = np.zeros((4, 4))
for c in np.unique(y):
class_mean = np.mean(X[y == c], axis=0)
n_c = np.sum(y == c)
S_B += n_c * np.outer(class_mean - overall_mean,
class_mean - overall_mean)
实用建议:
Fisher判别与逻辑回归有着深刻的联系。两者都可以看作线性分类器,但优化目标不同:
| 方法 | 优化目标 | 假设条件 |
|---|---|---|
| Fisher判别 | 最大化类间/类内散度比 | 各类协方差矩阵相同 |
| 逻辑回归 | 最大化似然函数 | 无分布假设 |
在实际项目中,可以两种方法都尝试,比较它们的性能差异。