【机器学习】超参数调优指南:交叉验证,网格搜索,混淆矩阵——基于鸢尾花与数字识别案例的深度解析

一、前言:为何要学交叉验证与网格搜索?

大家好!在机器学习的道路上,我们经常面临一个难题:模型调参。比如在 KNN 算法中,选择多少个邻居(n_neighbors)直接影响预测效果。

蛮力猜测:就像在厨房随便“加盐加辣椒”,不仅费时费力,还可能把菜搞砸。

交叉验证 + 网格搜索:更像是让你请来一位“大厨”,提前试好所有配方,帮你挑选出最完美的“调料搭配”。

交叉验证与网格搜索的组合,能让你在众多超参数组合中自动挑选出最佳方案,从而让模型预测达到“哇塞,这也太准了吧!”的境界。


二、概念扫盲:交叉验证 & 网格搜索

1. 交叉验证(Cross-Validation)

核心思路:

分组品尝:将整个数据集平均分成若干份(比如分成 5 份,即“5折交叉验证”)。

轮流担任评委:每次选取其中一份作为“验证集”(就像让这部分数据来“评委打分”),剩下的作为“训练集”来训练模型。

集体评定:重复多次,每一份都轮流担任验证集,然后把所有“评分”取平均,作为模型在数据集上的最终表现。

好处:

• 每个样本都有机会既当“选手”又当“评委”,使得评估结果更稳定、可靠。

• 避免单一划分带来的偶然性,确保你调出来的参数在不同数据切分下都表现良好。


2. 网格搜索(Grid Search)

核心思路:

列出所有可能:将你想尝试的超参数组合“罗列成一个表格(网格)”。

自动试菜:每种组合都进行一次完整的模型训练和评估,记录下它们的表现。

选出最佳配方:最后找出在交叉验证中表现最好的超参数组合。

好处:

• 自动化、系统化地寻找最佳参数组合,避免你手动“胡乱猜测”。

• 和交叉验证结合后,每个参数组合都经过了多次评估,结果更稳健。

3. 网格搜索 + 交叉验证

这两者结合就像“炼丹”高手的秘诀:

交叉验证解决了“数据切分”的问题,让评估更准确;

网格搜索解决了“超参数组合”问题,帮你遍历所有可能性。

合体后,你就能轻松找到最优超参数,让模型发挥出最佳性能!


三、案例一:鸢尾花数据集 + KNN + 交叉验证网格搜索

3.1 数据集介绍

数据来源:scikit-learn 内置的 load_iris

特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度

目标:根据花的外部特征预测其所属的鸢尾花种类


3.2 代码示例

下面代码展示如何在鸢尾花数据集上使用 KNN 算法,并通过 GridSearchCV(交叉验证+网格搜索)自动调优 n_neighbors 参数:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def iris_knn_cv():
    """
    使用KNN算法在鸢尾花数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
    """
    # 1. 加载数据
    iris = load_iris()
    X = iris.data      # 特征矩阵,包含四个特征
    y = iris.target    # 标签,分别代表三种鸢尾花

    # 2. 划分训练集和测试集
    # test_size=0.2 表示 20% 的数据用于测试,保证测试结果具有代表性
    # random_state=22 固定随机数种子,确保每次运行划分一致
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=22
    )

    # 3. 数据标准化
    # 标准化可使各特征均值为0、方差为1,消除量纲影响(对于基于距离的KNN非常重要)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # 4. 构建KNN模型及参数调优
    knn = KNeighborsClassifier()  # 初始化KNN模型

    # 4.1 设置网格搜索参数范围:尝试不同的邻居数
    param_grid = {
        'n_neighbors': [1, 3, 5, 7, 9]
    }

    # 4.2 进行网格搜索 + 交叉验证(5折交叉验证)
    grid_search = GridSearchCV(
        estimator=knn,        # 待调参的模型
        param_grid=param_grid,  # 超参数候选列表
        cv=5,                 # 5折交叉验证:将训练集分为5个子集,每次用1个子集验证,其余4个训练
        scoring='accuracy',   # 以准确率作为评估指标
        n_jobs=-1             # 使用所有CPU核心并行计算
    )
    grid_search.fit(X_train_scaled, y_train)  # 自动遍历各参数组合并评估

    # 4.3 输出网格搜索结果
    print("最佳交叉验证分数:", grid_search.best_score_)
    print("最优超参数组合:", grid_search.best_params_)
    print("最优模型:", grid_search.best_estimator_)

    # 5. 模型评估:用测试集评估最优模型的泛化能力
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test_scaled)
    acc = accuracy_score(y_test, y_pred)
    print("在测试集上的准确率:{:.2f}%".format(acc * 100))

    # 6. 可视化(选做):可进一步绘制混淆矩阵或学习曲线

