图解PyTorch中的torch.gather函数和 scatter 函数

前言

torch.gather在目前基于 transformer or query based 的目标检测中,在最后获取目标结果时,经常用到。

这里记录下用法,防止之后又忘了。

介绍

torch.gather

在这里插入图片描述
官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据
看到这个核心定义,我们很容易想到gather()的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的

示例

我们找个3x3的二维矩阵做个实验

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示
在这里插入图片描述

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示
在这里插入图片描述

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],
        [7],
        [9]])

过程如图所示
在这里插入图片描述

scatter

基本是 gather 的反过程,是将数据添加进去,
doc:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

example:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

具体过程见 gather 的就好~一摸一样,一个获取,一个填入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/509039.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Java基础揉碎]注解

目录 介绍 基本的Annotation介绍 1) Override:限定某个方法,是重写父类方法,该注解只能用于方法 2) Deprecated: 用于表示某个程序元素(类,方法等)已过时 3) SuppressWarnings:抑制编译器警告 四大元注解 ​编辑 Retention ​编辑​编辑 Docume…

【MySQL】MySQL故障排查与环境优化

案例 MySQL是目前企业最常见的数据库之一 日常维护管理的过程中,会遇到很多故障 mysql默认配置无法满足高性能要求 数据流向 MySQL常见的故障案例 故障一 故障现象: ERROR 2002 (HY000): Cant connect to local MySQL server through socket data/m…

【计算机网络】select/poll

多路转接 - select/poll 一、I/O 多路转接之 select1. select 接口2. select 的使用3. select 的优缺点 二、I/O 多路转接之 poll1. poll 接口2. poll 的使用3. poll 与 select 的对比 一、I/O 多路转接之 select 多路转接属于 IO 复用方式的一种。系统提供 select() 函数来实…

如何在 Mac 上恢复已删除的数据

如果您丢失了 Mac 上的数据,请不要绝望。恢复数据比您想象的要容易,并且有很多方法可以尝试。 在 Mac 上遭受数据丢失是每个人都认为永远不会发生在他们身上的事情之一......直到它发生。不过,请不要担心,因为您可以通过多种方法…

数据结构(六)——图的应用

6.4 图的应用 6.4.1 最小生成树 对于⼀个带权连通⽆向图G (V, E),⽣成树不同,每棵树的权(即树中所有边上的权值之和)也可能不同。设R为G的所有⽣成树的集合,若T为R中边的权值之和最小的生成树,则T称为G的…

解析网约车微服务中台架构:打造智能高效的出行平台

随着互联网技术的不断发展,网约车行业已经成为了城市出行的重要方式之一。为了应对市场竞争和用户需求的不断变化,各大网约车平台纷纷采用了微服务中台架构,以构建智能高效的出行平台。本文将深入探讨网约车微服务中台架构的核心概念、关键特…

反序列化漏洞

常见的反序列化流量特征: 像这种st2 045、068、shiro反序列化、fastjson这些java反序列化一类的流量特征 shiro就看cookie中Rememberme字段,什么都要从这里传 fastjson:可以在提交的包中找找json格式的数据,重点看一下有无rmi或…

【操作系统】想要更好的学习计算机,操作系统的知识必不可少!!!

操作系统的概念 导言一、日常生活中的操作系统二、计算机系统层次结构三、操作系统的定义3.1 控制和管理计算机资源3.2 组织、调度计算机的工作与资源的分配3.3 给用户和其他软件提供方便接口与环境3.4 总结 四、操作系统的目标和功能4.1 作为管理者4.1.1 处理机管理4.1.2 存储…

记录Xshell使用ed25519公钥免密链接SSH

试了半天,Xshell好像没办法导入linux生成的ssh公钥,因此需要以下步骤实现免密登录 结论,在linux公钥文件中,将客户端生成的ed25519公钥加上去即可(一个公钥单独一行) 1.使用Linux生成秘钥文件(不需要输入私钥密码passphrase)或者直接创建一…

【css】文本过长溢出一行不换行普通css以及antd实现

.text-box { white-space: nowrap; /* 防止文字换行 */ overflow: hidden; /* 隐藏超出div的内容 */ text-overflow: ellipsis; /* 当内容超出时,显示省略号 */ max-width: calc(100% - 80px); /* 假设按钮宽度为80px,则设置div的最大宽度为容器宽度…

【企业管理精粹】华为腾讯全套人力资源管理精品资料合集

以下是华为&腾讯全套人力资源管理资料目录,如需下载,请前往星球下载,海量免费资料等你领取: 华为全套企业管理资料合集,共23专题。 1.绩效考核 华为内训绝密资料:绩效管理与绩效考核.ppt 华为绩效管理与绩效考核制…

软考高级架构师;线程的同步和互斥、临界区、临界资源、信号量、PV 操作概念和例题

作者:明明如月学长, CSDN 博客专家,大厂高级 Java 工程师,《性能优化方法论》作者、《解锁大厂思维:剖析《阿里巴巴Java开发手册》》、《再学经典:《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

信息系统项目管理师——第15章项目风险管理

本章节内容属于10大管理知识领域,选择、案例、论文都会考。 选择题,稳定考3分左右,新教材基本考课本原话,这个分不能丢。 案例题,本期考的概率中等。 论文题,202305刚考过,这期不会考。 1管理基…

基于SSM的“电费管理系统”的设计与实现(源码+数据库+文档+PPT)

基于SSM的“电费管理系统”的设计与实现(源码数据库文档PPT) 开发语言:Java 数据库:MySQL 技术:SSM,VUE 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统功能结构图 系统登录页面 用户列表信息页面 添加用户信…

mkcert生成ssl证书+nginx部署局域网内的https服务访问问题

文章目录 mkcert生成ssl证书nginx部署局域网内的https服务访问问题1、下载mkcert查看自己的电脑是arm还是amd架构 2、安装mkcert3、测试mkcert是否安装成功4、查看CA证书存放位置5、打开windows的证书控制台6、生成自签证书,可供局域网内使用其他主机访问以下是nginx部署https服…

财富池指标公式--通达信免费指标公式源码合集--第二期

财富池免费通达信指标公式源码第二期来啦,今天给大家分享3个不同功能用法的指标,如果大家想要小编发布什么类型的指标,请多多在评论区留言呀! 一、通达信犀牛王指标公式 当在0轴线上出现蓝色加玫红色柱的信号时,是上涨…

LeetCode:331. 验证二叉树的前序序列化(模拟 Java)

目录 331. 验证二叉树的前序序列化 题目描述: 实现代码与解析: 模拟 原理思路: 331. 验证二叉树的前序序列化 题目描述: 序列化二叉树的一种方法是使用 前序遍历 。当我们遇到一个非空节点时,我们可以记录下这个节…

JAV八股--redis

如何保证Redis和数据库数据一致性 关于异步通知中消息队列和Canal的内容。 redisson实现的分布式锁的主从一致性 明天继续深入看这个系列问题 介绍IO复用模型

5个网络基础概念

说到网络,有五大基础概念是不得不提的,IP地址,子网掩码、网关、DHCP服务和PPPoE拨号,这些都是日常做电脑或路由器网络配置经常用到的,相信很多人都听过这些概念念,也知道都是一串串数字,但具体是…

远控桌面多任务并发文件保密传输

远程桌面文件传输是一个重要的功能,大多数远控都是用的桌面程序模式,利用系统自带复制粘贴拖拽文件拷贝功能,做一个ole调用对接,可以将很多控制权交给操作系统。 但我做的是浏览器版,浏览器是沙盒原理,为了…