【PyTorch】模型选择、欠拟合和过拟合

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 完整代码
      • 2.2.2. 输出结果

1. 理论介绍

  • 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合, 用于对抗过拟合的技术称为正则化
  • 训练误差和验证误差都很严重, 但它们之间差距很小。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足),无法捕获试图学习的模式。 这种现象被称为欠拟合
  • 训练误差是指模型在训练数据集上计算得到的误差。
  • 泛化误差是指模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。我们永远不能准确地计算出泛化误差,在实际中,我们只能通过将模型应用于一个独立的测试集来估计泛化误差, 该测试集由随机选取的、未曾在训练集中出现的数据样本构成。
  • 影响模型泛化的因素
    • 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。
    • 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。
    • 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集,而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。
  • 在机器学习中,我们通常在评估几个候选模型后选择最终的模型。 这个过程叫做模型选择。候选模型可能在本质上不同,也可能是不同的超参数设置下的同一类模型。
  • 为了确定候选模型中的最佳模型,我们通常会使用验证集。验证集与测试集十分相似,唯一的区别是验证集是用于确定最佳模型,测试集是用于评估最终模型的性能
  • K K K折交叉验证:当训练数据稀缺时,将原始训练数据分成 K K K个不重叠的子集。 然后执行 K K K次模型训练和验证,每次在 ( K − 1 ) (K-1) (K1)个子集上进行训练, 并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。 最后,通过对 K K K次实验的结果取平均来估计训练和验证误差。
  • 引起过拟合的因素
    • 模型复杂度
      模型复杂度
    • 数据集大小
      • 训练数据集中的样本越少,我们就越有可能(且更严重地)过拟合。
      • 给出更多的数据,拟合更复杂的模型可能是有益的; 如果没有足够的数据,简单的模型可能更有用。

2. 实例解析

2.1. 实例描述

使用以下三阶多项式来生成训练和测试数据 y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ  where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).并用1阶(线性模型)、3阶、20阶多项式拟合。

2.2. 代码实现

2.2.1. 完整代码

import os
import numpy as np
import math, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import track

def evaluate_loss(dataloader, net, criterion):
    """评估模型在指定数据集上的损失"""
    num_examples = 0
    loss_sum = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.cuda(), y.cuda()
            loss = criterion(net(X), y)
            num_examples += y.shape[0]
            loss_sum += loss.sum()
        return loss_sum / num_examples

def load_dataset(*tensors):
    """加载数据集"""
    dataset = TensorDataset(*tensors)
    return DataLoader(dataset, batch_size, shuffle=True)


if __name__ == '__main__':
    # 全局参数设置
    num_epochs = 400
    batch_size = 10
    learning_rate = 0.01

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir())

    # 生成数据集
    max_degree = 20             # 多项式最高阶数
    n_train, n_test = 100, 100  # 训练集和测试集大小

    true_w = np.zeros(max_degree+1)
    true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

    features = np.random.normal(size=(n_train + n_test, 1))
    np.random.shuffle(features)
    poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))
    for i in range(max_degree+1):
        poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
    labels = np.dot(poly_features, true_w)
    labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)

    poly_features, labels = [
        torch.as_tensor(x, dtype=torch.float32) for x in [
            poly_features, labels.reshape(-1, 1)]]
    
    def loop(model_degree):
        # 创建模型
        net = nn.Linear(model_degree+1, 1, bias=False).cuda()
        nn.init.normal_(net.weight, mean=0, std=0.01)
        criterion = nn.MSELoss(reduction='none')
        optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

        # 加载数据集
        dataloader_train = load_dataset(poly_features[:n_train, :model_degree+1], labels[:n_train])
        dataloader_test = load_dataset(poly_features[n_train:, :model_degree+1], labels[n_train:])
        
        # 训练循环
        for epoch in track(range(num_epochs), description=f'{model_degree}-degree'):
            for X, y in dataloader_train:
                X, y = X.cuda(), y.cuda()
                loss = criterion(net(X), y)
                optimizer.zero_grad()
                loss.mean().backward()
                optimizer.step()

            writer.add_scalars(f"{model_degree}-degree", {
                "train_loss": evaluate_loss(dataloader_train, net, criterion),
                "test_loss": evaluate_loss(dataloader_test, net, criterion),
            }, epoch)
        print(f"{model_degree}-degree: weights =", net.weight.data.cpu().numpy())

    for model_degree in [1, 3, 20]:
        loop(model_degree)
    writer.close()

