深度学习_15_过拟合欠拟合

过拟合和欠拟合
在这里插入图片描述
过拟合和欠拟合是训练模型中常会发生的事,如所要识别手势过于复杂,如五角星手势,那就需要更改高级更复杂的模型去训练,若用比较简单模型去训练,就会导致模型未能抓住手势的全部特征,那简单模型估计只能抓住五角星的其中一个角做特征,那么这个简单模型很可能就会将三角形与五角星混淆,这就是所谓欠拟合

若用识别五角星的复杂模型去识别三角形也是不行的,模型会过拟合,即学习了过多不重要的部分,可能会把三角形每条边所画的时间也当作学习的内容,即便我们人知道什么时候画哪条边都无所谓。

过拟合和欠拟合的表现都是模型的识别精度不够,所以要想判断模型是过拟合还是欠拟合,除了理论还是要多调试

如:
在这里插入图片描述
合适的模型应该是抛物线,上述左边是欠拟合,右边是过拟合

在这里插入图片描述
训练集和测试集

值得注意的是训练集和测试集必须是分开的,训练模型用训练集,一定不能让测试集污染模型

模型过拟的特征即对见过的数据集表现非常好,而对从未见过的模型表现非常差,若不把训练,测试集完全分开,最后的模型过拟合将无法被发现

实例:

完整代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)


# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []

    for epoch in range(num_epochs):
        train_loss, train_acc = d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        test_loss = evaluate_loss(net, test_iter, loss)

        # 将当前的损失值添加到列表中
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}")
    print(net[0].weight)
    # 假设 train_losses 和 test_losses 是已经计算出的损失值列表
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

分部讲解如下:

问题实例:
在这里插入图片描述

产生高斯分布随机数x并按上述式子生成训练集和验证集y,并对生成的y再添加一些杂音处理
注意:训练集一定要打乱,不要排序,排序会让训练效果大打折扣,如果训练数据是按照某种特定顺序排列的,那么模型可能会学习到这种顺序并在这个过程中引入偏差,导致模型在未见过的新数据上的泛化能力下降,打乱训练集的目的通常是为了防止模型学习到训练数据中的任何顺序依赖性,这样可以提高模型在随机或未见过的新数据上的泛化能力。

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)

计算损失函数,并不会更新迭代模型,所以用他来测试模型测试集损失

def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

训练函数,将X和对应y放在一起,即是进行模型迭代更新,又能计算模型训练损失,测试损失并绘制相应图形

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()  # 默认取平均损失
    input_shape = train_features.shape[-1]  # 模型大小取train_features最后一项大小
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])  # 整体数据集分成<= 10批次

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)  # 梯度下降算法

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []

    for epoch in range(num_epochs):
        train_loss, train_acc = d2l.train_epoch_ch3(net, train_iter, loss, trainer)  # 训练迭代模型
        test_loss = evaluate_loss(net, test_iter, loss)

        # 将当前的损失值添加到列表中
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}")
    print(net[0].weight)  # 输出训练好的模型
    # 假设 train_losses 和 test_losses 是已经计算出的损失值列表
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)  # 设置y轴的范围从0.01到100
    plt.show()

主函数

# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

利用上述实例验证欠拟合和过拟合以及正常拟合

在这里插入图片描述
上述函数对应真正的模型为:

true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

当然还有一些杂质,可忽略

那么可知预训练模型取四个维度就能做到正常拟合,而取二十个维度就是过拟合,取四个以下维度就是欠拟合

过拟合即取二十维度效果:

在这里插入图片描述
可以看出损失在下降到最低点的时候还会有上升
在这里插入图片描述
这是因为学完主要四个维度后又将本应是0的维度也学习了,也就是学习了无用的杂质。
在这里插入图片描述

欠拟合二维度模型效果:

在这里插入图片描述
损失很大,这也是没办法,毕竟还有很多重要维度没有学习上,本质上是模型过小

正常拟合四维度模型效果:

在这里插入图片描述
正常拟合的模型在损失到达最低点后便不再上升,训练出来的模型与真实数据也及其接近

在这里插入图片描述
正常拟合才是我们训练模型的期望状态

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/418143.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Gitlab: 私有化部署

