机器学习-pytorch1(持续更新)

上一节我们学习了机器学习的线性模型和非线性模型的机器学习基础知识,这一节主要将公式变为代码

代码编写网站:https://colab.research.google.com/drive

学习课程链接:ML 2022 Spring

1、Load Data(读取数据)

这需要用到pytorch里面的两个函数Dataset和Dataloader

torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset:是用来存储数据样本和期望值

dataset = MyDataset(file)

Dataloader:批量对数据进行分组,启用多处理

dataloader = DataLoader(dataset, batch_size, shuffle=True)

// 其中对于shuffle的取值,True表示训练,false表示测试

关于Dataset和Dataloader的关系如下:

 

ML 2022 Spring为图片来源

我们读取完数据,是不是想知道我们的数据长什么样子呢?(我们称数据为Tensors)

首先,它可能是一个一维数据,比如一个音频、一个温度

其次,还可能是一个二维数据,比如一张二值图像

最后,还可能是一个三维数据,比如一个彩色的图像

又有问题了,我们怎么通过编程得到我们图像的大小?

可以使用pytorch里面的shape()函数

我们怎么通过编程创造我们的数据呢?

eg:
x = torch.tensor([[1,-1],[-1,1]])
x = torch.from_numpy(np.array([[1,-1],[-1,1]]))
全0或全1数据
x = torch.zeros([2,2])    # 2*2的全0数据
x = torch.ones([1,2,5])    # 1*2*5的全1数据

 其次,还支持矩阵的运算

Addition:z = x + y
Subtraction:z = x - y
Power:y = x.pow(2)
Summation:y = x.sum()
Mean:y = x.mean()
维度转换:x = x.transpose(dim0,dim1)
消除维度:x = x.squeeze(dim)
增加维度:x = x.unsqueeze(dim)
组合:w = torch.cat([x,y,z],dim=1)

拥有不同的数据类型:

使用.to()可以切换到不同的设备:

CPU: x = x.to('cpu')
GPU: x = x.to('cuda')

 这里就又涉及到如何检查你的GPU了?可以使用以下语句检查你的计算机是否有GPU:

torch.cuda.is_available()

如何计算梯度?

 // 注意矩阵一定要使用小数点

2、Define Neural Network(训练和测试神经网络)

torch.nn.Module

线性: 

 非线性:

Sigmoid Activation:nn.Sigmoid()

ReLU Activation:nn.ReLU()

下面我根据所学的知识构建我自己的神经网络:

3、Loss Function(损失函数) 

x = torch.nn.MSELoss    # 对于回归任务
x = torch.nn.CrossEntropyLoss etc.    # 对于分类任务
loss = x(model_output,expected_value)

4、Optimization Algorithm(优化)

torch.optim

这是基于梯度的优化算法,不断调整参数,减少误差

比如:随机梯度下降(SGD)

torch.optim.SGD(model.parameters(), lr, momentum = 0)

* 调用optimizer.zero_grad()重置模型参数的梯度。

*调用loss.backward()反向传播预测loss的梯度。

*调用optimizer.step()调整模型参数。 

5、Entire Procedure(整个程序)

import torch.utils.data as data
dataset = data.Dataset(file)              # 读取数据
tr_set = DataLoader(dataset,batch_size,shuffle=True)  # 对数据集进行分组
model = MyModel().to(device)              # 建立我的模型并且选择我的设备(cpu or gpu)
criterion = nn.MSELoss()                # 建立损失函数
optimizer = torch.optim.SGD(model.parameters(),0.1)   # 建立优化
# 训练
for epoch in range(n_epochs):             # 迭代数据
  model.train()                    # 训练模型
  for x, y in tr_set:               # 迭代数据集
    optimizer.zero_grad()              # 设置梯度为0
    x, y = x.to(device),y.to(device)       # 将数据移动到设备
    pred = model(x)                # 计算输出
    loss = criterion(pred,y)            # 计算损失函数
    loss.backward()                 # 计算反向梯度
    optimizer.model()                # 优化模型
