机器学习——模型评估与选择(拟合、)

【说明】文章内容来自《机器学习——基于sklearn》,用于学习记录。若有争议联系删除。

1、拟合

        拟合是指机器学习模型在训练的过程中,通过更新参数,使得模型不断契合可观测数据(训练集)的过程。欠拟合指的是模型在训练和预测表现都不好,往往由于模型过于简单,如图(a)所示。正常模型指的是模型在训练和预测表现都好,如图 (b)所示。过拟合是指由于模型过于复杂,模型在训练集上的表现很好,但在测试集上表现较差,如图 (c)所示。

1.1 欠拟合

         欠拟合往往是由于模型的数据特征过少或者模型复杂度较低所致。因此,可以通过如下办法解决欠拟合问题:
通过特征工程添加更多特征项。当特征不足或者现有特征与样本标签的相关性不强时,通过组合特征进行处理。
进行模型优化,提升模型复杂度。提升模型的复杂度能够增强模型的拟合能力。例如,在线性模型中可以通过添加高次项等提升模型的复杂度。
减小正则项权重。正则化的目的是防止过拟合,但是现在模型出现了欠拟合,因此需要减小正则项的权重
使用集成方法。融合几个有差异的弱模型,使其成为一个强模型

1.2 过拟合

        过拟合指的是模型在训练集上表现得很好,但是在交叉验证集合测试集上表现得一般,也就是说模型对未知样本的预测表现得一般,泛化(generalization)能力较差。过拟合产生的原因往往有以下几个:

  1. 训练数据不足,数据太少,导致无法描述问题的真实分布。根据统计学的大数定律,在试验不变的条件下,重复试验多次,随机事件的频率将趋近其概率。模型在求解最小值的过程中需要兼顾真实数据拟合和随机误差拟合。
  2. 数据有噪声。噪声数据使得更复杂的模型会尽量覆盖噪声点,即对数据过拟合。
  3. 对模型训练过度,导致模型非常复杂。

针对过拟合产生的原因,抑制过拟合的几种方法如下:

  1. 获取更多的训练样本。通过获取更多的训练样本,可以衰减噪声数据的权重。
  2. 增加正则项权重,减少高次项的影响。例如,通过L1或 L2正则化提高模型的泛
  3. 数据处理。
  • 进行数据清洗,纠正错误的标签,删除错误数据。
  • 进行特征共性检查,利用皮尔逊相关系数计算变量之间的线性相关性,进行特征选择。
  • 重要特征筛选,例如,对决策树模型进行求最大深度、剪枝等操作。
  • 数据降维,通过主成分分析等保留重要特征。

1.3 正则化

        正则化(regularization)目的在于提高模型在未知测试数据上的泛化能力,避免参数过拟合。常用方法是在原模型优化目标的基础上增加针对参数的惩罚(penalty)项,一般有L1正则化与L2正则化两种方法。
L1正则化是基于L1范数实现的,也就是 LASSO 回归,即在目标函数后面加上参数的绝对值之和与参数的积项。

C = C_{o}+\frac{\lambda }{n}\sum_{w}^{}|w|
L2正则化是基于L2范数实现的,也就是岭回归,即在目标函数后面加上参数的平方之和与参数的积项:

C = C_{o}+\frac{\lambda }{2n}\sum_{w}^{}|w^{2}|

L1范数: 指向量中各个元素绝对值之和

\left | \left | x \right | \right | _{1}= |x_{1}|+|x_{2}|+|x_{3}|+\cdots +|x_{n}|

L2范数: 指向量各元素的平方和然后求平方根

\left | \left | x \right | \right | _{2}=\sqrt{x_{1}^{2}+{x_{1}^{2}}+{x_{3}^{2}}+\cdots +{x_{n}^{2}}}

2、模型调参

        模型调参是指从数据特征入手,通过调整模型中的超参数来解决欠拟合和过拟合问题。
Sklearn 提供了网格搜索和随机搜索两种参数调优方法。

2.1 网格搜索

        GridSearchCV可以自动进行超参数组合,通过交叉验证评估模型的优劣。GridSearchCV拆分为GridSearch 和 CV两部分,即网格搜索和交叉验证。网格搜索是指在指定的参数范围内,通过循环和比较调整参数,训练预估器。假设模型有两个参数a,b,参数a有3个取值,参数b有4个取值,共有3X4=12个取值,则依次搜索这12个网格。交叉验证将训练数据集划分为K份(默认为10),依次取其中一份作为测试集,其余为训练集。model_selection 模块提供了GridSearchCV函数用于网格搜索,格式如下:

GridSearchCV(eatimator,param_grid)

【参数说明】

  • estimator:选择使用的分类器,并且传人除需要优化的参数之外的其他参数。每一个分类器都需要一个 scoring参数或者 score方法。
  • param_grid:需要优化的参数的取值,值的类型为字典或者列表。

