1. 分类任务概述
在机器学习领域,分类任务是监督学习中最常见的任务类型之一。与回归任务预测连续值不同,分类任务的目标是预测离散的类别标签。本章我们将重点探讨如何使用各种算法解决分类问题,并以经典的MNIST手写数字数据集作为案例进行详细讲解。
2. MNIST数据集解析
2.1 数据集介绍
MNIST数据集包含70,000张28×28像素的灰度手写数字图像,每张图像对应0-9中的一个数字标签。这个数据集由美国高中生和人口调查局职员手写而成,已成为机器学习领域的"Hello World"基准测试集。
python复制from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"]
2.2 数据探索与可视化
让我们查看数据的基本结构和可视化部分样本:
python复制print(X.shape) # (70000, 784)
print(y.shape) # (70000,)
import matplotlib.pyplot as plt
some_digit = X[0].reshape(28, 28)
plt.imshow(some_digit, cmap="binary")
plt.axis("off")
plt.show()
print(y[0]) # 查看对应标签
2.3 训练集与测试集划分
MNIST已预先分为60,000张训练图像和10,000张测试图像。为确保数据分布均匀,我们需要打乱训练集:
python复制X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
3. 二元分类器实现
3.1 构建5检测器
我们先简化问题,构建一个仅区分数字5和非5的二元分类器:
python复制y_train_5 = (y_train == '5') # True for 5s, False otherwise
y_test_5 = (y_test == '5')
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
sgd_clf.predict([some_digit]) # 测试预测
3.2 性能评估方法
3.2.1 交叉验证实现
我们可以手动实现交叉验证来更好地控制过程:
python复制from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
skfolds = StratifiedKFold(n_splits=3, random_state=42)
for train_index, test_index in skfolds.split(X_train, y_train_5):
clone_clf = clone(sgd_clf)
X_train_folds = X_train[train_index]
y_train_folds = y_train_5[train_index]
X_test_fold = X_train[test_index]
y_test_fold = y_train_5[test_index]
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
3.2.2 使用cross_val_score
更简便的方法是使用Scikit-Learn的cross_val_score:
python复制from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
注意:高准确率可能具有误导性。例如,一个总是预测"非5"的简单分类器在MNIST上也能达到约90%的准确率,因为只有约10%的图像是数字5。
3.3 混淆矩阵分析
混淆矩阵提供了更全面的性能评估:
python复制from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
confusion_matrix(y_train_5, y_train_pred)
矩阵输出示例:
code复制[[53892, 687]
[ 1891, 3530]]
3.4 精确率与召回率
精确率和召回率提供了更细致的评估:
python复制from sklearn.metrics import precision_score, recall_score
precision = precision_score(y_train_5, y_train_pred) # TP / (TP + FP)
recall = recall_score(y_train_5, y_train_pred) # TP / (TP + FN)
F1分数是精确率和召回率的调和平均:
python复制from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)
3.5 精确率-召回率权衡
通过调整决策阈值可以控制精确率和召回率:
python复制y_scores = sgd_clf.decision_function([some_digit])
threshold = 0
y_some_digit_pred = (y_scores > threshold)
# 可视化权衡曲线
from sklearn.metrics import precision_recall_curve
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0, 1])
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()
3.6 ROC曲线分析
ROC曲线是另一种评估二元分类器的工具:
python复制from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--')
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plot_roc_curve(fpr, tpr)
plt.show()
from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
4. 多元分类实现
4.1 多类分类策略
对于多类分类问题,常用的策略有:
- OvA(One-versus-All):为每个类训练一个二元分类器
- OvO(One-versus-One):为每对类训练一个分类器
Scikit-Learn会自动检测并采用OvA策略(SVM除外):
python复制sgd_clf.fit(X_train, y_train) # 使用原始标签(0-9)
sgd_clf.predict([some_digit])
4.2 随机森林分类器
随机森林天然支持多类分类:
python复制from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
forest_clf.fit(X_train, y_train)
forest_clf.predict([some_digit])
forest_clf.predict_proba([some_digit]) # 查看各类概率
4.3 性能评估
使用交叉验证评估分类器性能:
python复制cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
# 通过特征缩放提高性能
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
5. 误差分析与模型改进
5.1 混淆矩阵分析
通过混淆矩阵识别分类错误模式:
python复制y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
5.2 聚焦特定错误
分析特定数字间的混淆情况,如3和5:
python复制cl_a, cl_b = '3', '5'
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
5.3 改进建议
基于误差分析,可能的改进方向包括:
- 增加更多训练数据(特别是难以区分的数字)
- 设计新的特征(如闭合环的数量)
- 预处理图像(居中、去噪等)
- 尝试不同的分类算法和超参数
6. 多标签与多输出分类
6.1 多标签分类
当需要为每个实例预测多个二元标签时,使用多标签分类:
python复制from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train.astype(int) >= 7)
y_train_odd = (y_train.astype(int) % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit]) # 示例预测
6.2 多输出分类
多输出分类是多标签分类的泛化,其中每个标签可以是多类别的:
python复制# 创建带噪声的训练集
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
# 训练去噪模型
knn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict([X_test_mod[0]])
plot_digit(clean_digit)
7. 实践建议与总结
- 模型选择:对于简单分类任务,线性模型如SGDClassifier效率高;对于复杂模式,随机森林或K近邻可能更合适
- 评估指标:根据业务需求选择合适的评估指标(精确率、召回率、F1等)
- 误差分析:通过混淆矩阵识别主要错误类型,针对性改进
- 数据增强:通过平移、旋转等变换增加训练数据多样性
- 特征工程:设计领域特定的特征往往能显著提升性能
在实际项目中,建议遵循以下流程:
- 快速构建一个基线模型
- 通过交叉验证评估其性能
- 分析主要错误类型
- 针对性地改进模型和数据
- 反复迭代优化
通过本章内容,我们系统性地学习了分类任务的实现方法、评估技术和改进策略。这些知识将为后续更复杂的机器学习任务奠定坚实基础。