# 验证
model.eval()                      # 将模型设置为评估模式
total_loss = 0          
for x,y in dv_set:                  # 对数据集进行迭代
  x,y = x.to(device),y.to(device)          # 将数据移动到涉笔
  with torch.no_grad():                # 不可迭代的计算
    pred = model(x)                # 计算输出
    loss = criterion(pred,y)           # 计算损失函数
  total_loss += loss.cpu().item()*len(x)      # 累加损失误差
  avg_loss = total_loss / len(dv_set.dataset)   # 计算平均损失
# 测试
model.eval()                       # 将模型设置为评估模式
preds = []
for x in dv_set:                   # 对数据集进行迭代
  x = x.to(device)                  # 将数据移动到涉笔
  with torch.no_grad():                # 不可迭代的计算
    pred = model(x)                # 计算输出
    preds.append(pred.cpu())             # 收集预测

// model.eval()  :更改模型的行为

//  with torch.no_grad() :防止对验证/测试数据进行意外训练

当我们训练完模型,也完成了测试,为了不使模型丢失,我们需要保存模型,pytorch也为我们提供了保存模型的方法。

保存模型:torch.save(model.state_dict(),path)

下次我们使用已经训练完成的模型,或者想继续训练,我们需要读取模型。

读取模型:ckpt = torch.load(path)     model.load_state_dict(ckpt)

// 这只是我根据所听的课自己写的笔记,如果有什么错误欢迎指正!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/444714.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

领域模型设计-COLA架构

前言 当我们需要创建的新应用的时候,往往需要站在一个长远的角度来设计我们的系统架构。有时候我们接手一个老的应用的时候,会发现由于创建之初没有好好规划系统架构,导致我们后期开分成本和维护成本都非常高。近些年来领域模型的系统设计非常…

Day26:安全开发-PHP应用模版引用Smarty渲染MVC模型数据联动RCE安全

目录 新闻列表 自写模版引用 Smarty模版引用 代码RCE安全测试 思维导图 PHP知识点: 功能:新闻列表,会员中心,资源下载,留言版,后台模块,模版引用,框架开发等 技术:输…

Pygame教程07:键盘常量+键盘事件的2种捕捉方式

------------★Pygame系列教程★------------ Pygame教程01:初识pygame游戏模块 Pygame教程02:图片的加载缩放旋转显示操作 Pygame教程03:文本显示字体加载transform方法 Pygame教程04:draw方法绘制矩形、多边形、圆、椭圆、弧…

【Java探索之旅】数据类型与变量,字面常量,整型变量

🎥 屿小夏 : 个人主页 🔥个人专栏 : Java入门到精通 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一、字面常量二、数据类型三、变量3.1 变量概念3.2 语法格式 四、整型变量4.1 整型变…

运维随录实战(13)之docker搭建mysql集群(pxc)

了解 MySQL 集群之前,先看看单节点数据库的弊病 大型互联网程序用户群体庞大,所以架构需要特殊设计。单节点数据库无法满足大并发时性能上的要求。单节点的数据库没有冗余设计,无法满足高可用。单节点 MySQL无法承载巨大的业务量,数据库负载巨大常见 MySQL 集群方案 Re…

.NET高级面试指南专题十六【 装饰器模式介绍,包装对象来包裹原始对象】

装饰器模式(Decorator Pattern)是一种结构型设计模式,用于动态地给对象添加额外的职责,而不改变其原始类的结构。它允许向对象添加行为,而无需生成子类。 实现原理: 装饰器模式通过创建一个包装对象来包裹原…

云原生之容器编排实践-ruoyi-cloud项目部署到K8S:Nginx1.25.3

背景 前面搭建好了 Kubernetes 集群与私有镜像仓库,终于要进入服务编排的实践环节了。本系列拿 ruoyi-cloud 项目进行练手,按照 MySQL , Nacos , Redis , Nginx , Gateway , Auth ,…

DDoS和CC攻击的原理

目前最常见的网络攻击方式就是CC攻击和DDoS攻击这两种,很多互联网企业服务器遭到攻击后接入我们德迅云安全高防时会问到,什么是CC攻击,什么又是DDoS攻击,这两个有什么区别的,其实清楚它们的攻击原理,也就知…

