PyTorch基础2

目录

1.Tensor常见操作

1.1获取元素值

1.2元素值运算

1.2.1算术运算符

1.2.2数学运算

1.2.3三角函数

1.2.4统计学函数

1.3阿达玛积

1.4Tensor相乘

1.5索引操作

1.5.1简单索引

1.5.2列表索引

1.5.3布尔索引

1.5.4索引赋值

1.6张量拼接

1.6.1 torch.cat()

1.6.2 torch.stack()

1.7形状操作

1.7.1 reshape()

1.7.2 view()

1.7.3 transpose()

1.7.4 permute()

1.7.5 flatten()

1.7.6 queeze()

1.7.7 unqueeze()

1.8张量分割

1.8.1 chunk()

 1.8.2 split()

1.9广播机制

9.1 广播机制规则

9.2 广播案例

1.10保存和加载

1.11并行化


1.Tensor常见操作

1.1获取元素值

单个元素tensor转换为Python数值

import torch

torch.manual_seed(40)
tensor1 = torch.rand(size=(4,2,2))
print(tensor1)

# 取出元素
print(tensor1[2][1][0].item())

注意:

  • 如果有多个元素则报错;

1.2元素值运算

1.2.1算术运算符

# 算术运算符: +
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1+1)

# 算术运算符: -
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1+-2)

# 算术运算符: *
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1 *2)

# 算术运算符: /
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1 /66)

# 算术运算符: //
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1 //66)

# 算术运算符: %
ten1 = torch.tensor([[ 100,  400,  900],[1600, 2500, 3600]])
print(ten1 %66)

1.2.2数学运算

(1)关于元素值运算的api,返回新的张量。

floor()向下取整
ceil()向上取整
round()四舍五入
trunc()裁剪,只保留整数部分
frac()只保留小数部分
fix()向零方向舍入
# 不修改原始数据
tensor1 = torch.tensor([[1,2,3],[4,5,6]])
print(tensor1.add(1))
print(tensor1.sub(1))
print(tensor1.mul(10))
print(tensor1.div(10))
print(tensor1.pow(2))
print(tensor1)


ten = torch.tensor([[1.12, 2.54, -3.78], [4.45, -45.50, 6.54], [17.0, -8.110, 9.90]],device='cuda')

print(ten.floor())
print(ten.ceil())
print(ten.fix())
print(ten.round())
print(ten.trunc())
print(ten.frac())
print(ten.abs())

(2)关于元素值运算的api,带有_的方法则会修改原始数据。 

# 修改原始数据
tensor1 = torch.tensor([[1,2,3],[4,5,6]])
print(tensor1.add_(1))
print(tensor1.sub_(1))
print(tensor1.mul_(10))
print(tensor1.pow_(2))
tensor1 = torch.tensor([[1,2,3],[4,5,6]],dtype=float)
print(tensor1.div_(10))
print(tensor1)

1.2.3三角函数

  • torch.cos(input,out=None)

  • torch.cosh(input,out=None) # 双曲余弦函数

  • torch.sin(input,out=None)

  • torch.sinh(input,out=None) # 双曲正弦函数

  • torch.tan(input,out=None)

  • torch.tanh(input,out=None) # 双曲正切函数

torch.set_printoptions(sci_mode=False)
ten = torch.tensor([[1.12, 2.54, -3.78], [4.45, -45.50, 6.54], [17.0, -8.110, 9.90]],device='cuda')

print(ten.sin())
print(ten.sinh())
print(ten.cos())
print(ten.cosh())
print(ten.tan())
print(ten.tanh())

1.2.4统计学函数

mean()张量的平均值
median()张量的中位数
mode()张量的众数
sum()张量的元素之和
std()张量的标准差
var()张量的方差
max()张量的最大值
min()张量的最小值
sort()对张量进行排序
topk()返回张量中的前 k 个最大值或最小值
histc()计算张量的直方图
unique()返回去重后的张量。
bincount()张量中每个元素的出现次数
ten = torch.tensor([[1.12, 2.54, -3.78,-3.78], [-45.50, -45.50,-45.0, 6.54], [17.0,17.0, -8.110, 9.90]],device='cuda')

print(ten.mean())
print(ten.median())
print(ten.mode())

print(ten.sum())
print(ten.std())
print(ten.var())
print(ten.max())
print(ten.min())
print(ten.sort())
print(ten.sqrt())

print(ten.topk(k=3,dim=-1,sorted=True))
print(ten.histc())
# 返回去重后的结果
print(ten.unique())

ten = torch.tensor([100, 1, 2, 3, 4],device='cuda')
# 输入:第一个位置:表示0有多少个,第二个位置表示1有多少个
print(ten.bincount())

