采用自动微分进行模型的训练

 自动微分训练模型

 简单代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的线性回归模型
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1

    def forward(self, x):
        return self.linear(x)

# 准备训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])

# 实例化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 训练模型
epochs = 1000
for epoch in range(epochs):
    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播
    optimizer.zero_grad()  # 清空之前的梯度
    loss.backward()  # 自动计算梯度
    optimizer.step()  # 更新模型参数

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 测试模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'预测值: {predicted.item():.4f}')

代码分解:

1.定义一个简单的线性回归模型:

  • LinearRegression 类继承自nn.Module,这是所有神经网络模型的基类
  • 在 __init__ 方法中,定义了一个线性层 self.linear,它的输入维度是1,输出维度也是1。
  • forward 方法定义了数据在模型中的传播路径,即输入 x 经过 self.linear 层后得到输出。
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1
    
        def forward(self, x):
            return self.linear(x)
    

2.准备训练数据:

  • x_train 和 y_train 分别是输入和目标输出的训练数据。每个张量表示一个样本,x_train 中的每个元素是一个维度为1的张量,因为模型的输入维度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.实例化模型,损失函数和优化器:

  • model 是我们定义的 LinearRegression 类的一个实例,即我们要训练的线性回归模型。
  • criterion 是损失函数,这里选择了均方误差损失(MSE Loss),用于衡量预测值与实际值之间的差异。
  • optimizer 是优化器,这里选择了随机梯度下降(SGD),用于更新模型参数以最小化损失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    

4.训练模型:

  • 这里进行了1000次迭代的训练过程。
  • 在每个迭代中,首先进行前向传播,计算模型对 x_train 的预测输出 outputs,然后计算损失 loss
  • 调用 optimizer.zero_grad() 来清空之前的梯度,然后调用 loss.backward() 自动计算梯度,最后调用 optimizer.step() 来更新模型参数
    epochs = 1000
    for epoch in range(epochs):
        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
    
        # 反向传播
        optimizer.zero_grad()  # 清空之前的梯度
        loss.backward()  # 自动计算梯度
        optimizer.step()  # 更新模型参数
    
        if (epoch+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.测试模型:

  • x_test 是用来测试模型的输入数据,这里表示输入为4.0。
  • model(x_test) 对 x_test 进行前向传播,得到预测结果 predicted
  • predicted.item() 取出预测结果的标量值并打印出来。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'预测值: {predicted.item():.4f}')
    

运行结果:

运行结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/799035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

链接追踪系列-07.logstash安装json_lines插件

进入docker中的logstash 容器内: jelexbogon ~ % docker exec -it 7ee8960c99a31e607f346b2802419b8b819cc860863bc283cb7483bc03ba1420 /bin/sh $ pwd /usr/share/logstash $ ls bin CONTRIBUTORS Gemfile jdk logstash-core modules tools x-pack …

如何预防最新的baxia变种勒索病毒感染您的计算机?

引言 在当今数字化时代,网络安全威胁层出不穷,其中勒索病毒已成为企业和个人面临的重大挑战之一。近期,.baxia勒索病毒以其高隐蔽性和破坏性引起了广泛关注。本文将详细介绍.baxia勒索病毒的特点、传播方式,并给出相应的应对策略…

2024-07-15 Unity插件 Odin Inspector3 —— Button Attributes

文章目录 1 说明2 Button 特性2.1 Button2.2 ButtonGroup2.3 EnumPaging2.4 EnumToggleButtons2.5 InlineButton2.6 ResponsiveButtonGroup 1 说明 ​ 本章介绍 Odin Inspector 插件中有关 Button 特性的使用方法。 2 Button 特性 2.1 Button 依据方法,在 Inspec…

YOLOv8训练自己的数据集(超详细)

一、准备深度学习环境 本人的笔记本电脑系统是:Windows10 YOLO系列最新版本的YOLOv8已经发布了,详细介绍可以参考我前面写的博客,目前ultralytics已经发布了部分代码以及说明,可以在github上下载YOLOv8代码,代码文件夹…

力扣经典题目之->移除值为val元素的讲解,的实现与讲解

一:题目 博主本文将用指向来形象的表示下标位的移动。 二:思路 1:两个整形,一个start,一个end,在一开始都 0,即这里都指向第一个元素。 2:在查到val之前,查一个&…

昇思25天学习打卡营第12天|MindSpore 助力下的 GPT2:数据集加载处理及模型全攻略

环境配置 %%capture captured_output 此乃 Jupyter Notebook 中的一个魔法命令,其作用在于捕获后续单元格中的输出,并将之存储于变量 captured_output 当中,而非直接于输出区域予以显示。如此一来,便可隐匿某些可能存在的输出信息…

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…

小程序里面使用vant ui中的vant-field组件,如何使得输入框自动获取焦点

