使用pytorch实现一个线性回归训练函数

使用sklearn.dataset 的make_regression创建用于线性回归的数据集

def create_dataset():
    x, y, coef = make_regression(n_samples=100, noise=10, coef=True, bias=14.5, n_features=1, random_state=0)
    return torch.tensor(x), torch.tensor(y), coef

加载数据集,并拆分batchs训练集

def load_dataset(x, y, batch_size):
    data_len = len(y)
    batch_num = data_len // batch_size
    for idx in range(batch_num):
        start = idx * batch_num
        end = idx * batch_num + batch_num
        train_x = x[start : end]
        train_y = y[start : end]
        yield train_x, train_y

定义初始权重和定义计算函数

w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0, requires_grad=True, dtype=torch.float64)
def linear_regression(x):
    return x * w + b

损失函数使用平方差

def linear_loss(y_pred, y_true):
    return (y_pred - y_true) ** 2

优化参数使用梯度下降方法

def sgd(linear_rate, batch_size):
    w.data = w.data - linear_rate * w.grad / batch_size
    b.data = b.data - linear_rate * b.grad / batch_size

训练代码

def train():
    # 加载数据
    x, y, coef = create_dataset()
    data_len = len(y)

    # 定义参数
    batch_size = 10
    epochs = 100
    linear_rate = 0.01

    # 记录损失值
    epochs_loss = []

    # 迭代
    for eid in range(epochs):
        total_loss = 0.0
        for train_x, train_y in load_dataset(x, y, batch_size):
            # 输入模型
            y_pred = linear_regression(train_x)

            # 计算损失
            loss_num = linear_loss(y_pred, train_y.reshape(-1,1)).sum()

            # 梯度清理
            if w.grad is not None:
                w.grad.zero_()
            if b.grad is not None:
                b.grad.zero_()

            # 反向传播
            loss_num.backward()

            # 更新权重
            sgd(linear_rate, batch_size)

            # 统计损失数值
            total_loss = total_loss + loss_num.item()

        # 记录本次迭代的平均损失
        b_loss = total_loss / data_len
        epochs_loss.append(b_loss)
        print("epoch={},b_loss={}".format(eid, b_loss))

    # 显示预测线核真实线的拟合关系
    print(w, b)
    print(coef, 14.5)

    plt.scatter(x, y)

    test_x = torch.linspace(x.min(), x.max(), 1000)
    y1 = torch.tensor([v * w + b for v in test_x])
    y2 = torch.tensor([v * coef + 14.5 for v in test_x])
    plt.plot(test_x, y1, label='train')
    plt.plot(test_x, y2, label='true')
    plt.grid()
    plt.show()

    # 显示损失值变化曲线
    plt.plot(range(epochs), epochs_loss)
    plt.show()

拟合显示还不错

损失值在低5次迭代后基本就很小了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/415549.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯_中断系统

一 中断 中断,即cpu暂停执行当前程序,转而执行另外一段特殊程序,处理结束后。返回之前暂停程序继续执行。 中断向量,中断服务程序的入口地址,每个中断源都对应一个固定的入口地址。 中断服务函数,内核响应中…

仿牛客网项目---用户注册登录功能的实现

从今天开始我们来写一个新项目,这个项目是一个完整的校园论坛的项目。主要功能模块:用户登录注册,帖子发布和热帖排行,点赞关注,发送私信,消息通知,社区搜索等。这篇文章我们先试着写一下用户的…

重拾前端基础知识:HTML

重拾前端基础知识:HTML 前言HTML编辑器标签锚链接 id和class表单Iframe字符实体HTML CSSHTML JavaScriptHTML 统一资源定位器(Uniform Resource Locators)XHTMLHTML5新的语义化元素视频和音频支持Canvas绘图本地存储表单增强Web Workers地理定位Web存储伸缩矢量图形…

《TCP/IP详解 卷一》第10章 UDP和IP分片

目录 10.1 引言 10.2 UDP 头部 10.3 UDP校验和 10.4 例子 10.5 UDP 和 IPv6 10.6 UDP-Lite 10.7 IP分片 10.7.1 例子:IPV4 UDP分片 10.7.2 重组超时 10.8 采用UDP的路径MTU发现 10.9 IP分片和ARP/ND之间的交互 10.10 最大UDP数据报长度 10.11 UDP服务器…

2024深度学习主流框架对比

tensorFlow 是最受欢迎和广泛使用的深度学习框架之一,目前在github的start数为181k。 TensorFlow是一个由Google Brain团队开发的开源深度学习框架。它允许开发者创建多种机器学习模型,包括卷积神经网络、循环神经网络和深度神经网络等,该框架…

全国产飞腾E2000Q +复旦微FPGA的轨道交通、电力解决方案

产品概述 ITX-XMF201是一款高性能边缘计算网关主板,采用飞腾E2000Q 4核处理器,国产化率达到95%国产化。 板载2电口,2路CAN,6路RS232接口,1路RS485接口,16路GPIO,可以满足银行、轨道交通、电力等…

【ArcGIS Pro二次开发】(83):ProWindow和WPF的一些技巧