示例:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
#加载鸢尾花数据
iris = datasets.load_iris()
#划分数据集
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, 
                                                   random_state = 3)
#标准化
transfer = StandardScaler()
X_train = transfer.fit_transform(X_train)
X_test = transfer.transform(X_test)
#KNN算法预估器
estimator = KNeighborsClassifier()
#多个超参数,引入网格搜索
param_dict = {'n_neighbors':[1,3,5,7,9,11]}
estimator = GridSearchCV(estimator, param_grid = param_dict, cv =10)
estimator.fit(X_train, y_train)
#模型评估
#最佳参数
print('最佳参数:\n',estimator.best_params_)
#最佳结果
print('最佳结果:\n',estimator.best_score_)
#最佳评估器
print('最佳评估器:\n',estimator.best_estimator_)
#交叉验证结果
print('交叉验证结果:\n',estimator.cv_results_)

【运行结果】

2.2 随机搜索

        在处理较少的超参数组合时,GridSearchCV方法比较适用。GridSearchCV可以保证在指定的参数范围内找到精度最高的参数,但是这也是网格搜索的缺陷所在,要求遍历所有可能参数的合,在面对大数据集和多参数的情况下,非常耗时。RandomizedSearchCV随机参数搜索的方法相对于网格搜索方法,找到模型的最优参数的可能性比较大,并且也比较省时。

Sklearn的model_selection模块RandomizedSearchCV方法用于随机搜索,代码如下:

meta_clf = RandomForestClassifier(n_estimators = 20)
给定参数搜索范围:list or distribution
param_dist = {"max_depth": [3, None], #给定list树的最大深度,如果None,节点扩展直到所有叶子
              "max_features": sp_randint(1, 11),  #给定distribution寻找最佳分裂点时考虑的特征数目。可选值,int(具体的数目),float(数目的百分比
              "min_samples_split": sp_randint(2, 11), #分裂内部节点需要的最少样例数。int(具体数目),float(数目的百分比) 
              "bootstrap": [True, False],   #构建数是不是采用有放回样本的方式
              "criterion": ["gini", "entropy"]}  #度量分裂的标准。可选值:“mse”,均方差(mean squared error)
# 用RandomiSearch+CV选取超参数
n_iter_search = 20
random_search = RandomizedSearchCV(clf, param_distributions=param_dist,
                                   n_iter=n_iter_search, cv=5, iid=False)
random_search.fit(X, y)

示例:

from sklearn import datasets 
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
iris = datasets.load_iris()
iris_feature = iris['data']
iris_target = iris['target']
#建模分析
forest_clf = RandomForestClassifier(random_state = 42)
param_distribs = {'n_estimators': range(10, 100), 'max_depth': range(5, 20)}
random_search = RandomizedSearchCV(forest_clf, param_distribs, n_iter = 50, cv = 3)
random_search.fit(iris_feature, iris_target)
print(random_search.best_params_)
best_model = random_search.best_estimator_

【运行结果】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/266648.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C# 使用Socket进行简单的通讯

目录 写在前面 代码实现 服务端部分 客户端部分 运行示例 总结 写在前面 在.Net的 System.Net.Sockets 命名空间中包含托管的跨平台套接字网络实现。 System.Net 命名空间中的所有其他网络访问类均建立在套接字的此实现之上。 其中的Socket 类是基于与 Linux、macOS 或 W…

STM32实现三个小灯亮

led.c #include"led.h"void Led_Init(void) {GPIO_InitTypeDef GPIO_VALUE; //???RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOC,ENABLE);//???GPIO_VALUE.GPIO_ModeGPIO_Mode_Out_PP;//???? ????GPIO_VALUE.GPIO_PinGPIO_Pin_1|GPIO_Pin_2|GPIO_P…

使用教程之【SkyWant.[2304]】路由器操作系统,破解移动【Netkeeper】校园网【小白篇】

许多高校目前饱受Netkeeper认证的痛苦,普通路由器无法使用, 教你利用SkyWant的Netkeeper认证软件来使你的SkyWant路由器顺利认证上网,全宿舍又可以合作共赢了! 步骤一:正确连接网线,插电开机 正确连接网…

一个简单例子更深入地理解BigQuery 的分区表

首先本文不会讲得很系统, 可以理解为是1个练习, 从这个简单例子中, 我们会体会到分区表与非分区表的操作和效果的区别 准备测试数据 首先, 本人准备了一份csv file , 测试数据, 大概样子如下:…

从零构建tomcat环境

一、官网构建 1.1 下载 一般来说对于开源软件都有自己的官方网站,并且会附上使用文档以及一些特性和二次构建的方法,那么我们首先的话需要从官网或者tomcat上下载到我们需要的源码包。下载地址:官网、Github。 这里需要声明一下&#xff…

Hadoop入门学习笔记——七、Hive语法