2.2.2. 输出结果

权重

  • 采用1阶多项式(线性模型)拟合
    1
  • 采用3阶多项式拟合
    3
  • 采用20阶多项式拟合
    20

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/225257.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

mysql知识分享(包含安装卸载)(一)

如果博客有错误,请佬指正。 目录 注意:打开cmd时要有管理员身份打开,重要 为何使用数据库? 数据库的相关概念 关系型数据库 关系型数据库设计规则 表,记录,字段 表的关联关系 一对一关联 一对多关系 …

基于Apache DolphinScheduler构建数据平台的实践分享

很荣幸收到社区的邀请来给大家分享一下公司基于构建数据平台的实践。我将首先介绍DataLink大数据平台,然后讲解调度系统的升级和选型,最后分享我们是如何基于Apache DolphinScheduler构建数据平台的,以及平台的现状和未来的规划。 公司介绍 …

ReentrantLock

文章目录 相对于 synchronized 它具备如下特点: 可中断,别的线程可以打断自己,避免自己竞争不到锁一直block,发生死锁可以设置超时时间:tryLock,可以解决死锁可以设置为公平锁,一般不会用支持多…

Zookeeper单机模式搭建

1、下载 ​wget https://dlcdn.apache.org/zookeeper/zookeeper-3.6.3/apache-zookeeper-3.6.3-bin.tar.gz 2、解压 tar -zxvf apache-zookeeper-3.6.3-bin.tar.gz 3、进入 apache-zookeeper-3.6.3-bin目录下,创建data cd apache-zookeeper-3.6.3-bin mkdir da…

JVM GUI可视化监控及诊断工具

工具既述 使用命令行工具或组合能帮您获取目标Java应用性能相关的基础信息,但它们存在下列局限: 无法获取方法级别的分析数据,如方法间的调用关系、各方法的调用次数和调用时间等(这对定位应用性能瓶颈至关重要)。要…

resnet 图像分类的项目

1. 项目文件 文件下载资源:resnet 图像分类的项目代码 本章利用reset34 作图像分类,包括计算训练集和测试集的loss、accuracy曲线,对训练好的网络在训练集测试集上求混淆矩阵 data 文件为训练集测试集,图像按照文件夹摆放inferenc…

汽车电子 -- CAN文件格式ASC

Vector提供了两种记录数据格式的格式规范:BLF和ASC。 先讲讲ASC。 参看:图文详解CAN Log文件 - ASC文件格式 一、ASC文件格式 在Vector提供的 CAN_LOG_TRIGGER_ASC_Format.pdf 提取码:ltjv 文件中,规定了CANoe/CANalyzer ASC记…

利用TCP通信实现文件传输和通信

前言 我们上一章已经熟悉了理论知识&#xff0c;这一章来练习一下 1.实现文件的传输 1.1 客户端 dir_client.c #include <stdio.h> #include <sys/socket.h> #include <netinet/in.h> #include <arpa/inet.h> #include <sys/types.h> #include …

Java ThreadLocal 实现原理 与 如何使用弱引用解决内存泄漏问题

目录 一、ThreadLocal 有什么用二、ThreadLocal 使用示例三、ThreadLocal 实现原理四、ThreadLocal 如何是使用弱引用解决内存泄漏问题4.1、强引用内存泄漏分析4.1、弱引用解决内存泄漏问题 一、ThreadLocal 有什么用 ThreadLocal 诞生于 JDK 1.2&#xff0c;用于解决多线程间的…

