deep learning with pytorch(一)

1.create a basic nerual network model with pytorch

数据集 Iris UCI Machine Learning Repository

fully connected  

目标:创建从输入层的代码开始,向前移动到隐藏层,最后到输出层

# %%
import torch
import torch.nn as nn
import torch.nn.functional as F

# %%
# create a model class that inherits nn.Module 这里是Module 不是model
class Model(nn.Module):
    #input layer (4 features of the flower) -->
    #  Hidden layer1 (number of neurons) -->
    #  H2(n) --> output (3 classed of iris flowers)
    def __init__(self, in_features = 4, h1 = 8, h2 = 9, out_features = 3):
        super().__init__() # instantiate out nn.Module 实例化
        self.fc1 = nn.Linear(in_features= in_features, out_features= h1)
        self.fc2 = nn.Linear(in_features= h1, out_features= h2)
        self.out = nn.Linear(in_features= h2, out_features= out_features)
    
    # moves everything forward 
    def forward(self, x):
        # rectified linear unit 修正线性单元 大于0则保留,小于0另其等于0
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)

        return x
    

# %%
# before we turn it on we need to create a manual seed, because networks involve randomization every time.
# say hey start here and then go randomization, then we'll get basically close to the same outputs

# pick a manual seed for randomization
torch.manual_seed(seed= 41)
# create an instance of model
model = Model()


2.load data and train nerual network model 

torch.optim

torch.optim — PyTorch 2.2 documentation

1. optimizer.zero_grad()

  • 作用: 清零梯度。在训练神经网络时,每次参数更新前,需要将梯度清零。因为如果不清零,梯度会累加到已有的梯度上,这是PyTorch的设计决策,目的是为了处理像RNN这样的网络结构,它们在一个循环中多次计算梯度。

  • 原理: PyTorch在进行反向传播(backward)时,会累计梯度,而不是替换掉当前的梯度值。因此,如果不手动清零,梯度值会不断累积,导致训练过程出错。

2. loss.backward()

  • 作用: 计算梯度。这一步会根据损失函数对模型参数进行梯度的计算。在神经网络中,损失函数衡量的是模型输出与真实标签之间的差异,通过反向传播算法,可以计算出损失函数关于模型各个参数的梯度。

  • 原理: 反向传播是一种有效计算梯度的算法,它首先计算输出层的梯度,然后逆向逐层传播至输入层。这个过程依赖于链式法则,是深度学习训练中的核心。

3. optimizer.step()

  • 作用: 更新参数。基于计算出的梯度,更新模型的参数。这一步实际上是在执行优化算法(如SGD、Adam等),根据梯度方向和设定的学习率调整参数值,以减小损失函数的值。

  • 原理: 优化器根据梯度下降(或其它优化算法)更新模型参数。梯度指示了损失函数增长最快的方向,因此通过向相反方向调整参数,模型的预测误差会逐渐减小。

  • # %%
    import pandas as pd
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    # %%
    # url = 'https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv'
    my_df = pd.read_csv('dataset/iris.csv')
    
    # %%
    # change last column from strings to integers
    my_df['species'] = my_df['species'].replace('setosa', 0.0)
    my_df['species'] = my_df['species'].replace('versicolor', 1.0)
    my_df['species'] = my_df['species'].replace('virginica', 2.0)
    my_df
    
    # my_df.head()
    # my_df.tail()
    
    # %%
    # train test split ,set X,Y    
    X = my_df.drop('species', axis = 1) # 删除指定列
    y = my_df['species']
    
    # %%
    #Convert these to numpy arrays
    X = X.values
    y = y.values
    # X
    
    # %%
    # train test split
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size= 0.2, random_state= 41)
    
    # %%
    # convert X features to float tensors
    X_train = torch.FloatTensor(X_train)
    X_test = torch.FloatTensor(X_test)
    #convert y labels to long tensors
    y_train = torch.LongTensor(y_train)
    y_test = torch.LongTensor(y_test)
    
    # %%
    # set the criterion of model to measure the error,how far off the predicitons are from the data
    criterion = nn.CrossEntropyLoss()
    # choose Adam optimizer, lr = learing rate (if error does not go down after a bunch of 
    # iterations(epochs), lower our learning rate),学习率越低,学习所需时间越长
    optimizer = torch.optim.Adam(model.parameters(), lr= 0.01)
    # 传进去的参数包括fc1, fc2, out
    # model.parameters
    
    # %%
    # train our model
    # epochs? (one run through all the training data in out network )
    epochs = 100
    losses = []
    for i in range(epochs):
        # go forward and get a prediction
        y_pred = model.forward(X_train) # get a predicted results
    
        #measure the loss/error, gonna be high at first
        loss = criterion(y_pred, y_train) # predicted values vs y_train
    
        # keep track of our losses
        #detach()不再跟踪计算图中的梯度信息,numpy(): 这个方法将PyTorch张量转换成NumPy数组。因为NumPy数组在Python科学计算中非常普遍,很多库和函数需要用到NumPy数组作为输入。
        losses.append(loss.detach().numpy()) 
    
        #print every 10 epoches
        if i % 10 == 0:
            print(f'Epoch: {i} and loss: {loss}')
        
        # do some back propagation: take the error rate of forward propagation and feed it back
        # thru the network to fine tune the weights
        # optimizer.zero_grad() 清零梯度,为新的梯度计算做准备。
        # loss.backward() 计算梯度,即对损失函数进行微分,获取参数的梯度。
        # optimizer.step() 更新参数,根据梯度和学习率调整参数值以最小化损失函数。
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    
    # %%
    # graph it out
    plt.plot(range(epochs), losses)
    plt.ylabel("loss/error")
    plt.xlabel("Epoch")
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/426714.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Angular基础---HelloWorld---Day1

