PyTorch-线性回归

已经进入大模微调的时代,但是学习pytorch,对后续学习rasa框架有一定帮助吧。

<!--  给出一系列的点作为线性回归的数据,使用numpy来存储这些点。 -->
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
                    [9.779], [6.182], [7.59], [2.167], [7.042],
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
                    [3.366], [2.596], [2.53], [1.221], [2.827],
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

<!--  转化tensor格式。 -->
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

<!--  这里的nn.Linear表示的是 y=w*x b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的。 -->
class linearRegression(nn.Module):
    def __init__(self):
        super(linearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 1 dimension

    def forward(self, x):
        out = self.linear(x)
        return out
model = linearRegression()

<!-- 定义loss和优化函数,这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是那些。 -->
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

<!-- 开始训练 -->
num_epochs = 1000
for epoch in range(num_epochs):
    inputs = Variable(x_train)
    target = Variable(y_train)
 
    # forward
    out = model(inputs) # 前向传播
    loss = criterion(out, target) # 计算loss
 
    # backward
    optimizer.zero_grad() # 梯度归零
    loss.backward() # 反向传播
    optimizer.step() # 更新参数
 
    if (epoch 1) % 20 == 0:
         print(f'Epoch[{epoch+1}/{num_epochs}], loss: {loss.item():.6f}')

<!--训练完成之后我们就可以开始测试模型了-->
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()

<!-- 显示图例 -->
fig = plt.figure(figsize=(10, 5))
plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')

plt.legend() 
plt.show()

<!-- 保存模型 -->
torch.save(model.state_dict(), './linear.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391376.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测

多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测 目录 多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 Matlab实现TCN-RVM时间卷积神经网络结合相关向量机…

跟着pink老师前端入门教程-day24

四、移动端WEB开发之响应式布局 1、响应式开发 1.1 响应式开发原理 就是使用媒体查询针对不同宽度的设备进行布局和样式的设置&#xff0c;从而适配不同设备的目的。 1.2 响应式布局容器 响应式需要一个父级做为布局容器&#xff0c;来配合子级元素来实现变化效果。 原理…

抽象队列同步器 AQS

文章目录 AQS一、AQS 概述1、什么是 AQS &#xff1f;2、AQS 架构图3、AQS 原理概述4、同步状态state5、FIFO等待队列6、AQS 中的 Node7、AQS 的特点 二、AQS 源码&#xff08;以 ReentrantLock 为例&#xff09;1、基本实现2、加锁1&#xff09;lock2&#xff09;addWaiter【1…

虚拟线程详解

前言 JDK21正式发布了虚拟线程 虚拟线程类似Golang中的协程&#xff0c;虚拟线程是轻量级线程&#xff0c;它可以大大减少编写、维护和观察高吞吐量并发应用程序的工作量&#xff0c;能够大大提升服务的高并发性能&#xff0c;允许通过 java.lang.Thread API 的现有代码来使用…

挑战杯 Yolov安全帽佩戴检测 危险区域进入检测 - 深度学习 opencv

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; Yolov安全帽佩戴检测 危险区域进入检测 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 该项目较为新颖&am…

如何实现Vuex数据持久化

Vuex是一个非常流行的状态管理工具&#xff0c;它可以帮助我们在Vue.js应用中管理和共享数据。然而&#xff0c;当应用重新加载或刷新时&#xff0c;Vuex的状态会被重置&#xff0c;这就导致了数据的丢失。那么&#xff0c;如何才能实现Vuex的数据持久化呢&#xff1f;让我们一…

正确看待OpenAI大模型Sora

2月16日凌晨&#xff0c;OpenAI发布了文生视频模型Sora。官方是这样描述的&#xff1a;Sora is an AI model that can create realistic and imaginative scenes from text instructions.Sora一个人工智能模型&#xff0c;它可以根据文本指令创建逼真和富有想象力的场景。Sora…

【NI-DAQmx入门】调整数据记录长度再进行数据处理

需要注意的是&#xff0c;初学者很容易造成一个大循环&#xff0c;导致采集循环的执行时间过长&#xff0c;最佳操作是采集循环只干采集的事&#xff0c;另起一个循环做数据拆解或分析。 有时需要以一定的采样率获取数据并记录所需的长度。然而&#xff0c;在处理这些数据时&am…

高校疫情防控系统的全栈开发实战

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

硬错误-STM32

需要修改栈大小 还得是野火的文档比较讲得深一点。

Transformer面试十问

1 Scaled Dot-Product Attention中为什么要除以 d k \sqrt{d_k} dk​ ​? 1. 从纯数学上考虑&#xff1a;对于输入均值为0,方差为1的分布&#xff0c;点乘后结果其方差为dk&#xff0c;所以需要缩放一下。下图为原论文注释。 2. 从神经网络上考虑&#xff1a;防止在计算点积…

【教学类-19-08】20240214《ABAB式-规律黏贴18格-手工纸15*15CM-一页3种图案,A空,纵向、无边框》(中班)

背景需求 利用15*15CM手工纸制作AB色块手环&#xff08;手工纸自带色彩&#xff09;&#xff0c;一页3个图案&#xff0c;2条为一组&#xff0c;黏贴成一个手环 素材准备 代码展示 # # 作者&#xff1a;阿夏 # 时间&#xff1a;2024年2月14日 # 名称&#xff1a;正方形数字卡…

《剑指Offer》笔记题解思路技巧优化 Java版本——新版leetcode_Part_3

《剑指Offer》笔记&题解&思路&技巧&优化_Part_3 &#x1f60d;&#x1f60d;&#x1f60d; 相知&#x1f64c;&#x1f64c;&#x1f64c; 相识&#x1f622;&#x1f622;&#x1f622; 开始刷题1. LCR 138. 有效数字——表示数值的字符串2. LCR 139. 训练计划…

数据结构对链表的初步认识(一)

已经两天没有更新了&#xff0c;今天就写一篇数据结构的链表吧&#xff0c;巩固自己也传授知识&#xff0c;不知道各位是否感兴趣看看这一篇有关联表的文章。 目录 链表的概念与结构 单向链表的实现 链表各个功能函数 首先我在一周前发布了一篇有关顺序表的文章&#xff0c;…

基于RTOS的嵌入式软件开发与可靠性提升

&#xff08;本文为简单介绍&#xff0c;观点来自网络&#xff09; 随着科技的快速发展&#xff0c;嵌入式系统无所不在&#xff0c;从你的智能手表到汽车的自动驾驶系统&#xff0c;它们都在静静地改变我们的世界。而在这一切的背后&#xff0c;实时操作系统&#xff08;RTOS&…

OpenAI 发布文生视频大模型 Sora,AI 视频要变天了,视频创作重新洗牌!AGI 还远吗?

一、一觉醒来&#xff0c;AI 视频已变天 早上一觉醒来&#xff0c;群里和朋友圈又被刷屏了。 今年开年 AI 界最大的震撼事件&#xff1a;OpenAI 发布了他们的文生视频大模型 Sora。 OpenAI 文生视频大模型 Sora 的横空出世&#xff0c;预示着 AI 视频要变天了&#xff0c;视…

Google Gemini 1.5:引领跨模态AIGC信息分析理解与视频内容推理的新篇章,与 Open AI 决一高下!

Gemini 1.5具有100万token的上下文理解能力&#xff0c;是目前最强&#xff01;具有跨模态理解和推理&#xff1a;能够对文本、代码、图像、音频和视频进行高度复杂的理解和推理。允许分析1小时视频、11小时音频、超过30,000行代码或超过700,000字的文本。不过谷歌这个Gemini 1…

简单聊聊k8s,和docker之间的关系

前言 随着云原生和微服务架构的快速发展&#xff0c;Kubernetes和Docker已经成为了两个重要的技术。但是有小伙伴通常对这两个技术的关系产生疑惑&#xff1a; 既然有了docker&#xff0c;为什么又出来一个k8s&#xff1f; 它俩之间是竞品的关系吗&#xff1f; 傻傻分不清。…

数据预处理 —— AI算法初识

一、预处理原因 AI算法对数据进行预处理的原因主要基于以下几个核心要点&#xff1a; 1. **数据清洗**&#xff1a; - 数据通常包含缺失值、异常值或错误记录&#xff0c;这些都会干扰模型训练和预测准确性。通过预处理可以识别并填充/删除这些不完整或有问题的数据。 2. **数…

问题记录——c++ sort 函数 和 严格弱序比较

引出 看下面这段cmp函数的定义 //按照vector第一个元素升序排序 static bool cmp(const vector<int>& a, const vector<int>& b){return a[0] < b[0]; }int eraseOverlapIntervals(vector<vector<int>>& intervals) {//按区间左端排序…