基于ssm学院党员管理系统论文

摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对鄂尔多斯应用技术学院党员信息管理混乱&#xff0c;出错率高&#x…

javaWebssh图书系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

系统前景 图书有很多&#xff0c;老的图书书的管理靠纸介质&#xff0c;浪费人力和物力&#xff0c;给图书管理者带来极大的资源浪费。随着计算机信息化的普及&#xff0c;对图书的管理带来本质的改变&#xff0c;图书的销售情况以及&#xff0c;图书管理&#xff0c;以及年终对…

振弦采集仪助力岩土工程质量控制

振弦采集仪助力岩土工程质量控制 随着工程建设规模越来越大&#xff0c;建筑结构的安全性和稳定性越来越成为人们所关注的焦点。岩土工程在工程建设中占据着非常重要的地位&#xff0c;岩土工程质量控制更是至关重要。而振弦采集仪作为一种先进的检测设备&#xff0c;正得到越…

Linux命令之ps

Linux命令之ps ps命令的基本用法**常用的ps命令示例** ps命令的基本用法 ps: 显示当前终端会话中属于当前用户的进程列表。 ps -ef: 显示系统中所有进程的列表&#xff0c;包括其他用户的进程。 ps -aux: 显示详细的进程信息&#xff0c;包括CPU和内存使用情况等。 -e选项&…

浅谈如何写开发信和报价?外贸邮件怎么写?

外贸开发信要写报价进去吗&#xff1f;写开发信能加产品价格吗&#xff1f; 开发信和报价是连接您和潜在客户之间的纽带&#xff0c;它们有助于传达您的价值主张、产品或服务的优势以及价格细节。蜂邮EDM将探讨如何撰写令人印象深刻的开发信和报价&#xff0c;以吸引更多的潜在…

AIGC:使用变分自编码器VAE实现MINIST手写数字生成

1 变分自编码器介绍 变分自编码器&#xff08;Variational Autoencoders&#xff0c;VAE&#xff09;是一种生成模型&#xff0c;用于学习数据的分布并生成与输入数据相似的新样本。它是一种自编码器&#xff08;Autoencoder&#xff09;的扩展&#xff0c;自编码器是一种用于…

严蔚敏数据结构p17(2.19)——p18(2.24) (c语言代码实现)

目录 2.19已知线性表中的元素以值递增有序排列,并以单链表作存储结构。试写一高效的算法,删除表中所有值大于 mink 且小于 maxk 的元素(若表中存在这样的元素&#xff09;同时释放被删结点空间,并分析你的算法的时间复杂度(注意:mink 和 maxk 是给定的个参变量,它们的值可以和表…

【JavaEE】生产者消费者模式

作者主页&#xff1a;paper jie_博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文于《JavaEE》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心打造的。笔者用重金(时间和精力)打造&…

从零开始学习 JS APL(一):完整指南和实例解析

本章内容主要是按一下来&#xff1a; 操作DOM BOM 比如 控制网页元 素交互等各种网页 交互效果 以下是我总结笔记&#xff08;仅供参考&#xff09; webAPL 获取DOM对象 变量声明有三个 var let 和 const 我们应该用那个呢&#xff1f; 首先var 先排除&#xff0c;老派写法…

UDP协议实现群聊

代码&#xff1a; import java.awt.*; import java.awt.event.*; import javax.swing.*; import java.net.*; import java.io.IOException; import java.lang.String; public class liaotian extends JFrame{ private static final int DEFAULT_PORT8899; private J…

机器的深度强化学习算法可以被诱导

设计一个好的奖励函数是机器深度强化学习算法的关键之一。奖励函数用于给予智能体&#xff08;机器&#xff09;在环境中采取不同行动时的反馈信号&#xff0c;以指导其学习过程。一个好的奖励函数应该能够引导智能体朝着期望的行为方向学习&#xff0c;并尽量避免潜在的问题&a…