1.3阿达玛积

阿达玛积指的是矩阵相应位的元素相乘,可以使用 mul() 或者 * 来实现;

ten1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
ten2 = torch.tensor([[4, 5, 6],[1, 2, 3]])
print(ten1 * ten2)
print(ten1.mul(ten2))

1.4Tensor相乘

点积运算将两个向量映射为一个标量,是向量之间的基本操作。

点积运算要求如果第一个矩阵的shape是 (N, M),那么第二个矩阵 shape必是 (M, P),点积运算结果的shape为 (N, P)。

使用:@符号或者matmul()方法完成Tensor的乘法,mm(tensor1,tensor2)只限用于二维矩阵相乘。

ten1 = torch.tensor([[1, 2, 3]])
ten2 = torch.tensor([[4], [5], [6]])

print(ten1 @ ten2)
print(ten1.matmul(ten2))
print(torch.mm(ten1,ten2))

print(ten2 @ ten1)
print(ten2.matmul(ten1))
print(torch.mm(ten2,ten1))

1.5索引操作

索引可以灵活地访问、修改张量中的特定元素或子集,从而简化代码并提高操作效率。

1.5.1简单索引

ten = torch.randint(low=0,high=10,size=(3,3)).cuda()
print(ten)

# 行索引
print(ten[1])

# 列索引
print(ten[:,1:2])

# 固定位置索引
print(ten[0:3,1:2])

# 单个元素索引
print(ten[0][1])
print(ten[0,1])

1.5.2列表索引

使用list批量的制定要索引的元素位置须注意list的所在维度。

ten = torch.randint(low=0,high=13,size=(3,4)).cuda()
print(ten)

# 列表索引
print(ten[[0,1],[2,3]])

# 行级别的索引
print(ten[[[2],[1]],[1,2,3]])

1.5.3布尔索引

(1)根据条件选择张量中的元素。

ten = torch.randint(low=0,high=10,size=(4,4)).cuda()
mask = ten>5
print(ten)
print(mask)
print(ten[mask])

(2)行级别的条件索引。

ten = torch.randint(low=0,high=17,size=(4,4)).cuda()
print(ten)

# 索引行(其列满足:....)
# 索引行(其第3列的元素大于10)
print(ten[ten[:,2]>10])
# 索引行(其第1列的元素大于2,且第2列的元素是偶数)
print(ten[(ten[:,1]%2==0) & (ten[:,0]>2)])

# 索引列(其行满足:....)
# 索引列(其第3行元素大于10)
print(ten[:,ten[2]>10])

1.5.4索引赋值

通过索引取地址,再进行批量赋值。

ten = torch.eye(4)
print(ten)

ten[:,1:2]=10
print(ten)

1.6张量拼接

1.6.1 torch.cat()

在元素级别上操作

torch.cat([ ],dim)(concatenate 的缩写)用于沿现有维度拼接张量。

注意:要拼接的张量在除了指定拼接的维度之外的所有维度上的大小必须相同。

ten1 = torch.tensor([[1,2,3],[4,5,6]])
ten2 = torch.tensor([[7,8,9],[77,88,99]])
print(torch.cat([ten1,ten2],dim=0))

ten1 = torch.tensor([[1,2,3],[4,5,6]])
ten2 = torch.tensor([[7,8,9],[77,88,99]])
print(torch.cat([ten1,ten2],dim=1))

1.6.2 torch.stack()

在张量级别上操作

torch.stack([ ],dim) 用于在新维度上堆叠张量。

注意:要堆叠的张量必须具有相同的形状。

技巧:堆叠指一人出一个交替添加,拼接指一人出完下个人在出完。

ten1 = torch.tensor([[1,2,3],[4,5,6]])
ten2 = torch.tensor([[7,8,9],[77,88,99]])
print(torch.stack([ten1,ten2],dim=0))

ten1 = torch.tensor([[1,2,3],[4,5,6]])
ten2 = torch.tensor([[7,8,9],[77,88,99]])
print(torch.stack([ten1,ten2],dim=1))

1.7形状操作

在 PyTorch 中,调整张量的维度和结构,以适应不同的计算需求。

1.7.1 reshape()

可以用于将张量转换为不同的形状,但要确保转换后的形状与原始形状具有相同的元素数量。

ten1 = torch.randint(low=0,high=13,size=(3,4,2)).cuda()
print(ten1)
ten2 = ten1.reshape(4,2,3)
print(ten2)

