一、前言:为何要学交叉验证与网格搜索?
大家好!在机器学习的道路上,我们经常面临一个难题:模型调参。比如在 KNN 算法中,选择多少个邻居(n_neighbors)直接影响预测效果。
• 蛮力猜测:就像在厨房随便“加盐加辣椒”,不仅费时费力,还可能把菜搞砸。
• 交叉验证 + 网格搜索:更像是让你请来一位“大厨”,提前试好所有配方,帮你挑选出最完美的“调料搭配”。
交叉验证与网格搜索的组合,能让你在众多超参数组合中自动挑选出最佳方案,从而让模型预测达到“哇塞,这也太准了吧!”的境界。
二、概念扫盲:交叉验证 & 网格搜索
1. 交叉验证(Cross-Validation)
核心思路:
• 分组品尝:将整个数据集平均分成若干份(比如分成 5 份,即“5折交叉验证”)。
• 轮流担任评委:每次选取其中一份作为“验证集”(就像让这部分数据来“评委打分”),剩下的作为“训练集”来训练模型。
• 集体评定:重复多次,每一份都轮流担任验证集,然后把所有“评分”取平均,作为模型在数据集上的最终表现。
好处:
• 每个样本都有机会既当“选手”又当“评委”,使得评估结果更稳定、可靠。
• 避免单一划分带来的偶然性,确保你调出来的参数在不同数据切分下都表现良好。
2. 网格搜索(Grid Search)
核心思路:
• 列出所有可能:将你想尝试的超参数组合“罗列成一个表格(网格)”。
• 自动试菜:每种组合都进行一次完整的模型训练和评估,记录下它们的表现。
• 选出最佳配方:最后找出在交叉验证中表现最好的超参数组合。
好处:
• 自动化、系统化地寻找最佳参数组合,避免你手动“胡乱猜测”。
• 和交叉验证结合后,每个参数组合都经过了多次评估,结果更稳健。
3. 网格搜索 + 交叉验证
这两者结合就像“炼丹”高手的秘诀:
• 交叉验证解决了“数据切分”的问题,让评估更准确;
• 网格搜索解决了“超参数组合”问题,帮你遍历所有可能性。
合体后,你就能轻松找到最优超参数,让模型发挥出最佳性能!
三、案例一:鸢尾花数据集 + KNN + 交叉验证网格搜索
3.1 数据集介绍
• 数据来源:scikit-learn 内置的 load_iris
• 特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度
• 目标:根据花的外部特征预测其所属的鸢尾花种类
3.2 代码示例
下面代码展示如何在鸢尾花数据集上使用 KNN 算法,并通过 GridSearchCV(交叉验证+网格搜索)自动调优 n_neighbors 参数:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
def iris_knn_cv():
"""
使用KNN算法在鸢尾花数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
"""
# 1. 加载数据
iris = load_iris()
X = iris.data # 特征矩阵,包含四个特征
y = iris.target # 标签,分别代表三种鸢尾花
# 2. 划分训练集和测试集
# test_size=0.2 表示 20% 的数据用于测试,保证测试结果具有代表性
# random_state=22 固定随机数种子,确保每次运行划分一致
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=22
)
# 3. 数据标准化
# 标准化可使各特征均值为0、方差为1,消除量纲影响(对于基于距离的KNN非常重要)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 4. 构建KNN模型及参数调优
knn = KNeighborsClassifier() # 初始化KNN模型
# 4.1 设置网格搜索参数范围:尝试不同的邻居数
param_grid = {
'n_neighbors': [1, 3, 5, 7, 9]
}
# 4.2 进行网格搜索 + 交叉验证(5折交叉验证)
grid_search = GridSearchCV(
estimator=knn, # 待调参的模型
param_grid=param_grid, # 超参数候选列表
cv=5, # 5折交叉验证:将训练集分为5个子集,每次用1个子集验证,其余4个训练
scoring='accuracy', # 以准确率作为评估指标
n_jobs=-1 # 使用所有CPU核心并行计算
)
grid_search.fit(X_train_scaled, y_train) # 自动遍历各参数组合并评估
# 4.3 输出网格搜索结果
print("最佳交叉验证分数:", grid_search.best_score_)
print("最优超参数组合:", grid_search.best_params_)
print("最优模型:", grid_search.best_estimator_)
# 5. 模型评估:用测试集评估最优模型的泛化能力
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test_scaled)
acc = accuracy_score(y_test, y_pred)
print("在测试集上的准确率:{:.2f}%".format(acc * 100))
# 6. 可视化(选做):可进一步绘制混淆矩阵或学习曲线
# 直接调用函数进行测试
if __name__ == "__main__":
iris_knn_cv()
输出:
3.3 结果解读
• 最佳交叉验证分数:表示在5折交叉验证过程中,所有参数组合中平均准确率最高的值。
• 最优超参数组合:显示在候选参数 [1, 3, 5, 7, 9] 中哪个 n_neighbors 的效果最好。
• 测试集准确率:验证模型在未见数据上的表现,反映其泛化能力。
通过这个案例,你可以看到交叉验证网格搜索如何自动帮你“挑菜”选料,让 KNN 模型在鸢尾花分类任务上达到最佳表现。
四、案例二:手写数字数据集 + KNN + 交叉验证网格搜索
4.1 数据集介绍
• 数据来源:scikit-learn 内置的 load_digits
• 特征:每张 8×8 像素的手写数字图像被拉伸成64维特征向量
• 目标:识别图片中数字所属类别(0~9)
4.2 代码示例
下面代码展示如何在手写数字数据集上使用 KNN 算法,并通过交叉验证网格搜索调优参数:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits # 导入手写数字数据集(内置于 scikit-learn)
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler # 导入数据标准化工具
from sklearn.neighbors import KNeighborsClassifier # 导入KNN分类器
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns # 导入 seaborn,用于绘制更美观的图表
def digits_knn_cv():
"""
使用KNN算法在手写数字数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
"""
# 1. 加载数据
digits = load_digits() # 从scikit-learn加载内置手写数字数据集
X = digits.data # 特征数据,形状为 (1797, 64),每一行对应一张图片的64个像素值
y = digits.target # 目标标签,共10个类别(数字 0 到 9)
# 2. 数据可视化:展示前5张图片及其标签
# 创建一个1行5列的子图区域,图像尺寸为10x2英寸
fig, axes = plt.subplots(1, 5, figsize=(10, 2))
for i in range(5):
# 显示第 i 张图片,使用灰度图(cmap='gray')
axes[i].imshow(digits.images[i], cmap='gray')
# 设置每个子图的标题,显示该图片对应的标签
axes[i].set_title("Label: {}".format(digits.target[i]))
# 关闭坐标轴显示(避免坐标信息干扰视觉效果)
axes[i].axis('off')
plt.suptitle("手写数字数据集示例") # 为整个图表添加一个总标题
plt.show() # 显示图表
# 3. 数据划分 + 标准化
# 将数据划分为训练集和测试集,其中测试集占20%,random_state保证每次划分一致
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 初始化标准化工具,将特征数据转换为均值为0、方差为1的标准正态分布
scaler = StandardScaler()
# 仅在训练集上拟合标准化参数,并转换训练集数据
X_train_scaled = scaler.fit_transform(X_train)
# 使用相同的转换参数转换测试集数据(避免数据泄露)
X_test_scaled = scaler.transform(X_test)
# 4. 构建KNN模型及网格搜索调参
knn = KNeighborsClassifier() # 初始化KNN分类器,暂未指定 n_neighbors 参数
# 定义一个字典,列出希望尝试的超参数组合
# 这里我们测试不同邻居数的效果:[1, 3, 5, 7, 9]
param_grid = {
'n_neighbors': [1, 3, 5, 7, 9]
}
# 初始化网格搜索对象,结合交叉验证
grid_search = GridSearchCV(
estimator=knn, # 需要调参的KNN模型
param_grid=param_grid, # 超参数候选组合
cv=5, # 5折交叉验证,将训练数据分成5份,每次用4份训练,1份验证
scoring='accuracy', # 使用准确率作为模型评估指标
n_jobs=-1 # 并行计算,使用所有可用的CPU核心加速计算
)
# 在标准化后的训练集上进行网格搜索,自动尝试所有参数组合,并进行交叉验证
grid_search.fit(X_train_scaled, y_train)
# 5. 输出网格搜索调参结果
# 打印在交叉验证中获得的最佳平均准确率
print("手写数字 - 最佳交叉验证分数:", grid_search.best_score_)
# 打印获得最佳结果时所使用的超参数组合,例如 {'n_neighbors': 3}
print("手写数字 - 最优超参数组合:", grid_search.best_params_)
# 打印最佳模型对象,该模型已使用最优参数重新训练
best_model = grid_search.best_estimator_
# 6. 模型评估:用测试集评估模型效果
# 使用最优模型对测试集进行预测
y_pred = best_model.predict(X_test_scaled)
# 计算测试集上的准确率
acc = accuracy_score(y_test, y_pred)
print("手写数字 - 测试集准确率:{:.2f}%".format(acc * 100))
# 7. 可视化混淆矩阵(直观展示各数字分类效果)
# 混淆矩阵能够显示真实标签与预测标签之间的对应关系
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 5))
# 使用 seaborn 的 heatmap 绘制混淆矩阵,annot=True 表示在每个单元格中显示数字
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d')
plt.title("手写数字 - 混淆矩阵")
plt.xlabel("预测值")
plt.ylabel("真实值")
plt.show()
# 直接调用函数进行测试
if __name__ == "__main__":
digits_knn_cv()
输出:
4.3 结果解读
• 最优 n_neighbors:通过交叉验证,我们找到了在候选参数中使模型表现最佳的邻居数量。
• 测试集准确率:在手写数字识别任务上,通常准确率能达到90%以上,证明 KNN 在小数据集上也能表现不错。
• 混淆矩阵:直观展示哪些数字容易混淆(例如数字“3”和“5”),便于进一步分析和改进。
混淆矩阵图的含义与作用
1. 横纵坐标的含义
• 行(纵轴)代表真实标签(真实的数字 0~9)。
• 列(横轴)代表模型预测的标签(预测的数字 0~9)。
2. 数值和颜色深浅
• 单元格 (i, j) 内的数值表示:真实类别为 i 的样本中,有多少被预测为 j。
• 越靠近对角线(i = j)代表预测正确的数量;
• 离对角线越远,说明模型将真实类别 i 的样本错误地预测成类别 j。
• 热力图中颜色越深表示数量越多,浅色则表示数量少。
3. 作用
• 评估模型分类效果:如果对角线上的数值高且远离对角线的数值低,说明模型分类准确度高;反之,说明某些类别容易被混淆。
• 发现易混淆的类别:通过观察非对角线位置是否有较大的数值,可以知道哪些数字最容易被误判。例如,模型可能经常把“3”预测成“5”,这能提示我们在后续改进中加强这两个类别的区分。
• 比单纯的准确率更全面:准确率只能告诉你模型整体正确率,而混淆矩阵能告诉你哪类错误最多,便于更有针对性地提升模型性能。
五、总结 & 彩蛋
1. 交叉验证的价值
• 有效避免过拟合,通过多次分组验证,使得模型评估更稳健。
2. 网格搜索的强大
• 自动遍历所有超参数组合,省去手动调参的烦恼,快速锁定“最佳拍档”。
3. KNN 的局限
• 虽然简单易用,但在大规模、高维数据中计算量较大,且对异常值较敏感。
4. 后续进阶
• 可以尝试随机搜索(RandomizedSearchCV)或贝叶斯优化,甚至转向更复杂的模型如 CNN 进行数字识别。
结语
如果你觉得本篇文章对你有所帮助,请记得点赞、收藏、转发和评论哦!你的支持是我继续创作的最大动力。让我们一起在机器学习的道路上不断探索、不断进步,早日成为调参界的“神仙”!
祝学习愉快,炼丹顺利~