PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程

PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程

本教程介绍了如何使用 PyTorch 构建一个简单的神经网络来实现关系拟合,具体演示了从数据准备到模型训练和可视化的完整过程。首先,利用一维线性空间生成带噪声的数据集,接着定义了一个包含隐藏层和输出层的神经网络。通过使用均方误差损失函数和随机梯度下降优化器,逐步训练神经网络来拟合数据。为了便于理解和监控训练过程,我们使用 matplotlib 实现了动态更新的图形,展示了每次迭代后的预测结果与真实数据的对比。该教程不仅帮助读者理解神经网络的基本架构和训练流程,还展示了如何通过可视化手段更直观地观察模型的优化过程,提升了对模型调优的理解与应用能力。

文章目录

  • PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程
      • 一 导入第三方库
      • 二 设置数据集
      • 三 编写神经网络
      • 四 训练神经网络
        • 可视化训练过程
      • 五 完整代码示例
      • 六 源码地址
      • 七 参考

预备课:PyTorch 激活函数详解:从原理到最佳实践

一 导入第三方库

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os

二 设置数据集

# 生成一维的线性空间数据,并增加一维使其形状为 (100, 1)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # 生成对应的 y 数据,加上噪声模拟真实情况

三 编写神经网络

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,输入维度为 n_feature,输出维度为 n_hidden
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层,输入维度为 n_hidden,输出维度为 n_output

    def forward(self, x):
        x = F.relu(self.hidden(x))  # 使用 ReLU 激活函数处理隐藏层的输出
        x = self.predict(x)  # 计算最终输出
        return x
      

在此定义了神经网络的结构,其中隐藏层的输入维度为 n_feature,输出维度为 n_hidden,而输出层的输入维度为 n_hidden,输出维度为 n_output。下图展示了以 3 个神经元为例的网络结构,以帮助理解。
在这里插入图片描述

:如果对上述代码感到困惑,可以暂时将其视为固定写法,专注于理解其基本框架。

四 训练神经网络

# 初始化神经网络
net = Net(n_feature=1, n_hidden=10, n_output=1)  # 定义网络,输入输出各为 1,隐藏层有 10 个神经元
print(net)  # 打印网络结构

# 定义优化器和损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 使用随机梯度下降法优化网络参数,学习率为 0.2
loss_func = torch.nn.MSELoss()  # 定义均方误差损失函数

plt.ion()  # 开启交互模式,允许动态更新图像

for epoch in range(200):
    prediction = net(x)  # 前向传播,使用当前网络计算预测值
    loss = loss_func(prediction, y)  # 计算预测值与真实值之间的误差

    optimizer.zero_grad()  # 清空上一步的梯度信息
    loss.backward()  # 反向传播,计算梯度
    optimizer.step()  # 根据梯度更新网络参数

    if epoch % 5 == 0:  # 每 5 个周期更新一次图像
        plt.cla()  # 清除当前图像内容
        plt.scatter(x.data.numpy(), y.data.numpy(), label='True Data')
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2, label='Prediction')
        plt.text(0.5, 0, f'Loss={loss.item():.4f}', fontdict={'size': 20, 'color': 'red'})
        plt.legend()  # 添加图例

        # 保存当前图像
        # file_path = os.path.join(target_directory, f'epoch_{epoch}.png')
        # plt.savefig(file_path)
        # print(f"图像已保存: {file_path}")
        plt.pause(0.1)  # 暂停以更新图像

plt.ioff()  # 关闭交互模式
plt.show()  # 显示最终图像
可视化训练过程

可视化神经网络训练(关系拟合)

:通过引入 matplotlib 实现训练过程的可视化,帮助直观地跟踪模型的学习进展。

五 完整代码示例

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os


class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,输入维度为 n_feature,输出维度为 n_hidden
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层,输入维度为 n_hidden,输出维度为 n_output

    def forward(self, x):
        x = F.relu(self.hidden(x))  # 使用 ReLU 激活函数处理隐藏层的输出
        x = self.predict(x)  # 计算最终输出
        return x