目录 1. 说明 2. 资源要求 3. 安装 4. 配置实践 4.1 服务器 4.2 人员与项目 4.2 部署准备 4.2.1 访问变量及用户账号设置 4.2.2 Runner设置 4.2.3 要点 5. 应用项目 CI/CD 6. 参考 1. 说明 gitlab是一个强大且免费的代码管理/部署工具&#xff0c;能统一集成代码仓…

力扣 674. 最长连续递增序列

题目来源&#xff1a;https://leetcode.cn/problems/longest-continuous-increasing-subsequence/description/ C题解&#xff1a;贪心算法。把所有元素遍历一遍&#xff0c;比较它与上个数的大小&#xff0c;大的话更新长度tmp&#xff0c;小的话初始化长度tmp&#xff0c;并与…

linux nasm汇编中调用printf不报错,但调用scanf报错。抛出了分段错误(核心转储)

当我写了如下汇编时 ; nasm -f elf64 -g -F dwarf charsin.asm ; gcc charsin.o -no-pie -o charsin ; ld -o eatclib eatclib.o ; gdb eatclib[SECTION .data]SPrompt db Enter string data, followed by Enter: ,0IPrompt db Enter an integer value, followed by Enter: ,1…

本科毕业设计:计及并网依赖性的分布式能源系统优化研究。(C语言实现)(内包含NSGA II优化算法)(二)

目录 前言 1、sofc函数 2、光伏板函数 3、集热场函数 4、sofc电跟随策略函数 5、二分法找sofc运行点函数 6、目标函数&#xff1a;成本 7、目标函数&#xff1a;二氧化碳排放量 8、目标函数&#xff1a;并网依赖性 前言 本篇文章介绍的是我的毕业设计&#xff0c;我将C…

【JavaEE】_HttpServletResponse类

目录 1. 核心方法 2. 关于setStatus(400)与sendError 2.1 setStatus(400) 2.2 sendError 3. setHeader方法 4. 构造重定向响应 4.1 使用setHeader和setStatus实现重定向 4.2 使用sendRedirect实现重定向 本专栏已有文章介绍HttpServlet和HttpServletRequest类&#…

KubeEdge 边缘计算应用部署

文章目录 一、概述1.Kubernetes 对 Pod 调度规则1.1.自动调度1.2.定向调度1.3.亲和性调度1.4.污点和容忍 2.KubeEdge 应用部署2.1.KubeEdge应用部署方式2.2.标签操作 二、KubeEdge应用部署1.Node添加标签2.DaemonSet部署2.1.部署到所有节点2.2.部署到边缘节点 3.Deployment部署…

登录认证-Filter(黑马学习笔记)

分析 我们需要使用过滤器Filter来完成案例我们需要使用过滤器Filter来完成案例 我们先来回顾下前面分析过的登录校验的基本流程&#xff1a; ● 要进入到后台管理系统&#xff0c;我们必须先完成登录操作&#xff0c;此时就需要访问登录接口login。 —————————————…

leetcode:860.柠檬水找零

题意&#xff1a;按照支付顺序&#xff0c;进行支付&#xff0c;能够正确找零。 解题思路&#xff1a;贪心策略&#xff1a;针对支付20的客人&#xff0c;优先选择消耗10而不是消耗5&#xff0c;因为5可以用来找零10或20. 代码实现&#xff1a;有三种情况&#xff08;代表三种…

CSS技巧:实现两个div在同一行显示的方法

css如何让两个div在同一行显示 - web开发 - 亿速云 在Web开发中&#xff0c;经常遇到需要将多个元素水平排列在同一行的情况。其中一个常见的需求是将两个div元素放置在同一行上&#xff0c;使它们并排显示。在本文中&#xff0c;我们将介绍几种实现这一效果的CSS方法。 1. 使…

MySQL:索引有哪些(清晰明了)

一提到索引&#xff0c;可能就会想到B树索引、Hash索引、聚簇索引、主键索引、唯一索引、联合索引等等&#xff0c;但这些名词并不能混为一谈&#xff0c;他们有重复的部分&#xff0c;是从不同方面给索引取的名字。 从数据结构上来讲&#xff1a;B树索引、Hash索引、Full-text…

