Pytorch(一)

目录

一、基本操作

二、自动求导机制

 三、线性回归DEMO

3.1模型的读取与保存

3.2利用GPU训练时

四、常见的Tensor形式

五、Hub模块


一、基本操作

操作代码如下:

import torch
import numpy as np

#创建一个矩阵
x1 = torch.empty(5,3)

# 随机值
x2 = torch.rand(5,3)

# 初始化一个全零的矩阵
x3 = torch.zeros(5,3,dtype = torch.long)

# view操作改变矩阵维度
x4 = torch.randn(4,4) #4*4矩阵
y = x4.view(16) #变成一行的矩阵
z = x4.view(-1,8) #变为2*8的矩阵
print(y.size()) #矩阵的尺寸

#与numpy的协同操作
# tensor转array
a = torch.ones(5)
b = a.numpy()

# array转tensor
a1 = np.ones(5)
b1 = torch.from_numpy(a)

二、自动求导机制

案例代码如下:

 

import torch

#计算流程
x = torch.rand(1)
b = torch.rand(1,requires_grad=True)
w = torch.rand(1,requires_grad=True)
y = w * x
z = y + b

# 反向传播计算
z.backward(retain_graph = True)
print(w.grad)
print(b.grad)

 三、线性回归DEMO

 

import numpy as np
import torch
import torch.nn as nn

# 构建线性回归模型
class LinearRegressionModel(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(LinearRegressionModel,self).__init__()
        self.linear = nn.Linear(input_dim,output_dim)

    def forward(self,x):
        out = self.linear(x)
        return out

x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)
print(x_train.shape)

#y = 2x + 1
y_values = [2*i + 1 for i in range(11)]
y_train = np.array(x_values,dtype=np.float32)
y_train = x_train.reshape(-1,1)

# 构建model
input_dim = 1
output_dim = 1

model = LinearRegressionModel(input_dim,output_dim)

# 指定好参数和损失函数
epochs = 1000 #训练次数
learning_rate = 0.01 #学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate) #优化器
criterion = nn.MSELoss() #损失函数

# 训练模型
for epoch in range(epochs):
    epoch += 1
    #注意转行为tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)


    #梯度要清零每一次迭代
    optimizer.zero_grad()

    #前向传播
    outputs = model(inputs)

    #计算损失
    loss = criterion(outputs,labels)

    #反向传播
    loss.backward()

    #更新权重参数
    optimizer.step()
    if epoch % 50 ==0:
        print('epoch {},loss {}'.format(epoch,loss.item()))

3.1模型的读取与保存

# 模型的保存与读取
torch.save(model.state.dict(),'model.pkl') #保存
model.load_state_dict(torch.load('model.pkl')) #读取

3.2利用GPU训练时

利用GPU训练时要将数据与模型导入cuda

#注意转行为tensor
inputs = torch.from_numpy(x_train)
labels = torch.from_numpy(y_train)
#利用GPU训练数据时的数据
inputs = torch.from_numpy(x_train).to(device)
labels = torch.from_numpy(y_train).to(device)


model = LinearRegressionModel(input_dim,output_dim)

#使用GPU进行训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

四、常见的Tensor形式

  • 1.scalar:通常是指一个数值
  • 2.vector:通常是指一个向量
  • 3.matrix:通常是指一个数据矩阵
  • 4.n-dimensional tensor:高维数据

五、Hub模块

Github地址:https://github.com/pytorch/hub

Hub已有模型:https://pytorch.org/hub/research-models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/49639.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

合作客户销售数据可视化分析

以一个案例进行实际分析: 数据来源:【地区数据分析】 以此数据来制作报表。 技巧一:词云图 以城市名称来显示合同金额的分布,合同金额越大,则城市文字显示越大。 技巧二:饼图 下面制定一个,合…

力扣 738. 单调递增的数字

题目来源:https://leetcode.cn/problems/monotone-increasing-digits/description/ C题解:像1234就可以直接返回1234,像120需要从个位往高位遍历,2比0大,那么2减一成为1,0变成9,变成119。 clas…

【图像分类】CNN + Transformer 结合系列.1

介绍三篇结合使用CNNTransformer进行学习的论文:CvT(ICCV2021),Mobile-Former(CVPR2022),SegNetr(arXiv2307). CvT: Introducing Convolutions to Vision Transformers, …

SpringMVC 拦截器详解

目录 一、介绍 二、过滤器与拦截器的简单对比 三、自定义拦截器 四、注册拦截器 五、案例演示-登录拦截器 5.1 自定义拦截器 5.2 注册拦截器 编写的初衷是为了自己巩固复习,如果能帮到你将是我的荣幸❣️ 一、介绍 SpringMVC提供的拦截器类似于JavaWeb中的过…

C++网络编程 TCP套接字基础知识,利用TCP套接字实现客户端-服务端通信

1. TCP 套接字编程流程 1.1 概念 流式套接字编程针对TCP协议通信,即是面向对象的通信,分为服务端和客户端两部分。 1.2 服务端编程流程: 1)加载套接字库(使用函数WSAStartup()),创建套接字&…

