模型选择实战

我们现在可以通过多项式拟合来探索这些概念。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定x,我们将使用以下三阶多项式来生成训练和测试数据的标签:

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

对模型进行训练和测试

首先让我们实现一个函数来评估模型在给定数据集上的损失。

def evaluate_loss(net, data_iter, loss):  #@save
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

现在定义训练函数。

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss(reduction='none')
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正常)

我们将首先使用三阶多项式函数,它与数据生成函数的阶数相同。 结果表明,该模型能有效降低训练损失和测试损失。 学习到的模型参数也接近真实值w=[5,1.2,−3.4,5.6]。

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

线性函数拟合(欠拟合)

让我们再看看线性函数拟合,减少该模型的训练损失相对困难。 在最后一个迭代周期完成后,训练损失仍然很高。 当用来拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

高阶多项式函数拟合(过拟合)

现在,让我们尝试使用一个阶数过高的多项式来训练模型。 在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。 因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。 虽然训练损失可以有效地降低,但测试损失仍然很高。 结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)

总结:

  • 欠拟合是指模型无法继续减少训练误差。过拟合是指训练误差远小于验证误差。
  • 由于不能基于训练误差来估计泛化误差,因此简单地最小化训练误差并不一定意味着泛化误差的减小。机器学习模型需要注意防止过拟合,即防止泛化误差过大。
  • 验证集可以用于模型选择,但不能过于随意地使用它。
  • 我们应该选择一个复杂度适当的模型,避免使用数量不足的训练样本。

借鉴:

4.4. 模型选择、欠拟合和过拟合 — 动手学深度学习 2.0.0 documentation (d2l.ai)