2024高校建设大数据实验室的用途有哪些?

随着当前现代计算机信息技术的快速进步发展,传统的各类高校重点实验室项目建设管理模式已经难以与当前现代信息技术快速发展相相互适应,高校应用大数据重点实验室项目建设模式是高校加快培养一批创新型高校大数据专业人才的重要技术基础。与此同时,高校应用大数据重点实验室项目…

手撕LRU缓存——LinkedHashMap简易源码

题目链接&#xff1a;https://leetcode.cn/problems/lru-cache/description/?envTypestudy-plan-v2&envIdtop-100-liked 原理非常简单&#xff0c;一个双端链表配上一个hash表。 首先我们要知道什么是LRU就是最小使用淘汰。怎么淘汰&#xff0c;链表尾部就是最不常用的直接…

专利:基于2D工业相机的工件目标检测及三维姿态

本发明公开了一种基于2D工业相机的工件目标检测及三维姿态判定方法&#xff0c;首先根据待生产或是待加工工件目标搭建其三维几何模型&#xff0c;并标记该几何模型制定特征点&#xff0c;然后对通过两个2D工业相机分别获得的现场工件目标图像进行目标检测及特征识别&#xff0…

Python电能质量扰动信号分类(六)基于扰动信号特征提取的超强机器学习识别模型

目录 往期精彩内容&#xff1a; 前言 1 数据集和特征提取 1.1 数据集导入 1.2 扰动信号特征提取 2超强模型XGBoost——原理介绍 2.1 原理介绍 2.2 特征数据集制作 3 模型评估和对比 3.1 随机森林分类模型 3.2 支持向量机SVM分类模型 3.3 XGBoost分类模型 代码、数据…

线程安全的集合容器

线程安全的集合容器 List集合中的线程安全的集合容器&#xff1a; 在旧版本中Vector是线程安全的集合容器&#xff0c;在JDK 1.5以后CopyOnWriteArrayList也是线程安全的集合容器&#xff0c;CopyOnWriteArrayList的数据结构是Object类型的数组。 CopyOnWriteArrayList是如何…

最新IE跳转Edge浏览器解决办法(2024.2.29)

最新IE跳转Edge浏览器解决办法&#xff08;2024.2.29&#xff09; 1.前言2. 解决方案2.1.创建快捷方式2.2.效果 3. 遗留问题 1.前言 在前几天我发布过一个关于使用卸载补丁从而解决最新的IE跳转Edge浏览器的解决方案。   但是这个方案其实存在一个BUG&#xff0c;例如我昨天重…

Mac 重新安装系统

Mac 重新安装系统 使用可引导安装器重新安装&#xff08;可用于安装非最新的 Mac OS&#xff0c;系统降级&#xff0c;需要清除所有数据&#xff09; 插入制作好的可引导安装器&#xff08;U盘或者移动固态硬盘&#xff09;&#xff0c;如何制作可引导安装器将 Mac 关机将 Ma…

本地ssh连接服务器成功,而vscode连接服务器超时

解决方案&#xff0c;按下CTRL SHIFT P 或者 COMMAND SHIFT P&#xff0c;然后输入Remote kill&#xff0c;找到以下命令Kill VS Code Server on Host...&#xff0c;然后选择连接失败的服务器进行Kill&#xff0c;之后再次尝试连接即可。 如果还不成功。找一个能登陆的方…

ubuntu基础操作(1)-个人笔记

搜狗输入法Linux官网-首页搜狗输入法for linux—支持全拼、简拼、模糊音、云输入、皮肤、中英混输https://pinyin.sogou.com/linux 1.关闭sudo密码&#xff1a; 终端&#xff08;ctrl alt t&#xff09;输入 sudo visudo 打开visudo 找到 %sudo ALL(ALL:ALL) ALL 这一行…

mount命令最新详细教程

背景 需要在设备上面&#xff0c;自动化运行u盘里面的脚本&#xff0c;并且进入一个产测模式。因此实际使用了这个mount命令&#xff0c;所以&#xff0c;写了这么一篇供大家参考。 一. 定义 mount命令在Linux和类Unix系统中用于挂载文件系统&#xff0c;即将存储设备…