视频课程地址:https://www.bilibili.com/video/BV1WY4y197g7 课程资料链接:https://pan.baidu.com/s/15KpnWeKpvExpKmOC8xjmtQ?pwd5ay8 Hadoop入门学习笔记(汇总) 目录 七、Hive语法7.1. 数据库相关操作7.1.1. 创建数据库7.1.2…

每日一题——LeetCode859

方法一 个人方法: 首先s和goal要是长度不一样或者就只有一个字符这两种情况可以直接排除剩下的情况s和goal的长度都是一样的,s的长度为2也是特殊情况,只有s的第一位等于goal的第二位,s的第二位等于goal的第一位才能满足剩下的我们…

生成allure报告出现:ALLURE REPORT UNKNOWN

问题:点击浏览器查看时无法查看到报告 错误代码: if __name__ "__main__":pytest.main([./test_study/test_fixture.py])os.system("allure generate ./temps -o ./temps --clean") 结果导向: 解决:因为…

ZooKeeper 使用介绍和原理详解

目录 1. 介绍 重要性 应用场景 2. ZooKeeper 架构 服务角色 数据模型 工作原理 3. 安装和配置 下载 ZooKeeper 安装和配置 启动 ZooKeeper 验证和管理 停止和关闭 4. ZooKeeper 数据模型 数据结构和层次命名空间: 节点类型和 Watcher 机制&#xff…

SpringMVC:整合 SSM 上篇

文章目录 SpringMVC - 03整合 SSM 上篇一、准备工作二、MyBatis 层1. dao 层2. service 层 三、Spring 层四、SpringMVC 层五、执行六、说明 SpringMVC - 03 整合 SSM 上篇 用到的环境: IDEA 2019(JDK 1.8)MySQL 8.0.31Tomcat 8.5.85Maven…

OpenCV利用HSV颜色区间分离不同物体

需求 当前有个需求是从一个场景中将三个不同的颜色的二维码分离出来,如下图所示。 这里有两个思路可以使用 思路一是通过深度学习的方式,训练一个能够识别旋转边界框的模型,但是需要大量的数据进行模型训练,此处缺少训练数据&a…

WARNING: HADOOP_SECURE_DN_USER has been replaced by HDFS_DATANODE_SECURE_USER.

Hadoop启动时警告,但不影响使用,强迫症的我还是决定寻找解决办法 WARNING: HADOOP_SECURE_DN_USER has been replaced by HDFS_DATANODE_SECURE_USER. Using value of HADOOP_SECURE_DN_USER.原因是Hadoop安装配置于root用户下,对文件需要进…

智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于金枪鱼群算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.金枪鱼群算法4.实验参数设定5.算法结果6.…

深度学习(八):bert理解之transformer

1.主要结构 transformer 是一种深度学习模型,主要用于处理序列数据,如自然语言处理任务。它在 2017 年由 Vaswani 等人在论文 “Attention is All You Need” 中提出。 Transformer 的主要特点是它完全放弃了传统的循环神经网络(RNN&#x…

PHP函数定义和分类

函数的含义和定义格式 在PHP中,允许程序员将常用的流程或者变量等组件组织成一个固定的格式实现特定功能,也就是说函数是具有特定功能特定格式的代码段。 函数的定义格式如下: function 函数名(参数1,参数2,参数n) {…

适配器模式学习

适配器模式(Adapter)将一个类的接口转换成客户希望的另外一个接口。Adapter 模式使得原本由于接口不兼容而不能一起工作的那些类可以一起工作。 适配器模式分为类适配器模式和对象适配器模式两种,前者类之间的耦合度比后者高,且要…

【高数定积分求解旋转体体积】 —— (上)高等数学|定积分|柱壳法|学习技巧

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 💫个人格言:"没有罗马,那就自己创造罗马~" 目录 Shell method Setting up the Integral 例题 Example 1: Example 2: Example 3: Computing…

RasaGPT对话系统的工作原理

RasaGPT 结合了 Rasa 和 Langchain 这 2 个开源项目,当超出 Rasa 现有意图(out_of_scope)的时候,就会执行 ActionGPTFallback,本质上就是利用 Langchain 做了一个 RAG,调用 LLM API。RasaGPT 涉及的技术栈比较多而复杂&#xff0c…

如何更好地理解和掌握 KMP 算法?

KMP算法是一种字符串匹配算法,可以在 O(nm) 的时间复杂度内实现两个字符串的匹配。本文将引导您学习KMP算法,阅读大约需要30分钟。 1、字符串匹配问题 所谓字符串匹配,是这样一种问题:“字符串 P 是否为字符串 S 的子串&#xf…

springcloud-gateway-2-鉴权

目录 一、跨域安全设置 二、GlobalFilter实现全局的过滤与拦截。 三、GatewayFilter单个服务过滤器 1、原理-官方内置过滤器 2、自定义过滤器-TokenAuthGatewayFilterFactory 3、完善TokenAuthGatewayFilterFactory的功能 4、每一个服务编写一个或多个过滤器&#xff0c…