def print_hi(name):
    print(f'Hi, {name}')
    # 创建保存图片的目录
    # target_directory = "/Users/your/Desktop/001"
    # if not os.path.exists(target_directory):
    #     os.makedirs(target_directory)
    # 创建数据集
    # 生成一维的线性空间数据,并增加一维使其形状为 (100, 1)
    x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
    y = x.pow(2) + 0.2 * torch.rand(x.size())  # 生成对应的 y 数据,加上噪声模拟真实情况

    # 初始化神经网络
    net = Net(n_feature=1, n_hidden=10, n_output=1)  # 定义网络,输入输出各为 1,隐藏层有 10 个神经元
    print(net)  # 打印网络结构

    # 定义优化器和损失函数
    optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 使用随机梯度下降法优化网络参数,学习率为 0.2
    loss_func = torch.nn.MSELoss()  # 定义均方误差损失函数

    plt.ion()  # 开启交互模式,允许动态更新图像

    for epoch in range(200):
        prediction = net(x)  # 前向传播,使用当前网络计算预测值
        loss = loss_func(prediction, y)  # 计算预测值与真实值之间的误差

        optimizer.zero_grad()  # 清空上一步的梯度信息
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 根据梯度更新网络参数

        if epoch % 5 == 0:  # 每 5 个周期更新一次图像
            plt.cla()  # 清除当前图像内容
            plt.scatter(x.data.numpy(), y.data.numpy(), label='True Data')
            plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2, label='Prediction')
            plt.text(0.5, 0, f'Loss={loss.item():.4f}', fontdict={'size': 20, 'color': 'red'})
            plt.legend()  # 添加图例

            # 保存当前图像
            # file_path = os.path.join(target_directory, f'epoch_{epoch}.png')
            # plt.savefig(file_path)
            # print(f"图像已保存: {file_path}")
            plt.pause(0.1)  # 暂停以更新图像

    plt.ioff()  # 关闭交互模式
    plt.show()  # 显示最终图像


if __name__ == '__main__':
    print_hi('关系拟合')

复制粘贴并覆盖到你的 main.py 中运行,运行结果如下。

Hi, 关系拟合
Net(
  (hidden): Linear(in_features=1, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=1, bias=True)
)

六 源码地址

代码地址,GitHub 之 关系拟合 。

七 参考

[1] PyTorch 官方文档

[2] 莫烦 Python

[3] 可视化神经网络 TensorFlow Playground

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/942019.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

渐开线齿轮和摆线齿轮有什么区别?

摆线齿形与渐开线齿形的区别 虽然在比对这两种齿形,但有一个事情希望大家注意:渐开线齿轮只是摆线齿轮的一个特例。 (1)摆线齿形的压力角在啮合开始时最大,在齿节点减小到零,在啮合结束时再次增大到最大…

Debian 12 安装配置 fail2ban 保护 SSH 访问

背景介绍 双十一的时候薅羊毛租了台腾讯云的虚机, 是真便宜, 只是没想到才跑了一个月, 系统里面就收集到了巨多的 SSH 恶意登录失败记录. 只能说, 互联网真的是太不安全了. 之前有用过 fail2ban 在 CentOS 7 上面做过防护, 不过那已经是好久好久之前的故事了, 好多方法已经不…

Vulhub靶场Apache解析漏洞

一.apache_parsing 原理:Apache HTTPD ⽀持⼀个⽂件拥有多个后缀,并为不同后缀执⾏不同的指令。在Apache1.x/2.x中Apache 解析⽂件的规则是从右到左开始判断解析,如果后缀名为不可识别⽂件解析,就再往左判断。如 1.php.xxxxx 打开靶场 创建一个名为1.p…