在ArcGIS Pro二次开发中,SDK提供了一种工具界面【ArcGIS Pro ProWindow】。 关于ProWindow的用法,之前写过一篇基础的教程: 【ArcGIS Pro二次开发】(13):ProWindow的用法_arcgispro二次开发教程-CSDN博客 主要是对几个常用控件…

Java开发的核心模式 - MVC

文章目录 1、MVC设计模式2、Web开发本质3、服务器的性能瓶颈 1、MVC设计模式 MVC设计模式示意图 在整个Java学习之旅中,MVC(Model-View-Controller)设计模式无疑占据着极其重要的地位,堪称理解和掌握Java项目开发精髓的钥匙。如…

【Python】Python实现串口通信(Python+Stm32)

🎉欢迎来到Python专栏~Python实现串口通信 ☆* o(≧▽≦)o *☆嗨~我是小夏与酒🍹 ✨博客主页:小夏与酒的博客 🎈该系列文章专栏:Python学习专栏 文章作者技术和水平有限,如果文中出现错误,希望…

3D数字孪生

数字孪生(Digital Twin)是物理对象、流程或系统的虚拟复制品,用于监控、分析和优化现实世界的对应物。 这些数字孪生在制造、工程和城市规划等领域变得越来越重要,因为它们使我们能够在现实世界中实施改变之前模拟和测试不同的场景…

Spring 类型转换、数值绑定与验证(三)— Formatting 与 Validation

1 Formatting 在Spring中用于格式化数据及根据地域展示不同格式的数据。 图 Formatting接口 UML 1.1 注解驱动Formatting 自定义像“DateTimeFormat”注解来对相关字段格式化的步骤为: 自定义注解。定义一个实现AnnotationFormatterFactory接口的工厂类。往容器…

H3C OSPF Stub特殊区域+认证实验

H3C OSPF Stub特殊区域认证实验 实验拓扑 ​​ 实验需求 按照图示配置 IP 地址,所有路由器配置环回口 IP 地址为 X.X.X.X/32​ 作为 Router-id,X 为设备编号(R5 除外)按照图示分区域配置 OSPFR1 上配置默认路由,指…

数据卷dockerfile

目录 一、数据卷 1. 简介 2. 数据卷和数据卷容器 1. 数据卷: 2. 数据卷容器: 二、自定义镜像 1. 作用 2. 自定义centos 3. 自定义tomcat8 一、数据卷 1. 简介 数据卷是一个可供一个或多个容器使用的特殊目录,它将主机操作系统目录直…

小白水平理解面试经典题目leetcode. 606 Construct String from Binary Tree【递归算法】

Leetcode 606. 从二叉树构造字符串 题目描述 例子 小白做题 坐在自习室正在准备刷题的小白看到这道题,想想自己那可是没少和白月光做题呢,也不知道小美刷题刷到哪里了,这题怎么还没来问我,难道是王谦谦去做题了? 这…

使用Java和PostGis的全国A级风景区数据入库实战

目录 前言 一、数据介绍 1、空间数据 2、属性表说明 3、QGIS数据预览 二、PostGIS空间数据库设计 1、空间表结构 三、Java空间入库 1、实体定义 2、数据操作Mapper 3、业务层实现 4、入库 5、数据入库验证 总结 前言 星垂平野阔,月涌大江流”“晴川历历…

.NET生成MongoDB中的主键ObjectId

前言 因为很多场景下我们需要在创建MongoDB数据的时候提前生成好主键为了返回或者通过主键查询创建的业务,像EF中我们可以生成Guid来,本来想着要不要实现一套MongoDB中ObjectId的,结果发现网上各种各样的实现都有,不过好在阅读C#…

机器人内部传感器阅读梳理及心得-速度传感器-模拟式速度传感器

速度传感器是机器人内部传感器之一,是闭环控制系统中不可缺少的重要组成部分,它用来测量机器人关节的运动速度。可以进行速度测量的传感器很多,如进行位置测量的传感器大多可同时获得速度的信息。但是应用最广泛、能直接得到代表转速的电压且…

基于stm32F103的座面声控台灯

1.基本内容: 设计一个放置在桌面使用的台灯,使用220v交流电供电。具备显示屏能够实时显示日期(年、月、日和星期),时间(小时、分钟、秒)和温度(摄氏度);能够通…

高校物品捐赠管理系统|基于springboot高校物品捐赠管理系统设计与实现(源码+数据库+文档)

高校物品捐赠管理系统目录 目录 基于springboot高校物品捐赠管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、捐赠信息管理 3、论坛信息管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算…

回归预测 | Matlab实现CPO-BiTCN-BiGRU冠豪猪算法优化双向时间卷积门控循环单元多变量回归预测

回归预测 | Matlab实现CPO-BiTCN-BiGRU冠豪猪算法优化双向时间卷积门控循环单元多变量回归预测 目录 回归预测 | Matlab实现CPO-BiTCN-BiGRU冠豪猪算法优化双向时间卷积门控循环单元多变量回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现CPO-BiTCN-B…