文章目录 1. 创建Angular 项目2.对Angular架构的最基本了解3.创建并引用新的组件(component)4.对Angular架构新的认识(多组件)5.组件中业务逻辑文件的编辑(ts文件)6.标签中属性的绑定(1) ID的绑定(2) class…

AI大模型与小模型之间的“脱胎”与“反哺”(第二篇)

此图片来源于网络 21. **跨模态学习(Cross-Modal Learning)**: 如果各个行业AI小模型涉及多种数据类型或模态,可以通过跨模态学习技术让大模型理解并整合这些不同模态之间的关联,从而提升对多行业复杂问题的理解和解…

【Redis 主从复制】

文章目录 1 :peach:环境配置:peach:1.1 :apple:三种配置方式:apple:1.2 :apple:验证:apple:1.3 :apple:断开复制和切主:apple:1.4 :apple:安全性:apple:1.5 :apple:只读:apple:1.6 :apple:传输延迟:apple: 2 :peach:拓扑结构:peach:2.1 :apple:⼀主⼀从结构:apple:2.2 :apple:⼀…

【FPGA/IC】CRC电路的Verilog实现

前言 在通信过程中由于存在各种各样的干扰因素,可能会导致发送的信息与接收的信息不一致,比如发送数据为 1010_1010,传输过程中由于某些干扰,导致接收方接收的数据却成了0110_1010。为了保证数据传输的正确性,工程师们…

30天JS挑战(第十五天)------本地存储菜谱

第十五天挑战(本地存储菜谱) 地址:https://javascript30.com/ 所有内容均上传至gitee,答案不唯一,仅代表本人思路 中文详解:https://github.com/soyaine/JavaScript30 该详解是Soyaine及其团队整理编撰的,是对源代…

【数据结构】B树

1 B树介绍 B树(英语:B-tree),是一种在计算机科学自平衡的树,能够保持数据有序。这种数据结构能够让查找数据、顺序访问、插入数据及删除的动作,都在对数时间内完成。B树,概括来说是一个一般化的…

CAS外部云迁移vmware虚拟机兼容性问题处理

CAS外部云迁移vmware虚拟机兼容性问题处理 1、迁移过程中报错实图 2、问题原因 打开虚拟机存储的位置,发现文件夹下存在ctk.vmdk的文件 3、在vmware右键虚拟机编辑设置 注:虚拟机需要先关机 点击虚拟机选项——高级——编辑设置 将ctk.ENABLED改为…

第五套CCF信息学奥赛c++练习题 CSP-J认证初级组 中小学信奥赛入门组初赛考前模拟冲刺题(选择题)

第五套中小学信息学奥赛CSP-J考前冲刺题 1、不同类型的存储器组成了多层次结构的存储器体系,按存取速度从快到慢排列的是 A、快存/辅存/主存 B、外存/主存/辅存 C、快存/主存/辅存 D、主存/辅存/外存 答案:C 考点分析:主要考查计算机相关知识&…