# 直接调用函数进行测试
if __name__ == "__main__":
    iris_knn_cv()

输出: 

3.3 结果解读

最佳交叉验证分数:表示在5折交叉验证过程中,所有参数组合中平均准确率最高的值。

最优超参数组合:显示在候选参数 [1, 3, 5, 7, 9] 中哪个 n_neighbors 的效果最好。

测试集准确率:验证模型在未见数据上的表现,反映其泛化能力。

通过这个案例,你可以看到交叉验证网格搜索如何自动帮你“挑菜”选料,让 KNN 模型在鸢尾花分类任务上达到最佳表现。


四、案例二:手写数字数据集 + KNN + 交叉验证网格搜索

4.1 数据集介绍

数据来源:scikit-learn 内置的 load_digits

特征:每张 8×8 像素的手写数字图像被拉伸成64维特征向量

目标:识别图片中数字所属类别(0~9)

4.2 代码示例

下面代码展示如何在手写数字数据集上使用 KNN 算法,并通过交叉验证网格搜索调优参数:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits        # 导入手写数字数据集(内置于 scikit-learn)
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler  # 导入数据标准化工具
from sklearn.neighbors import KNeighborsClassifier  # 导入KNN分类器
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns  # 导入 seaborn,用于绘制更美观的图表