# -1表示:此位数字自动计算
ten1 = torch.randint(low=0,high=13,size=(3,4,2)).cuda()
print(ten1)
ten2 = ten1.reshape(4,2,-1)
print(ten2)
ten3 = ten1.reshape(4,-1)
print(ten3)

1.7.2 view()

(1)view进行形状变换的特征

  • 张量在内存中是连续的;

  • 返回的是原始张量视图,不重新分配内存,效率更高;

  • 如果张量在内存中不连续(比如张量进行了转置操作),view 将无法执行,并抛出错误。

(2)内存连续性

我们在进行变形或转置操作时,很容易造成内存的不连续性。

(3) view() 和 reshape() 比较

view:高效,但需要张量在内存中是连续的;

reshape:更灵活,但涉及内存复制,更消耗内存资源;

ten1 = torch.randint(low=0,high=13,size=(3,4)).cuda()
print(ten1)
print(ten1.is_contiguous())

ten2 = ten1.t()
print(ten2)
print(ten2.is_contiguous())

# tensor连续,可用view()
ten3 = ten1.view(2,-1)
print(ten1)

# # tensor不连续,不可用view()
# ten4 = ten2.view(2,-1)
# print(ten4)

1.7.3 transpose()

transpose 用于交换张量的两个维度,返回的是原张量的视图。

img_ten1 = torch.randint(low=0,high=256,size=(3,200,100)).cuda()
# print(ten1)

img_ten2 = img_ten1.transpose(dim0=2,dim1=0)
print(img_ten2.shape)
print(img_ten2)

1.7.4 permute()

permute 用于改变张量的所有维度顺序

img_ten1 = torch.randint(low=0,high=256,size=(3,200,100)).cuda()

img_ten2 = img_ten1.permute(1,2,0)
print(img_ten2.shape)
print(img_ten2)

1.7.5 flatten()

flatten 用于元素不变,消除维度。

tensor.flatten(start_dim=0, end_dim=-1)

 参数:

  • start_dim:从哪个维度开始展平。

  • end_dim:在哪个维度结束展平。默认值为 -1,表示展平到最后一个维度。

ten = torch.randint(low=0,high=256,size=(4,2,2)).cuda()
print(ten)

# 展平成二维
ten2 = ten.flatten(start_dim=1,end_dim=2)
print(ten2)

# 展平成一维
ten3 = ten.flatten(start_dim=0,end_dim=-1)
print(ten3)

1.7.6 queeze()

用于移除所有大小为 1 的维度,或者移除指定位置且大小为 1 的维度。

ten = torch.randint(low=0,high=256,size=(1,3,2,1)).cuda()
# print(ten)

ten1 = ten.squeeze(dim=0)
print(ten1.shape)

ten2 = ten.squeeze(-1)
print(ten2.shape)

ten3 = ten.squeeze(3).squeeze(0)
print(ten3.shape)

ten4 = ten.squeeze()
print(ten4.shape)

1.7.7 unqueeze()

用于在指定位置插入一个大小为 1 的新维度。

ten = torch.randint(low=0,high=256,size=(3,4)).cuda()

ten1 = ten.unsqueeze(dim=0)
print(ten1.shape)

ten2 = ten.unsqueeze(-1)
print(ten2.shape)

ten3 = ten2.unsqueeze(0).unsqueeze(2)
print(ten3.shape)

1.8张量分割

可以按照指定的大小或者块数进行分割。

1.8.1 chunk()

 按块数进行分割。

ten = torch.randint(low=0,high=256,size=(3,4,5)).cuda()

# 分成chunks份
tensor_tup1=ten.chunk(chunks=2,dim=2)
print(tensor_tup1)

# 若chunks>dim 则按dim分割
print(ten.chunk(5))

 1.8.2 split()

按指定的大小进行分割。

ten = torch.randint(low=0,high=256,size=(3,4,5)).cuda()

# 分割的每一份的大小:split_size
tensor_tup1=ten.split(split_size=10,dim=2)
print(tensor_tup1)

# 若split_size>dim 则割的每一份的大小为:dim
print(ten.chunk(5))

1.9广播机制

广播机制支持了不同形状的张量进行计算的实现,而无需显式地调整张量的形状。广播机制通过自动扩展较小维度的张量,使其与较大维度的张量兼容,从而实现按元素计算。

9.1 广播机制规则

广播机制需要遵循以下规则:

  • 每个张量的维度至少为1

  • 满足右对齐

9.2 广播案例

1D和2D张量广播

# 1D和2D张量广播
ten1 =torch.Tensor(([10],[20],[30])) 
ten2 =torch.Tensor(([[1,2,3],[4,5,6],[7,8,9]])) 

ten3 = ten1+ten2
print(ten3)