C#,数值计算,用割线法(Secant Method)求方程根的算法与源代码

1 割线法 割线法用于求方程 f(x) 0 的根。它是从根的两个不同估计 x1 和 x2 开始的。这是一个迭代过程,包括对根的线性插值。如果两个中间值之间的差值小于收敛因子,则迭代停止。 亦称弦截法,又称线性插值法.一种迭代法.指用割线近似曲线求…

【JavaScript 漫游】【033】Cookie 总结

文章简介 本篇文章为【JavaScript 漫游】专栏的第 033 篇文章,主要记录了浏览器模型中 Cookie 相关的知识点。 Cookie 概述 Cookie 是服务器保存在浏览器的一小段文本信息,一般大小不能超过4KB。浏览器每次向服务器发出请求,就会自动附上这…

gensim 实现 TF-IDF;textRank 关键词提取

目录 TF-IDF 提取关键词 介绍 代码 textRAnk 提取关键词 这里只写了两种简单的提取方法,不需要理解上下文,如果需要基于一些语义提取关键词用 LDA:TF-IDF,textRank,LSI_LDA 关键词提取-CSDN博客 TF-IDF 提取关键词…

【框架学习 | 第三篇】Spring上篇(Spring入门、核心功能、Spring Bean——>定义、作用域、生命周期、依赖注入)

文章目录 1.Spring简述1.1什么是Spring框架?1.2Spring的核心功能1.2.1 IOC(1)IOC介绍(2)控制?反转? 1.2.2 AOP(1)AOP介绍(2)专业术语(…

BadUsb制作

BadUsb制作 一个树莓派pico kali监听 需要的文件 https://pan.baidu.com/s/1_kyzXIqk9JWHGHstTgq7sQ?pwd6666 1.将pico插入电脑 2.将Bad USB固件中的文件复制到pico中,pico会重启 3.将Bad USB目录文件复制进去(打开Bad USB目录文件复制) …

Spring Boot单元测试与热部署简析

1 Spring Boot的简介 Spring Boot是一个用于构建独立的、生产级别的Spring应用程序的框架。它简化了Spring应用程序的开发过程,提供了自动配置和默认配置,使得开发者只需专注于业务逻辑的实现,而不用去关注繁琐的配置问题。 Spring …

伪分布Hadoop的安装与部署

1.实训目标 (1)熟悉掌握使用在Linux下安装JDK。 (2)熟悉掌握使用在Linux下安装Hadoop。 (3)熟悉掌握使用配置SSH免密登录。 2.实训环境与软件 环境 版本 说明 Windows 10系统 64位 操作电脑配置 …

HTML 学习笔记(一)开始

一、介绍: 首先引用百度百科的一段话作为介绍:   HTML的全称为超文本标记语言,是一种标记语言。它包括一系列标签,通过这些标签可以将网络上的文档格式统一,使分散的Internet资源连接为一个逻辑整体。HTML文本是由HTML命令组成的描述性文本…

[贰],万能开篇HelloWorld

1,新建项目 File/New/Project Android/Android Application Project 输入程序名字HelloWorld Next Next 选择Blank Activity 修改为HelloWorldActivity 2,异常点 2.1,异常1:No resource found that matches the given name The…

如何在Windows系统使用固定tcp公网地址ssh远程Kali系统

文章目录 1. 启动kali ssh 服务2. kali 安装cpolar 内网穿透3. 配置kali ssh公网地址4. 远程连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 简单几步通过[cpolar 内网穿透](cpolar官网-安全的内网穿透工具 | 无需公网ip | 远程访问 | 搭建网站)软件实现ssh 远程连接kali! …

ubuntu23.10安装搜狗拼音

1.添加fcitx仓库 sudo add-apt-repository ppa:fcitx-team/nightly 更新: sudo apt-get update 安装fcitx sudo apt-get install fcitx fcitx安装成功 切换输入系统为fcitx

设计模式 代理模式

代理模式主要使用了 Java 的多态,主要是接口 干活的是被代理类,代理类主要是接活, 你让我干活,好,我交给幕后的类去干,你满意就成,那怎么知道被代理类能不能干呢? 同根就成&#xff…