def digits_knn_cv():
    """
    使用KNN算法在手写数字数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
    """
    # 1. 加载数据
    digits = load_digits()  # 从scikit-learn加载内置手写数字数据集
    X = digits.data         # 特征数据,形状为 (1797, 64),每一行对应一张图片的64个像素值
    y = digits.target       # 目标标签,共10个类别(数字 0 到 9)
    
    # 2. 数据可视化:展示前5张图片及其标签
    # 创建一个1行5列的子图区域,图像尺寸为10x2英寸
    fig, axes = plt.subplots(1, 5, figsize=(10, 2))
    for i in range(5):
        # 显示第 i 张图片,使用灰度图(cmap='gray')
        axes[i].imshow(digits.images[i], cmap='gray')
        # 设置每个子图的标题,显示该图片对应的标签
        axes[i].set_title("Label: {}".format(digits.target[i]))
        # 关闭坐标轴显示(避免坐标信息干扰视觉效果)
        axes[i].axis('off')
    plt.suptitle("手写数字数据集示例")  # 为整个图表添加一个总标题
    plt.show()  # 显示图表
    
    # 3. 数据划分 + 标准化
    # 将数据划分为训练集和测试集,其中测试集占20%,random_state保证每次划分一致
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    # 初始化标准化工具,将特征数据转换为均值为0、方差为1的标准正态分布
    scaler = StandardScaler()
    # 仅在训练集上拟合标准化参数,并转换训练集数据
    X_train_scaled = scaler.fit_transform(X_train)
    # 使用相同的转换参数转换测试集数据(避免数据泄露)
    X_test_scaled = scaler.transform(X_test)
    
    # 4. 构建KNN模型及网格搜索调参
    knn = KNeighborsClassifier()  # 初始化KNN分类器,暂未指定 n_neighbors 参数
    # 定义一个字典,列出希望尝试的超参数组合
    # 这里我们测试不同邻居数的效果:[1, 3, 5, 7, 9]
    param_grid = {
        'n_neighbors': [1, 3, 5, 7, 9]
    }
    # 初始化网格搜索对象,结合交叉验证
    grid_search = GridSearchCV(
        estimator=knn,         # 需要调参的KNN模型
        param_grid=param_grid, # 超参数候选组合
        cv=5,                  # 5折交叉验证,将训练数据分成5份,每次用4份训练,1份验证
        scoring='accuracy',    # 使用准确率作为模型评估指标
        n_jobs=-1              # 并行计算,使用所有可用的CPU核心加速计算
    )
    # 在标准化后的训练集上进行网格搜索,自动尝试所有参数组合,并进行交叉验证
    grid_search.fit(X_train_scaled, y_train)
    
    # 5. 输出网格搜索调参结果
    # 打印在交叉验证中获得的最佳平均准确率
    print("手写数字 - 最佳交叉验证分数:", grid_search.best_score_)
    # 打印获得最佳结果时所使用的超参数组合,例如 {'n_neighbors': 3}
    print("手写数字 - 最优超参数组合:", grid_search.best_params_)
    # 打印最佳模型对象,该模型已使用最优参数重新训练
    best_model = grid_search.best_estimator_
    
    # 6. 模型评估:用测试集评估模型效果
    # 使用最优模型对测试集进行预测
    y_pred = best_model.predict(X_test_scaled)
    # 计算测试集上的准确率
    acc = accuracy_score(y_test, y_pred)
    print("手写数字 - 测试集准确率:{:.2f}%".format(acc * 100))
    
    # 7. 可视化混淆矩阵(直观展示各数字分类效果)
    # 混淆矩阵能够显示真实标签与预测标签之间的对应关系
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(6, 5))
    # 使用 seaborn 的 heatmap 绘制混淆矩阵,annot=True 表示在每个单元格中显示数字
    sns.heatmap(cm, annot=True, cmap='Blues', fmt='d')
    plt.title("手写数字 - 混淆矩阵")
    plt.xlabel("预测值")
    plt.ylabel("真实值")
    plt.show()

# 直接调用函数进行测试
if __name__ == "__main__":
    digits_knn_cv()

输出:

4.3 结果解读

最优 n_neighbors:通过交叉验证,我们找到了在候选参数中使模型表现最佳的邻居数量。

测试集准确率:在手写数字识别任务上,通常准确率能达到90%以上,证明 KNN 在小数据集上也能表现不错。

混淆矩阵:直观展示哪些数字容易混淆(例如数字“3”和“5”),便于进一步分析和改进。

混淆矩阵图的含义与作用

1. 横纵坐标的含义

行(纵轴)代表真实标签(真实的数字 0~9)。

列(横轴)代表模型预测的标签(预测的数字 0~9)。

2. 数值和颜色深浅

• 单元格 (i, j) 内的数值表示:真实类别为 i 的样本中,有多少被预测为 j。

• 越靠近对角线(i = j)代表预测正确的数量;

• 离对角线越远,说明模型将真实类别 i 的样本错误地预测成类别 j。

• 热力图中颜色越深表示数量越多,浅色则表示数量少。

3. 作用

评估模型分类效果:如果对角线上的数值高且远离对角线的数值低,说明模型分类准确度高;反之,说明某些类别容易被混淆。

发现易混淆的类别:通过观察非对角线位置是否有较大的数值,可以知道哪些数字最容易被误判。例如,模型可能经常把“3”预测成“5”,这能提示我们在后续改进中加强这两个类别的区分。

比单纯的准确率更全面:准确率只能告诉你模型整体正确率,而混淆矩阵能告诉你哪类错误最多,便于更有针对性地提升模型性能。


五、总结 & 彩蛋

1. 交叉验证的价值

• 有效避免过拟合,通过多次分组验证,使得模型评估更稳健。

2. 网格搜索的强大