手学深度学习 2.0.0 documentation (d2l.ai)](https://zh-v2.d2l.ai/chapter_multilayer-perceptrons/underfit-overfit.html#id11)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/346249.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis中BigKey的分析与优化

Redis中BigKey的分析与优化 Redis以其出色的性能和易用性,在互联网技术栈中占据了重要的地位。 但是,高效的工具使用不当也会成为性能瓶颈。在Redis中,BigKey是常见的性能杀手之一,它们会消耗过多的内存,导致网络拥塞…

4 课程分类查询

4 课程分类查询 4.1 需求分析 下边根据内容管理模块的业务流程,下一步要实现新增课程,在新增课程界面,有三处信息需要选择,如下图: 课程等级、课程类型来源于数据字典表,此部分的信息前端已从系统管理服…

将vue组件发布成npm包

文章目录 前言一、环境准备1.首先最基本的需要安装nodejs,版本推荐 v10 以上,因为需要安装vue-cli2.安装vue-cli 二、初始化项目1.构建项目2.开发组件/加入组件3. 修改配置文件 三、调试1、执行打包命令2、发布本地连接包3、测试项目 四、发布使用1、注册…

从零开始训练 YOLOv8最新8.1版本教程说明(包含Mac、Windows、Linux端 )同之前的项目版本代码有区别

从零开始训练 YOLOv8 - 最新8.1版本教程说明 本文适用Windows/Linux/Mac:从零开始使用Windows/Linux/Mac训练 YOLOv8 算法项目 《芒果 YOLOv8 目标检测算法 改进》 适用于芒果专栏改进 YOLOv8 算法 文章目录 官方 YOLOv8 算法介绍改进网络代码汇总第一步 配置环境1.1 系列配…

【GitHub项目推荐--不错的 Go 学习项目】【转载】

开源实时性能分析平台 Pyroscope 是基于 Go 的开源实时性能分析平台,在源码中添加几行代码 pyroscope 就能帮你找出源代码中的性能问题和瓶颈、CPU 利用率过高的原因,调用树展示帮助你理解程序,支持 Go、Python、Ruby 语言。 Pyroscope 可以…

力(FFT,acwing2313)

题目路径: https://www.acwing.com/problem/content/2315/ 思路:

css实现右边边框分割线 渐变色,边框四角样式

分割线 代码 .data-item:first-of-type {border-right: 2px solid;border-image: linear-gradient(to top,rgba(0, 0, 0, 0.1) 0%,rgba(81, 110, 197, 0.76) 50%,rgba(0, 0, 0, 0.1) 100%)1;padding: 15px 0;}四角边框样式 代码 .chart-box {cursor: pointer;background: line…

[docker] Docker 网络

一、Docker 网络 1.1 Docker 网络实现原理 Docker使用Linux桥接,在宿主机虚拟一个Docker容器网桥(docker0),Docker启动一个容器时会根据Docker网桥的网段分配给容器一个IP地址,称为Container-IP,同时Docker网桥是每个容器的默认…

axi_quad_spi

文章目录 系统框图正常模式XIP模式 性能IP 核配置AXI Interface OptionXIP ModePerformance Mode SPI OptionsModeTransaction WidthFrequency RatioSlave DeviceEnable Master ModeEnable STARTUP Primitive 寄存器映射0x40 (SRR) 软件复位0x60 (SPICR) SPI控制0x64 (SPISR) S…

Docker安装Clickhouse详细教程

简介 ClickHouse是一种列式数据库管理系统,专门用于高性能数据分析和数据仓库应用。它是一个开源的数据库系统,最初由俄罗斯搜索引擎公司Yandex开发,用于满足大规模数据分析和报告的需求。 特点 开源的列式存储数据库管理系统,…

C# 将HTML网页、HTML字符串转换为PDF文件

将HTML转换为PDF可实现格式保留、可靠打印、文档归档等多种用途,满足不同领域和情境下的需求。本文将通过以下两个示例,演示如何使用第三方库Spire.PDF for .NET和QT插件在C# 中将Html 网页(URL)或HTML字符串转为PDF文件。 HTML转…

ubuntu20根目录扩容

ubuntu根目录/ 或者 /home文件夹有时出现空间满了的情况,可以用gparted工具进行空间的重新分配。 首先,如果你是双系统,需要从windows系统下磁盘压缩分配一部分未使用的空间给ubuntu,注意压缩的空间要邻接ubuntu所在盘的位置。 …

NTFS 磁盘管理 :NTFS Disk by Omi NTFS

NTFS Disk by Omi NTFS是一款专为Mac系统设计的NTFS文件系统读写解决方案的工具。它可以帮助Mac用户方便地访问和管理NTFS格式的硬盘、U盘、移动硬盘以及其他存储设备,提供高效稳定的NTFS卷管理功能。 NTFS 磁盘管理 :NTFS Disk by Omi NTFS 该软件的主…

C#调用SqlSugar操作达梦数据库报错“无效的表或视图名”

安装达梦数据库后,使用SqlSugar连接测试数据库并基于DBFirst方式创建数据库表对应的类,主要代码如下: SqlSugarClient db new SqlSugarClient(new ConnectionConfig(){DbType DbType.Dm,ConnectionString "Serverlocalhost; User Id…

仓储管理系统——软件工程报告(项目管理)⑦

项目管理 一、管理计划 这个项目的计划是一个关键的阶段,它需要考虑到多个因素,包括软件规模的度量、工作量的估算以及详细的进度计划,以确保项目按时、高质量地完成。 软件规模度量: 在软件工程中,度量软件规模是…

基于TriDet的时序动作检测算法训练自己的slowfast数据

最近一直在研究时序动作识别和检测,也一直关注着目前的最新进展,有好的算法,我都会在我自己的数据集上运行看看,一方面是为自己累积相关算法,另一方面也是想看看,目前最新的算法是否可以应用到一些项目上。…

Linux-共享内存

文章目录 前言一、system V共享内存申请共享内存挂载共享内存删除共享内存挂载删除共享内存 二、示例代码三.运行效果 前言 在这之前我们已经学习了两种进程间通信方式:匿名管道和命名管道。 从我们之前的学习已经知道,想让多个进程间进行通信就需要让他…

Redis面试

1.说说什么事redis Redis是一种基于键值对的NoSql数据库。 Redis中的value支持string(字符串)、hahs(哈希)、list、set、zset(有序集合)、bitmaps(位图),HyperLoglog等数…

快速上手的AI工具-文心一言辅助学习

前言 大家好晚上好,现在AI技术的发展,它已经渗透到我们生活的各个层面。对于普通人来说,理解并有效利用AI技术不仅能增强个人竞争力,还能在日常生活中带来便利。无论是提高工作效率,还是优化日常任务,AI工…

谷粒商城配置虚拟机

一、创建虚拟机 之前有在VM里面建一个ubuntu的虚拟机,准备拿来直接用,网络设置为NAT模式,查看我的虚拟机是虚拟机:192.168.248.128 主机: 192.168.2.12。可以互相ping通。 二、linux安装docker Docker docker是虚拟…