在ubuntu上安装hadoop完分布式

准备工作 Xshell安装包 Xftp7安装包 虚拟机安装包 Ubuntu镜像源文件 Hadoop包 Java包 一、安装虚拟机 创建ubuntu系统 完成之后会弹出一个新的窗口 跑完之后会重启一下 按住首先用ctrlaltf3进入命令界面,输入root,密码登录管理员账号 按Esc 然后输入 …

蓝牙BLE 5.0、5.1、5.2和5.3区别

随着科技的不断发展,蓝牙技术也在不断进步,其中蓝牙BLE(Bluetooth Low Energy)是目前应用广泛的一种蓝牙技术,而BLE 5.0、5.1、5.2和5.3则是其不断升级的版本。本文将对这四个版本的区别进行详细的比较。 一、BLE 5.0…

为啥要用C艹不用C?

在很多时候,有人会有这样的疑问 ——为什么要用C?C相对于C优势是什么? 最近两年一直在做Linux应用,能明显的感受到C带来到帮助以及快感 之前,我在文章里面提到环形队列 C语言,环形队列 环形队列到底是怎么回…

FPGA高端项目:FPGA基于GS2971的SDI视频接收+纯verilog图像缩放+多路视频拼接,提供8套工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐本博已有的 SDI 编解码方案本方案的SDI接收转HDMI输出应用本方案的SDI接收图像缩放应用本方案的SDI接收HLS图像缩放HLS多路视频拼接应用本方案的SDI接收HLS动态字符叠加输出应用本方案的SDI接收HLS多路视频融合叠加应用本方案的SDI接收GTX…

【代码】Android|获取压力传感器、屏幕压感数据(大气压、原生和Processing)

首先需要分清自己需要的是大气压还是触摸压力,如果是大气压那么就是TYPE_PRESSURE,可以参考https://source.android.google.cn/docs/core/interaction/sensors/sensor-types?hlzh-cn。如果是触摸压力就是另一回事,我需要的是触摸压力。 不过…

【算法沉淀】刷题笔记:并查集 带权并查集+实战讲解

🎉🎉欢迎光临🎉🎉 🏅我是苏泽,一位对技术充满热情的探索者和分享者。🚀🚀 🌟特别推荐给大家我的最新专栏《数据结构与算法:初学者入门指南》📘&am…

Windows Server 各版本搭建文件服务器实现共享文件(03~19)

一、Windows Server 2003 打开服务器,点击左下角开始➡管理工具➡管理您的服务器➡添加或删除角色 点击下一步等待测试 勾选自定义配置,点击下一步 选择文件服务器,点击下一步 勾选设置默认磁盘空间,数据自己更改,最…

Onenote软件新建笔记本时报错:无法在以下位置新建笔记本

报错现象: 当在OneNote软件上,新建笔记本时: 然后,尝试重新登录微软账户,也不行,提示报错: 解决办法: 打开一个新的记事本,复制粘贴以下内容: C:\Users\Adm…

如何防御跨站请求伪造(CSRF)攻击?

CSRF 英文全称是 Cross-site request forgery,所以又称为“跨站请求伪造”,是指恶意诱导用户打开被精心构造的网站,在该网站中,利用用户的登录状态发起的跨站请求。简单来讲,CSRF 就是利用了用户的登录状态&#xff0c…

WordPress建站入门教程:如何在本地电脑搭建WordPress网站?

前面跟大家分享了『WordPress建站入门教程:如何安装本地WordPress网站运行环境?』,接下来boke112百科就继续跟大家分享本地电脑如何搭建WordPress网站。 小皮面板(phpstudy)的“软件管理 – 网站程序”虽然可以一键部…

excel统计分析——拉丁方设计

参考资料:生物统计学 拉丁方设计也是随机区组设计,是对随机区组设计的一种改进。它在行的方向和列的方向都可以看成区组,因此能实现双向误差的控制。在一般的试验设计中,拉丁方常被看作双区组设计,用于提高发现处理效应…

身份证识别系统(安卓)

设计内容与要求: 通过手机摄像头捕获身份证信息,将身份证上的姓名、性别、出生年月、身份证号码保存在数据库中。1)所开发Apps软件至少需由3-5个以上功能性界面组成。要求:界面美观整洁、方便应用;可以使用Android原生…