机器学习:训练集与测试集分割train_test_split

1 引言

在使用机器学习训练模型算法的过程中,为提高模型的泛化能力、防止过拟合等目的,需要将整体数据划分为训练集和测试集两部分,训练集用于模型训练,测试集用于模型的验证。此时,使用train_test_split函数可便捷高效的实现数据训练集与测试集的划分。

2 train_test_split介绍

train_test_split函数来自scikit-learn库(也称为sklearn),安装命令:

pip install sklearn

函数的导入:

from sklearn.model_selection import train_test_split

1.1 函数定义

def train_test_split(*arrays,test_size=None,train_size=None,random_state=None,
    shuffle=True,stratify=None,):

1.2 参数说明

  • *arrays: 单个数组或元组,表示需要划分的数据集。如果传入多个数组,则必须保证每个数组的第一维大小相同。
  • test_size: 测试集的大小(占总数据集的比例,值为0.0-1.0,表示测试集占总样本比例)。默认值为0.25,即将传入数据的25%作为测试集。
  • train_size: 训练集的大小(占总数据集的比例,值为0.0-1.0,表示训练集占总样本比例)。默认值为None,此时和test_size互补,即训练集的大小为(1-test_size)。
  • random_state: 随机数种子。可以设置一个整数,用于复现结果。默认为None。其实是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。(比如每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。)
  • shuffle: 是否随机打乱数据。默认为True。
  • stratify: 可选参数,用于进行分层抽样。传入标签数组,保证划分后的训练集和测试集中各类别样本比例与原始数据集相同。默认为None,即普通的随机划分。(此参数作用是保持测试集与整个数据集里的数据分类比例一致,比如有1000个数据,800个属于A类,200个属于B类。设置stratify = y_lable,test_size=0.25,split之后数据组成如下:training: 750个数据,其中600个属于A类,150个属于B类;testing: 250个数据,其中200个属于A类,50个属于B类)

1.3 返回值

该函数返回一个元组(X_train, X_test, y_train, y_test),其中X_train表示训练集的特征数据,X_test表示测试集的特征数据,y_train表示训练集的标签数据,y_test表示测试集的标签数据。

1.4 注意事项

  • test_sizetrain_size必须至少有一个设置为非None
  • 当传入多个数组时,请确保每个数组的第一维大小相同。
  • random_state要设置一个整数值,从而保证每次获取相同的训练集和测试集
  • 当使用分层抽样时,请确保传入的标签数组是正确的。

3 train_test_split使用

3.1 使用train_test_split分割Iris数据

from sklearn import datasets
from sklearn.model_selection import train_test_split

# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)
print(X_train)
print(X_test)

结果展示:

X_train=[[6.5 2.8 4.6 1.5]
 [6.7 2.5 5.8 1.8]
 [6.8 3.  5.5 2.1]
 [5.1 3.5 1.4 0.3]
 [6.  2.2 5.  1.5]
 ......此处数据省略
 [4.9 3.6 1.4 0.1]]
X_test=[[5.8 4.  1.2 0.2]
 [5.1 2.5 3.  1.1]
 [6.6 3.  4.4 1.4]
 [5.4 3.9 1.3 0.4]
 [7.9 3.8 6.4 2. ]
 ......此处数据省略
 [5.2 3.4 1.4 0.2]]

3.2 使用train_test_split分割水果识别数据

在/opt/dataset下存放着水果图片的分类数据文件夹(文件夹名称为标签),每个文件夹下存储着多张对应标签的水果图片,如下所示:

以apple文件夹为例,图片内容如下:

数据加载和分割数据集的代码如下:

from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split

# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                     mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5]
                                ), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset', transform=transform)

# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/61581.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字孪生的「三张皮」问题:数据隐私、安全与伦理挑战

引言 随着数字化时代的来临,数据成为了当今社会的宝贵资源。然而,数据的广泛使用也带来了一系列隐私、安全与伦理挑战。数字孪生作为一种虚拟的数字化实体,通过收集和分析大量数据,模拟和预测现实世界中的各种情境,为…

【Jenkins】Jenkins 安装

Jenkins 安装 文章目录 Jenkins 安装一、安装JDK二、安装jenkins三、访问 Jenkins 初始化页面 Jenkins官网地址:https://www.jenkins.io/zh/download/ JDK下载地址:https://www.oracle.com/java/technologies/downloads/ 清华源下载RPM包地址&#xff…

简单认识ELK日志分析系统

一. ELK日志分析系统概述 1.ELK 简介 ELK平台是一套完整的日志集中处理解决方案,将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用, 完成更强大的用户对日志的查询、排序、统计需求。 好处: (1)提高安全…

【每天40分钟,我们一起用50天刷完 (剑指Offer)】第四十八天 48/50【字符串处理】【最低公共祖先】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

DQN原理和代码实现

参考:王树森《强化学习》书籍、课程、代码 1、基本概念 折扣回报: U t R t γ ⋅ R t 1 γ 2 ⋅ R t 2 ⋯ γ n − t ⋅ R n . U_tR_t\gamma\cdot R_{t1}\gamma^2\cdot R_{t2}\cdots\gamma^{n-t}\cdot R_n. Ut​Rt​γ⋅Rt1​γ2⋅Rt2​⋯γn−…

基于 APN 的 CXL 链路训练