• 自动遍历所有超参数组合,省去手动调参的烦恼,快速锁定“最佳拍档”。

3. KNN 的局限

• 虽然简单易用,但在大规模、高维数据中计算量较大,且对异常值较敏感。

4. 后续进阶

• 可以尝试随机搜索(RandomizedSearchCV)或贝叶斯优化,甚至转向更复杂的模型如 CNN 进行数字识别。


结语

如果你觉得本篇文章对你有所帮助,请记得点赞、收藏、转发和评论哦!你的支持是我继续创作的最大动力。让我们一起在机器学习的道路上不断探索、不断进步,早日成为调参界的“神仙”!

祝学习愉快,炼丹顺利~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/972100.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UGUI RectTransform的SizeDelta属性

根据已知内容,SizeDelta offsetMax - offsetMin 1.锚点聚拢情况下 输出 那么此时SizeDelta就是UI元素的长宽大小 2. 锚点分散时 引用自此篇文章中的描述 揭秘!anchoredPosition的几何意义! SizeDelta offsetMax - offsetMin (rectMax…

51单片机入门_10_数码管动态显示(数字的使用;简单动态显示;指定值的数码管动态显示)

接上篇的数码管静态显示,以下是接上篇介绍到的动态显示的原理。 动态显示的特点是将所有位数码管的段选线并联在一起,由位选线控制是哪一位数码管有效。选亮数码管采用动态扫描显示。所谓动态扫描显示即轮流向各位数码管送出字形码和相应的位选&#xff…

mybatis使用typeHandler实现类型转换

使用mybatis作为操作数据库的orm框架,操作基本数据类型时可以通过内置的类型处理器完成java数据类型和数据库类型的转换,但是对于扩展的数据类型要实现与数据库类型的转换就需要自定义类型转换器完成,比如某个实体类型存储到数据库&#xff0…

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…

最新智能优化算法:牛优化( Ox Optimizer,OX)算法求解经典23个函数测试集,MATLAB代码

一、牛优化算法 牛优化( OX Optimizer,OX)算法由 AhmadK.AlHwaitat 与 andHussamN.Fakhouri于2024年提出,该算法的设计灵感来源于公牛的行为特性。公牛以其巨大的力量而闻名,能够承载沉重的负担并进行远距离运输。这种…

【linux】在 Linux 服务器上部署 DeepSeek-r1:70b 并通过 Windows 远程可视化使用

【linux】在 Linux 服务器上部署 DeepSeek-r1:70b 并通过 Windows 远程可视化使用 文章目录 【linux】在 Linux 服务器上部署 DeepSeek-r1:70b 并通过 Windows 远程可视化使用个人配置详情一、安装ollama二、下载deepseek版本模型三、在 Linux 服务器上配置 Ollama 以允许远程访…

【Linux网络编程】应用层协议HTTP(请求方法,状态码,重定向,cookie,session)

🎁个人主页:我们的五年 🔍系列专栏:Linux网络编程 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 ​ Linux网络编程笔记: https://blog.cs…

Chrome多开终极形态解锁!「窗口管理工具+IP隔离插件

Web3项目多开,继ads指纹浏览器钱包被盗后,更多人采用原生chrome浏览器,当然对于新手,指纹浏览器每月成本也是一笔不小开支,今天逛Github发现了这样一个解决方案,作者开发了窗口管理工具IP隔离插件&#xff…

Canal同步MySQL增量数据

引言 在现在的系统开发中,为了提高查询效率 , 以及搜索的精准度, 会大量的使用 redis 、memcache 等 nosql 系统的数据库 , 以及 solr 、 elasticsearch 类似的全文检索服务。 那么这个时候, 就又有一个问题需要我们来考虑, 就是数据同步的问题, 如何将实时变化的…

MacOS 15.3 卸载系统内置软件

1、关闭系统完整性(SIP) 进入恢复模式(recovery) 如果您使用的是黑苹果或者白苹果,可以选择 重启按住CommandR 进入,如果是M系列芯片,长按开机键,进入硬盘选择界面进入。 我是MacMini M4芯片,关…

【核心算法篇七】《DeepSeek异常检测:孤立森林与AutoEncoder对比》

大家好,今天我们来深入探讨一下《DeepSeek异常检测:孤立森林与AutoEncoder对比》这篇技术博客。我们将从核心内容、原理、应用场景等多个方面进行详细解析,力求让大家对这两种异常检测方法有一个全面而深入的理解。 一、引言 在数据科学和机器学习领域,异常检测(Anomaly…

Ubuntu24.04无脑安装docker(含图例)

centos系统请看这篇 Linux安装Docker教程(详解) 一. ubuntu更换软件源 请看这篇:Ubuntu24.04更新国内源 二. docker安装 卸载老版docker(可忽略) sudo apt-get remove docker docker-engine docker.io containerd runc更新软件库 sudo a…

thingboard告警信息格式美化

原始报警json内容: { "severity": "CRITICAL","acknowledged": false,"cleared": false,"assigneeId": null,"startTs": 1739801102349,"endTs": 1739801102349,"ackTs": 0,&quo…

✨2.快速了解HTML5的标签类型

✨✨HTML5 的标签类型丰富多样&#xff0c;每种类型都有其独特的功能和用途&#xff0c;以下是一些常见的 HTML5 标签类型介绍&#xff1a; &#x1f98b;结构标签 &#x1faad;<html>&#xff1a;它是 HTML 文档的根标签&#xff0c;所有其他标签都包含在这个标签内&am…

day12_调度和可视化

文章目录 day12_调度和可视化一、任务调度1、开启进程2、登入UI界面3、配置租户4、创建项目5、创建工作流5.1 HiveSQL部署&#xff08;掌握&#xff09;5.2 SparkDSL部署&#xff08;掌握&#xff09;5.3 SparkSQL部署&#xff08;熟悉&#xff09;5.4 SeaTunnel部署&#xff0…

使用nvm管理node.js版本,方便vue2,vue3开发

在Vue项目开发过程中&#xff0c;我们常常会遇到同时维护Vue2和Vue3项目的情况。由于不同版本的Vue对Node.js 版本的要求有所差异&#xff0c;这就使得Node.js 版本管理成为了一个关键问题。NVM&#xff08;Node Version Manager&#xff09;作为一款强大的Node.js 版本管理工具…

Java8适配的markdown转换html工具(FlexMark)

坐标地址&#xff1a; <dependency><groupId>com.vladsch.flexmark</groupId><artifactId>flexmark-all</artifactId><version>0.60.0</version> </dependency> 工具类代码&#xff1a; import com.vladsch.flexmark.ext.tab…

Qt开发①Qt的概念+发展+优点+应用+使用

目录 1. Qt的概念和发展 1.1 Qt的概念 1.2 Qt 的发展史&#xff1a; 1.3 Qt 的版本 2. Qt 的优点和应用 2.1 Qt 的优点&#xff1a; 2.2 Qt 的应用场景 2.3 Qt 的应用案例 3. 搭建 Qt 开发环境 3.1 Qt 的开发工具 3.2 Qt SDK 的下载和安装 3.3 Qt 环境变量配置和使…

vscode插件开发

准备 安装开发依赖 npm install -g yo generator-code 安装后&#xff0c;运行命令 yo code 运行 打开项目&#xff0c; 点击 vscode 调式 按 F5 或点击调试运行按钮 会打开一个新窗口&#xff0c;在新窗口按快捷键 CtrlShiftP &#xff0c;搜索 Hello World 选择执行 右下角出…

自制简单的图片查看器(python)

图片格式&#xff1a;支持常见的图片格式&#xff08;JPG、PNG、BMP、GIF&#xff09;。 import os import tkinter as tk from tkinter import filedialog, messagebox from PIL import Image, ImageTkclass ImageViewer:def __init__(self, root):self.root rootself.root.…