ten1 =torch.Tensor(([10],[20],[30])) # 维度(3,1)->广播(3,3)
ten4 =torch.Tensor(([1,2,3])) # 维度(1,3)->广播(3,3)

ten5= ten1+ten4
print(ten3)

2D 和 3D 张量广播

广播机制会根据需要对两个张量进行形状扩展,以确保它们的形状对齐,从而能够进行逐元素运算。

# 2D 和 3D 张量广播
ten1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
ten2 = torch.tensor([[[2, 3, 4]], [[5, 6, 7]]])
print(ten1.shape,ten2.shape)
print(ten1)
print(ten2)

ten3 =ten1+ten2
print()
print(ten3.shape)
print(ten3)

1.10保存和加载

Pytorch支持,张量数据可以保存下来并再次加载使用。

# 1. 定义一个张量
ten = torch.tensor([[1.12, 2.54, -3.78], [-45.50, -45.50, 6.54], [17.0, -8.110, 9.90]],device='cuda')
# 2. 保存到文件中
torch.save(ten, "ten.pt")

# 1. 加载张量
x = torch.load("ten.pt")
# 2. 打印张量
print(ten,ten.device)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 3. 加载到指定设备
ten = torch.load("ten.pt", map_location=device)
print(ten, ten.device)

1.11并行化

在 PyTorch 中,可以查看和设置用于 CPU 运算的线程数。

PyTorch 使用多线程来加速 CPU 运算,但有时可能需要调整线程数来优化性能。

(1)查看线程数

# 查看当前 PyTorch 使用的线程数
num=torch.get_num_threads()
print(num)

(2)设置线程数

设置线程数时,确保应考虑 CPU 核心数和其他进程的资源需求,以获得最佳性能。

(3)注意事项

  • 线程数设置过高可能会导致线程相互竞争,反而降低性能;

  • 线程数设置过低可能会导致计算资源未得到充分利用;

  • 当使用 GPU 进行计算时,线程数设置对性能影响较小,因为 GPU 计算并不依赖于 CPU 线程数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/922749.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】【线性表】栈的基本概念(附c语言源码)

栈的基本概念 讲基本概念还是回到数据结构的三要素:逻辑结构,物理结构和数据运算。 从逻辑结构来讲,栈的各个数据元素之间是通过是一对一的线性连接,因此栈也是属于线性表的一种从物理结构来说,栈可以是顺序存储和顺…

OpenOCD之J-Link下载

1.下载USB Dirver Tool.exe,选择J-Link dirver,替换成WinUSB驱动。(⭐USB Dirver Tool工具可将J-Link从WinUSB驱动恢复为默认驱动⭐) 下载方式 ①官方网址:https://visualgdb.com/UsbDriverTool/ ②笔者的CSDN链接&…

【JavaEE初阶 — 多线程】定时器的应用及模拟实现

目录 1. 标准库中的定时器 1.1 Timer 的定义 1.2 Timer 的原理 1.3 Timer 的使用 1.4 Timer 的弊端 1.5 ScheduledExecutorService 2. 模拟实现定时器 2.1 实现定时器的步骤 2.1.1 定义类描述任务 定义类描述任务 第一种定义方法 …

ssm168基于jsp的实验室考勤管理系统网页的设计与实现+jsp(论文+源码)_kaic

毕 业 设 计(论 文) 题目:实验室考勤管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本实验室考勤管…

原生微信小程序在顶部胶囊左侧水平设置自定义导航兼容各种手机模型