//.wxml <van-fieldmodel:value"{{ userName }}"placeholder"请输入学号"focus"{{focusUserName}}"/>// .js this.setData({focusUserName: true});vant-field

钡铼ARMxy控制器在智能网关中的应用

随着IoT物联网技术的飞速发展&#xff0c;智能网关作为连接感知层与网络层的枢纽&#xff0c;可以实现感知网络和通信网络以及不同类型感知网络之间的协议转换。钡铼技术的ARMxy系列控制器凭借其高性能、低功耗和高度灵活性的特点&#xff0c;在智能网关中发挥了关键作用&#…

RPC与服务的注册发现

文章目录 1. 什么是远程过程调用(RPC)?2. RPC的流程3. RPC实践4. RPC与REST的区别4.1 RPC与REST的相似之处4.2 RPC与REST的架构原则4.3 RPC与REST的主要区别 5. RPC与服务发现5.1 以zookeeper为服务注册中心5.2 以etcd为服务注册中心 6. 小结参考 1. 什么是远程过程调用(RPC)?…

大语言模型诞生过程剖析

过程图如下 &#x1f4da; 第一步&#xff1a;海量文本的无监督学习 得到基座大模型&#x1f389; &#x1f50d; 原料&#xff1a;首先&#xff0c;我们需要海量的文本数据&#xff0c;这些数据可以来自互联网上的各种语料库&#xff0c;包括书籍、新闻、科学论文、社交媒体帖…

<数据集>光伏板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;2400张 标注数量(xml文件个数)&#xff1a;2400 标注数量(txt文件个数)&#xff1a;2400 标注类别数&#xff1a;4 标注类别名称&#xff1a;[Crack,Grid,Spot] 序号类别名称图片数框数1Crack8688922Grid8248843S…

全栈智能家居系统设计方案:STM32+Linux+多协议(MQTT、Zigbee、Z-Wave)通信+云平台集成

1. 项目概述 随着物联网技术的快速发展,智能家居系统正在成为现代生活中不可或缺的一部分。本文介绍了一个基于STM32微控制器和Linux系统的智能家居解决方案,涵盖了硬件设计、软件架构、通信协议以及云平台集成等方面。 该系统具有以下特点: 采用STM32作为终端设备的控制核心…

springboot3——项目部署

springboot的项目开发完了&#xff0c;怎么样把他放到服务器上或者生产环境上让他运行起来跑起来。就要牵扯到项目部署&#xff0c;打包的方式了。 springboot支持jar和war: 打jar包&#xff1a;默认方式&#xff0c;项目开发完打个jar包&#xff0c;通过命令把jar包起起来就…

汇川ST 实现分拣

//初始化 IF init FALSE THEN// 初始化init : 1 ;//45 Y数组 BOOL[8] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] 不保持 私有 Y0(*Y数组[0] BOOL OFF 发料…

MySQL 中的几种锁

MySQL 中的锁 #按锁粒度如何划分? 按锁粒度划分的话&#xff0c;MySQL 的锁有&#xff1a; 表锁&#xff1a;开销小&#xff0c;加锁快&#xff1b;锁定力度大&#xff0c;发生锁冲突概率高&#xff0c;并发度最低;不会出现死锁。行锁&#xff1a;开销大&#xff0c;加锁慢…

unity宏编译版本

在写c程序的时候我们通常可以用不同的宏定义来控制不同版本的编译内容&#xff0c;最近有个需求就是根据需要编译一个完全体验版本&#xff0c;就想到了用vs的那套方法。经过研究发现unity也有类似的控制方法。 注意这里设置完后要点击右下的应用&#xff0c;我起先就没有设置…

7/13 - 7/15

vo.setId(rs.getLong("id"))什么意思&#xff1f; vo.setId(rs.getLong("id")); 这行代码是在Java中使用ResultSet对象&#xff08;通常用于从数据库中检索数据&#xff09;获取一个名为"id"的列&#xff0c;并将其作为long类型设置为一个对象…

深度学习基础:Numpy 数组包

数组基础 在使用导入 Numpy 时&#xff0c;通常给其一个别名 “np”&#xff0c;即 import numpy as np 。 数据类型 整数类型数组与浮点类型数组 为了克服列表的缺点&#xff0c;一个 Numpy 数组只容纳一种数据类型&#xff0c;以节约内存。为方便起见&#xff0c;可将 Nu…

简洁实用的原创度检测工具AntiPlagiarism NET 4.132

AntiPlagiarism NET是一个适用于Windows的程序&#xff0c;它允许您检查文本的唯一性和从不同Internet来源借用的存在。使用AntiPlagiarism NET&#xff0c;您可以&#xff1a; 将程序用于不同的目的该程序适用于学生、教师、记者、文案作者和其他需要检查其文本或其他作者文本…