🔥点击查看精选 CXL 系列文章🔥 🔥点击进入【芯片设计验证】社区,查看更多精彩内容🔥 📢 声明: 🥭 作者主页:【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

Dockerfile构建mysql

使用dockerfile构建mysql详细教学加案例 Dockerfile 文件 # 使用官方5.6版本,latest为默认版本 FROM mysql:5.6 #复制my.cof至容器内 ADD my.cnf /etc/mysql/my.cof #设置环境变量 密码 ENV MYSQL_ROOT_PASSWORD123456my.cof 文件 [mysqld] character-set-server…

LNMP搭建

LNMP:目前成熟的企业网站的应用模式之一,指的是一套协同工作的系统和相关软件 能够提供静态页面服务,也可以提供动态web服务。 这是一个缩写 L linux系统,操作系统。 N nginx网站服务,也可也理解为前端&#xff0c…

企业计算机服务器中了locked勒索病毒怎么办,如何预防勒索病毒攻击

计算机服务器是企业的关键信息基础设备,随着计算机技术的不断发展,企业的计算机服务器也成为了众多勒索者的攻击目标,勒索病毒成为当下计算机服务器的主要攻击目标。近期,我们收到很多企业的求助,企业的服务器被locked…

uni-app、H5实现瀑布流效果封装,列可以自定义

文章目录 前言一、效果二、使用代码三、核心代码总结 前言 最近做项目需要实现uni-app、H5实现瀑布流效果封装,网上搜索有很多的例子,但是代码都是不够完整的,下面来封装一个uni-app、H5都能用的代码。在小程序中,一个个item渲染…

Godot 4 源码分析 - Path2D与PathFollow2D

学习演示项目dodge_the_creeps,发现里面多了一个Path2D与PathFollow2D 研究GDScript代码发现,它主要用于随机生成Mob var mob_spawn_location get_node(^"MobPath/MobSpawnLocation")mob_spawn_location.progress randi()# Set the mobs dir…

【机器学习】编码、创造和筛选特征

在机器学习和数据科学领域中,特征工程是提取、转换和选择原始数据以创建更具信息价值的特征的过程。假设拿到一份数据集之后,如何逐步完成特征工程呢? 文章目录 一、特性类型分析1.1 数值型特征1.2 类别型特征1.3 时间型特征1.4 文本型特征1.…

Android Studio安装AI编程助手Github Copilot

csdn原创谢绝转载 简介 文档链接 https://docs.github.com/en/copilot/getting-started-with-github-copilot 它是个很牛B的编程辅助工具,装它,快装它. 支持以下IDE: IntelliJ IDEA (Ultimate, Community, Educational)Android StudioAppC…

数据库操作系列-Mysql, Postgres常用sql语句总结

文章目录 1.如果我想要写一句sql语句,实现 如果存在则更新,否则就插入新数据,如何解决?MySQL数据库实现方案: ON DUPLICATE KEY UPDATE写法 Postgres数据库实现方案:方案1:方案2:关于更新:如何实…

【云原生】K8S二进制搭建一

目录 一、环境部署1.1操作系统初始化 二、部署etcd集群2.1 准备签发证书环境在 master01 节点上操作在 node01与02 节点上操作 三、部署docker引擎四、部署 Master 组件4.1在 master01 节点上操 五、部署Worker Node组件 一、环境部署 集群IP组件k8s集群master01192.168.243.1…

【雕爷学编程】MicroPython动手做(31)——物联网之Easy IoT

1、物联网的诞生 美国计算机巨头微软(Microsoft)创办人、世界首富比尔盖茨,在1995年出版的《未来之路》一书中,提及“物物互联”。1998年麻省理工学院提出,当时被称作EPC系统的物联网构想。2005年11月,国际电信联盟发布《ITU互联网…

在 Ubuntu 上安装 Docker 桌面

Ubuntu 22.04 (LTS) 安装 Docker 桌面 要成功安装 Docker Desktop,您必须: 满足系统要求拥有 64 位版本的 Ubuntu Jammy Jellyfish 22.04 (LTS) 或 Ubuntu Impish Indri 21.10。对于非 Gnome 桌面环境,必须安装 gnome-terminal:…

机器学习笔记 - YOLO-NAS 最高效的目标检测算法之一

一、YOLO-NAS概述 YOLO(You Only Look Once)是一种对象检测算法,它使用深度神经网络模型,特别是卷积神经网络,来实时检测和分类对象。该算法首次在 2016 年由 Joseph Redmon、Santosh Divvala、Ross Girshick 和 Ali Farhadi 发表的论文《You Only Look Once: Unified, Re…

Excel·VBA表格横向、纵向相互转换

如图:对图中区域 A1:M6 横向表格,转换成区域 A1:C20 纵向表格,即 B:M 列转换成每2列一组按行写入,并删除空行。同理,反向操作就是纵向表格转换成横向表格 目录 横向转纵向实现方法1转换结果 实现方法2转换结果 纵向转横…

ThreadLocal有内存泄漏问题吗

对于ThreadLocal的原理不了解或者连Java中的引用类型都不了解的可以看一下我的之前的一篇文章Java中的引用和ThreadLocal_鱼跃鹰飞的博客-CSDN博客 我这里也简单总结一下: 1. 每个Thread里都存储着一个成员变量,ThreadLocalMap 2. ThreadLocal本身不存储数据&…