无论是在什么手机机型下,自定义的导航都和右侧的胶囊水平一条线上。如图下 以上图iphone12,13PRo 以上图是没有带黑色扇帘的机型 以下是调试器看的wxml的代码展示 注意:红色阔里的是自定义导航(或者其他的logo啊,返回之…

Python 获取微博用户信息及作品(完整版)

在当今的社交媒体时代,微博作为一个热门的社交平台,蕴含着海量的用户信息和丰富多样的内容。今天,我将带大家深入了解一段 Python 代码,它能够帮助我们获取微博用户的基本信息以及下载其微博中的相关素材,比如图片等。…

springcloud alibaba之shcedulerx实现分布式锁

文章目录 1、shcedulerx简介2、基于mysq分布式锁实现3、注解方式使用分布式锁4、编码方式使用分布式锁 1、shcedulerx简介 springcloud alibaba shcedulerx看起来有点像xxl job那样的任务调度中间件,其实它是一个分布式锁框架,含有两种实现一种基于DB实…

【LLM训练系列02】如何找到一个大模型Lora的target_modules

方法1:观察attention中的线性层 import numpy as np import pandas as pd from peft import PeftModel import torch import torch.nn.functional as F from torch import Tensor from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig from typ…

Selenium的八种定位方式

1. 通过 ID 定位 ID 是最直接和高效的方式来定位元素,因为每个页面中的 ID 应该是唯一的。 from selenium import webdriverdriver webdriver.Chrome(executable_pathpath/to/chromedriver) driver.get(https://example.com)# 通过 ID 定位 element driver.find…

MySQL底层概述—1.InnoDB内存结构

大纲 1.InnoDB引擎架构 2.Buffer Pool 3.Page管理机制之Page页分类 4.Page管理机制之Page页管理 5.Change Buffer 6.Log Buffer 1.InnoDB引擎架构 (1)InnoDB引擎架构图 (2)InnoDB内存结构 (1)InnoDB引擎架构图 下面是InnoDB引擎架构图,主要分为内存结构和磁…

丹摩|丹摩智算平台深度评测

1. 丹摩智算平台介绍 随着人工智能和大数据技术的快速发展,越来越多的智能计算平台涌现,为科研工作者和开发者提供高性能计算资源。丹摩智算平台作为其中的一员,定位于智能计算服务的提供者,支持从数据处理到模型训练的全流程操作…

基于企业微信客户端设计一个文件下载与预览系统

在企业内部沟通与协作中,文件分享和管理是不可或缺的一部分。企业微信(WeCom)作为一款广泛应用于企业的沟通工具,提供了丰富的API接口和功能,帮助企业进行高效的团队协作。然而,随着文件交换和协作的日益增…

LLM的原理理解6-10:6、前馈步骤7、使用向量运算进行前馈网络的推理8、注意力层和前馈层有不同的功能9、语言模型的训练方式10、GPT-3的惊人性能

目录 LLM的原理理解6-10: 6、前馈步骤 7、使用向量运算进行前馈网络的推理 8、注意力层和前馈层有不同的功能 注意力:特征提取 前馈层:数据库 9、语言模型的训练方式 10、GPT-3的惊人性能 一个原因是规模 大模型GPT-1。它使用了768维的词向量,共有12层,总共有1.…

大模型系列11-ray

大模型系列11-ray PlasmaPlasmaStore启动监听处理请求 ProcessMessagePlasmaCreateRequest请求PlasmaCreateRetryRequest请求PlasmaGetRequest请求PlasmaReleaseRequestPlasmaDeleteRequestPlasmaSealRequest ObjectLifecycleManagerGetObjectSealObject ObjectStoreRunnerPlas…

开源动态表单form-create-designer 扩展个性化配置的最佳实践教程

在开源低代码表单设计器 form-create-designer 的右侧配置面板里,field 映射规则为开发者提供了强大的工具去自定义和增强组件及表单配置的显示方式。通过这些规则,你可以简单而高效地调整配置项的展示,提升用户体验。 源码地址: Github | G…

美创科技入选2024数字政府解决方案提供商TOP100!

11月19日,国内专业咨询机构DBC德本咨询发布“2024数字政府解决方案提供商TOP100”榜单。美创科技凭借在政府数据安全领域多年的项目经验、技术优势与创新能力,入选收录。 作为专业数据安全产品与服务提供商,美创科技一直致力于为政府、金融、…

地平线 bev_cft_efficientnetb3 参考算法-v1.2.1

01 概述 在自动驾驶感知算法中 BEV 感知成为热点话题,BEV 感知可以弥补 2D 感知的缺陷构建 3D “世界”,更有利于下游任务和特征融合。 地平线集成了基于 bev 的纯视觉算法,目前已支持 ipm-based 、lss-based、 transformer-based&#xff…

C#里怎么样检测文件的属性?

C#里怎么样检测文件的属性? 对于文件来说,在C#里有一种快速的方法来检查文件的属性。 比如文件是否已经压缩, 文件是否加密, 文件是否是目录等等。 属性有下面这么多: 例子演示如下: /** C# Program to View the Information of the File*/ using System; using Syste…

最新‌VSCode保姆级安装教程(附安装包)

文章目录 一、VSCode介绍 二、VSCode下载 下载链接:https://pan.quark.cn/s/19a303ff81fc 三、VSCode安装 1.解压安装文件:双击打开并安装VSCode 2.勾选我同意协议:然后点击下一步 3.选择目标位置:点击浏览 4.选择D盘安装…

传输控制协议(TCP)和用户数据报协议(UDP)

一、传输控制协议(TCP) 传输控制协议(Transmission Control Protocol,TCP)是一种面向连接的、可靠的、基于字节流的传输层通信协议,由 IETF 的 RFC 793 定义。 它通过三次握手建立连接,确保数…