MySQL Windows版本下载及安装时默认路径的修改

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、MySQL 下载二、默认路径修改1、安装前准备【非常重要】2、启动安装程序总结1、MySQL下载2、MySQL默认路径修改前言 MySQL 被Oracle收购后,各种操作规范及约束也相应的跟着来了,这不,只…

【前端实习评审】对小说详情模块更新的后端接口压力流程进行了人群优化

大家好,本篇文章分享一下【校招VIP】免费商业项目“推推”第一期书籍详情模块 前端同学的开发文档周最佳作品。该同学来自安徽科技学院土木工程专业。本项目亮点难点: 1.热门书籍在更新点的访问压力; 2.书籍更新通知的及时性和有效性&#xf…

重学C++系列之友元

一、什么是友元 在C中,为了提高程序的效率,在一些场景下,引入友元,但同时类的封装性就会被破坏。 二、怎么实现友元 友元关键字(friend) // 在类中声明另一个类的成员函数来做为友元函数 // 以关键字&…

Centos部署Springboot项目详解

准备启动jar包,app.jar放入指定目录。 一、命令启动 1、启动命令 java -jar app.jar 2、后台运行 nohup java -jar app.jar >/dev/null 2>&1 & 加入配置参数命令 nohup java -Xms512M -Xmx512M -jar app.jar --server.port9080 spring.profiles…

基于鲸鱼优化算法的5G信道估计(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑…

Unity 工具之 NuGetForUnity 包管理器,方便在 Unity 中的进行包管理的简单使用

Unity 工具之 NuGetForUnity 包管理器,方便在 Unity 中的进行包管理的简单使用 目录 Unity 工具之 NuGetForUnity 包管理器,方便在 Unity 中的进行包管理的简单使用 一、简单介绍 二、NuGetForUnity 的下载导入 Unity 三、NuGetForUnity 在 Unity 的…

Jetbrains idea 代码关闭 注释自动渲染 导致换行不生效

方法1 关闭注释自动渲染 取消勾选 方法2 结尾使用 <br> 强制换行

vector使用

文章目录 vector的介绍vector的使用vector的初始化vector iterator迭代器的使用vector 空间增长问题vector的增删改查 迭代器失效总结 vector的介绍 文档介绍 vector是表示可变大小数组的序列容器。就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。也就是意味着…

【雕爷学编程】MicroPython动手做(02)——尝试搭建K210开发板的IDE环境3

4、下载MaixPy IDE&#xff0c;MaixPy 使用Micropython 脚本语法&#xff0c;所以不像 C语言 一样需要编译&#xff0c;要使用MaixPy IDE , 开发板固件必须是V0.3.1 版本以上&#xff08;这里使用V0.5.0&#xff09;, 否则MaixPy IDE上会连接不上&#xff0c; 使用前尽量检查固…

宝塔设置云服务器mysql端口转发,实现本地电脑访问云mysql

环境&#xff1a;centos系统使用宝塔面板 实现功能&#xff1a;宝塔设置云服务器mysql端口转发&#xff0c;实现本地电脑访问mysql 1.安装mysql、PHP-7.4.33、phpMyAdmin 5.0 软件商店》搜索 mysql安装即可 软件商店》搜索 PHP安装7.4.33即可&#xff08;只需要勾选快速安装&…

iOS开发-实现快速登录弹窗与微信微博QQ三方登录切换控件

iOS开发-实现快速登录弹窗与微信微博QQ三方登录切换控件。 之前开发中实现快速登录弹窗与微信微博等了切换控件。 一、效果图 二、实现代码 实现背景渐变UIBlurEffect self.blurEffect [UIBlurEffect effectWithStyle:UIBlurEffectStyleLight]; self.effectView [[UIVisu…

linux上适用的反汇编调试软件(对标od)

ubuntu下类似于od软件 经过搜索&#xff0c;在Ubuntu上选用edb-debugger进行动态调试&#xff0c; 下载链接: https://github.com/eteran/edb-debugger 但是依赖反汇编引擎: https://github.com/capstone-engine/capstone 安装 先安装capstone 先下载release的版本&#xf…

基于量子同态加密的安全多方凸包协议

摘要安全多方计算几何(SMCG)是安全多方计算的一个分支。该协议是为SMCG中安全的多方凸包计算而设计的。首先&#xff0c;提出了一种基于量子同态加密的安全双方值比较协议。由于量子同态加密的性质&#xff0c;该协议可以很好地保护量子电路执行过程中数据的安全性和各方之间的…

组合模式-树形结构的处理

A公司需要筛选出年龄35岁及以上(如果是领导&#xff0c;年龄为45岁及以上)的人。其组织架构图如下。 图 A公司部分组织架构图 图 传统解决方案 public class Development {private String name;public Development(String name) {this.name name;}List<Employee> emplo…

需求分析案例:消息配置中心

本文介绍了一个很常见的消息推送需求&#xff0c;在系统需要短信、微信、邮件之类的消息推送时&#xff0c;边界如何划分和如何设计技术方案。 1、需求 一个系统&#xff0c;一般会区分多个业务模块&#xff0c;并拆分成不同的业务系统&#xff0c;例如一个商城的架构如下&am…