MATLAB 抛物线拟合(Quadratic,二维)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 这里仍然是最小二乘法的应用,其推导过程如下所述: 1.二次函数模型: 其中,a、b 和 c 是需要确定的参数。 2.最小二乘法 假设我们有一组数据点 ( x 1 ​ , y 1

重温设计模式--原型模式

文章目录 原型模式定义原型模式UML图优点缺点使用场景C 代码示例深拷贝、浅拷贝 原型模式定义 用原型实例指定创建对象的种类,并且通过拷贝这些原型创建新的对象; 核心中的核心就是 克隆clone ,后面讲 原型模式是一种创建型设计模式,它的主要…

mac iterm2 使用 lrzsz

前言 mac os 终端不支持使用 rz sz 上传下载文件,本文提供解决方法。 mac 上安装 brew install lrzsz两个脚本 注意:/usr/local/bin/iterm2-send-zmodem.sh 中的 sz命令路径要和你mac 上 sz 命令路径一致。 /usr/local/bin/iterm2-recv-zmodem.sh 中…

【基础篇】1. JasperSoft Studio编辑器与报表属性介绍

编辑器介绍 Jaspersoft Studio有一个多选项卡编辑器,其中包括三个标签:设计,源代码和预览。 Design:报表设计页面,可以图形化拖拉组件设计报表,打开报表文件的主页面Source:源代码页码&#xff…

【magic-dash】01:magic-dash创建单页面应用及二次开发

文章目录 一、magic-dash是什么1.1 安装1.2 使用1.2.1 查看内置项目模板1.2.2 生成指定项目模板1.2.3 查看当前magic-dash版本1.2.4 查看命令说明1.2.5 内置模板列表二、创建虚拟环境并安装magic-dash三、magic-dash单页工具应用开发3.1 创建单页面项目3.1.1 使用命令行创建单页…

《Pytorch框架CV开发-从入门到实战》

目录 1.环境部署2.自动梯度计算张量 tensor3.线性回归4.逻辑回归6.人工神经网络的基本概念6.1 感知器6.2 激活函数6.3多层感知器6.4 反向传播算法——前向传播6.5 反向传播算法——反向传播6.6 反向传播算法——训练方法7.Pytorch基础数据集8.手写数字识别人工神经网络训练8.1 …

WebRTC学习二:WebRTC音视频数据采集

系列文章目录 第一篇 基于SRS 的 WebRTC 环境搭建 第二篇 基于SRS 实现RTSP接入与WebRTC播放 第三篇 centos下基于ZLMediaKit 的WebRTC 环境搭建 第四篇 WebRTC 学习一:获取音频和视频设备 第五篇 WebRTC学习二:WebRTC音视频数据采集 文章目录 系列文章…

国自然联合项目|影像组学智能分析理论与关键技术|基金申请·24-12-25

小罗碎碎念 该项目为国自然联合基金项目,执行年限为2019年1月至2022年12月,直接费用为204万元。 项目研究内容包括影像组学分析、智能计算、医疗风险评估等,旨在通过模拟医生诊断过程,推动人工智能在医疗领域的创新。 项目取得了…

怎样配备公共配套设施,才能让啤酒酿造流程高效环保?

今天,天泰邀请大家和我一起走进啤酒厂,了解水、蒸汽、压缩空气和二氧化碳这些基础设施如何助力啤酒生产,实现高效与环保的完美结合。 水 水是啤酒酿造的基础,啤酒厂对水质的要求极高。为了确保水质达标,啤酒厂设有专…

医疗行业 UI 设计系列合集(一):精准定位

在当今数字化时代,医疗行业与信息技术的融合日益紧密,UI 设计在其中扮演着至关重要的角色。精准定位的 UI 设计能够显著提升医疗产品与服务的用户体验,进而对医疗效果和患者满意度产生积极影响。 一、医疗行业 UI 设计的重要性概述 医疗行业…

本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——12使用YOLO-Bin

本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——12使用YOLO-Bin ​ 根据前面内容,所有的子任务已经基本结束,接下来就是调用转化的bin模型进行最后的逻辑控制了 1 .YOLO的bin使用 ​ 对于yolo其实有个简单的办法,也…

EMC整改

首先我们来从EMC测试项目构成说起,EMC主要包含两大项:EMI(干扰)和EMS(产品抗干扰和敏感度),当然这两大项中又包括许多小项目。 EMI主要测试项: RE(产品辐射&#xff0c…

Xcode 16 编译弹窗问题、编译通过无法,编译通过打包等问题汇总

问题1:打包的过程中不断提示 :codesign 想要访问你的钥匙串中的密钥“develop 或者distribution 证书” 解决:打开钥匙串,点击证书---显示简介---信任----改为始终信任 (记住 :不能只修改钥匙的显示简介的…

go window安装protoc protoc生成protobuf文件

1. 下载: Releases protocolbuffers/protobuf GitHub 2. 解压缩: 3. 配置环境变量: 选择系统变量->Path -> 新增 解压缩后的bin路径 4. 打印版本: protoc --version 5. 安装protoc-gen-go cmd 下输入安装命令&#xff0…

在【Arduino IDE】中在线下载和离线下载【ESP系列开发板的SDK】

在线下载 打开Arduino IDE,依次点击 文件➔首选项➔其他开发板管理器地址,复制粘贴以下的开发板管理地址: https://arduino.me/packages/esp32.json https://arduino.me/packages/esp8266.json 如下图所示,然后点击确定&#xf…

Arduino PID 控制教程

在控制系统中,控制器在出现错误和干扰的情况下将特定系统的输出校正为目标。最流行的控制器类型是PID ,它是比例、积分和微分的首字母缩写。在以下教程中,我将向您展示如何在项目中使用这种控制器。 什么是PID? 如上所述,PID 是比例、积分和微分的缩写。这种控制器仅用于反…

vue2 升级为 vite 打包

VUE2 中使用 Webpack 打包、开发,每次打包时间太久,尤其是在开发的过程中,本文记录一下 VUE2 升级Vite 步骤。 安装 Vue2 Vite 依赖 dev 依赖 vitejs/plugin-vue2": "^2.3.3 vitejs/plugin-vue2-jsx": "^1.1.1 vite&…