transformer图像切块与还原(window_partition+window_unpartition)

文章目录

  • 前言
  • 一、切割图像(window_partition)
  • 二、还原图像(window_unpartition)
  • 三、整体代码


前言

假如b ,h,w,c=(3,32,32,768)需将h w按照14尺寸切割,32/14无法整除,需pad为(3,42,42,768)完成固定尺寸块切割,进而完成transformer结构,最终摒弃pad数据还原为(3,32,32,768)。在使用Transformer结构提取特征时,通常会使用window_partition和window_unpartition来划分和还原图像块的过程。这两个步骤是为了将图像分割成小块,送入Transformer网络进行处理,然后再将处理后的特征重新组合成原始图像的尺寸。为此,我摘录TAM大模型处理方法代码,记录图像尺寸切割与还原。


一、切割图像(window_partition)

这一步骤是将原始图像按照设定的窗口大小划分成多个块,并将这些块重新排列成一个较大的矩阵,以便送入Transformer网络。

def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:     
"""     Partition into non-overlapping windows with padding if needed.     
Args:         
x (tensor): input tokens with [B, H, W, C].         
window_size (int): window size.      
Returns:         
windows: windows after partition with [B * num_windows, window_size, window_size, C].         
(Hp, Wp): padded height and width before partition     
"""     
B, H, W, C = x.shape      
pad_h = (window_size - H % window_size) % window_size     pad_w = (window_size - W % window_size) % window_size     
if pad_h > 0 or pad_w > 0:         
	x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))     
Hp, Wp = H + pad_h, W + pad_w      
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)     
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)     
return windows, (Hp, Wp)

二、还原图像(window_unpartition)

在完成特征提取后,使用window_unpartition将处理后的特征重新还原为原始图像的尺寸。这样可以保持特征与原始图像之间的对应关系。

def window_unpartition(     windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor:     
"""     
Window unpartition into original sequences and removing padding.     
Args:         
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].         
window_size (int): window size.         
pad_hw (Tuple): padded height and width (Hp, Wp).         
hw (Tuple): original height and width (H, W) before padding.      
Returns:         
x: unpartitioned sequences with [B, H, W, C].     
"""     
Hp, Wp = pad_hw     H, W = hw     
B = windows.shape[0] // (Hp * Wp // window_size // window_size)     
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)     
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)      
if Hp > H or Wp > W:         
	x = x[:, :H, :W, :].contiguous()     
return x

三、整体代码

import torch 
from typing import Optional, Tuple, Type 
import torch.nn.functional as F  




def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:     
"""     
Partition into non-overlapping windows with padding if needed.     
Args:         
x (tensor): input tokens with [B, H, W, C].         
window_size (int): window size.      
Returns:         
windows: windows after partition with [B * num_windows, window_size, window_size, C].         
(Hp, Wp): padded height and width before partition     
"""     
B, H, W, C = x.shape      
pad_h = (window_size - H % window_size) % window_size     
pad_w = (window_size - W % window_size) % window_size     
if pad_h > 0 or pad_w > 0:         
	x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))     
Hp, Wp = H + pad_h, W + pad_w      
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)     
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)     
return windows, (Hp, Wp)  
def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor:     
"""     
Window unpartition into original sequences and removing padding.     
Args:         
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].         
window_size (int): window size.         
pad_hw (Tuple): padded height and width (Hp, Wp).         
hw (Tuple): original height and width (H, W) before padding.      
Returns:         
x: unpartitioned sequences with [B, H, W, C].     
"""     
Hp, Wp = pad_hw     
H, W = hw     
B = windows.shape[0] // (Hp * Wp // window_size // window_size)     
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)     
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)      
if Hp > H or Wp > W:         
x = x[:, :H, :W, :].contiguous()     
return x  

if __name__ == '__main__':     
	x=torch.randn((3,32,32,768)) # b,h,w,c     
	window_size=14     
	H, W = x.shape[1], x.shape[2]     
	x, pad_hw = window_partition(x, window_size)  # 使用window_size尺寸划分图像块     
	print("使用window_partition填充,修改尺寸格式为:",x.shape)     
	y = window_unpartition(x, window_size, pad_hw, (H, W))  # 在返回原有尺寸      
	print("window_unpartition,返回原有尺寸格式为:",y.shape)


结果显示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/393340.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java数据结构与算法刷题-----LeetCode150. 逆波兰表达式求值

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 解题思路 本题也叫后缀表达式,更利于机器处理题目给出的案例都…

Find My资讯|苹果Vision Pro无法通过Find My进行远程定位和发声

苹果 Vision Pro 头显现在已经正式开售,不过根据该公司日前发布的支持文件,这款头显目前缺乏一系列关键查找功能,用户无法在 iCloud 网站或Find My应用中获悉头显的位置,也无法让这款头显远程播放声音。 不过支持文件同时提到 V…

【快速解决】python项目打包成exe文件——vscode软件

目录 操作步骤 1、打开VSCode并打开你的Python项目。 2、在VSCode终端中安装pyinstaller: 3、运行以下命令使用pyinstaller将Python项目打包成exe文件: 其中your_script.py是你的Python脚本的文件名。 4、打包完成后,在你的项目目录中会…

多线程---线程同步,线程通信

线程同步 1.概述 线程同步是多线程编程中的一个重要概念,它指的是在多线程环境中,通过一定的机制保证多个线程按照某种特定的方式正确、有序地执行。这主要是为了避免并发问题,如死锁、竞态条件、资源争用等,确保数据的一致性和完…

Python 基于 AI 动物识别技术的研究与实现,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

hope实验室预备役第三次测试题解

目录 1.选数 2.奇怪的电梯 3.无线通讯网 4. Rotate Colored Subsequence 5.LOWER 6.Error Correction 1.选数 P1036 [NOIP2002 普及组] 选数 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 题目描述 已知 n 个整数 1,2,⋯ ,x1​,x2​,⋯,xn​,以及 1 个整…

VNCTF 2024 Web方向 WP

Checkin 题目描述:Welcome to VNCTF 2024~ long time no see. 开题,是前端小游戏 源码里面发现一个16进制编码字符串 解码后是flag CutePath 题目描述:源自一次现实渗透 开题 当前页面没啥好看的,先爆破密码登录试试。爆破无果…

洗地机什么牌子最好?家用洗地机推荐

如今洗地机已经在家庭中扮演着至关重要的角色,随着人们对居住环境的卫生要求越来越高,洗地机作为结合了吸尘和拖地为一体的清洁工具,不仅可以高效的帮助我们清洁地板,节省时间,还可以为我们节省很多收纳空间。那么&…

typeScript 类型推论

什么是类型推论? 类型推论是 TypeScript 中的一个特性,它允许开发人员不必显式地指定变量的类型。相反,开发人员可以根据变量的使用情况让 TypeScript 编译器自动推断出类型。例如,如果开发人员将一个字符串赋值给一个变量&#…

【力扣白嫖日记】1795.每个产品在不同商店的价格

前言 练习sql语句,所有题目来自于力扣(https://leetcode.cn/problemset/database/)的免费数据库练习题。 今日题目: 1795.每个产品在不同商店的价格 表:Products 列名类型product_idintstore1intstore2intstore3in…

项目中和兄弟部门难以高效协作?你需要注意这四点

在组织架构日益复杂的今天,靠一个人单打独斗完成工作或项目越来越难,也越来越不可能。不知你是否留意过,无论招聘什么岗位,几乎所有企业都在强调“团队合作”。 这里的团队不光指的是同部门协作,要包括公司内部的跨部门…

网络原理 - HTTP/HTTPS(1)

HTTP HTTP是什么 HTTP("全程超文本协议")是一种应用非常广泛的应用层协议. 文本:字符串(能在utf8/gbk)码表上找到合法字符. 超文本:不仅是字符串,还能携带图片啥的(HTML). 富文本:类似于word文档这种. HTTP诞生于1991年.目前已经发展为最主流使用的一种应用层协议.…

不等式的证明之二

不等式的证明之二 证明下述不等式证法一证法二证法二的补充 证明下述不等式 设 a , b , c a,b,c a,b,c 是正实数,请证明下述不等式: 11 a 5 a 6 b 11 b 5 b 6 c 11 c 5 c 6 a ≤ 3 \begin{align} \sqrt{\frac{11a}{5a6b}}\sqrt{\frac{11b}{5b6c}…

centos7如何切换到root用户

在 CentOS 7 中,你可以通过几种方式切换到 root 用户。最常用的方法是使用 su (switch user) 命令或者 sudo 命令。这里是如何使用这些命令的详细说明: 使用 su 命令 打开终端。输入以下命令并按下回车键:su -系统会提示你输入 root 用户的…

云手机在引流方面有什么优势?

对于电商商家而言,无论是在亚马逊还是其他平台,有效的流量来源主要集中在短视频引流和社交电商营销。要在新兴社交平台为企业电商带来更多流量,不可忽视云手机的关键作用和独特优势。 云手机的定义与作用 在经营TikTok、Facebook和INS账号时&…

外汇110:外汇做空是什么意思?如何运作?一文读懂

外汇市场允许卖空,就像众多金融市场一样。但什么是卖空呢?如何外汇做空?在本文中,我们将讨论如何做空货币。什么是外汇做空? 外汇做空(Short Selling)是外汇市场上的一种投资方式。它指的是投资…

Java面向对象案例之设计用户去ATM机存款取款(三)

需求及思路分析 业务代码需求: 某公司要开发“银行管理系统”,请使用面向对象的思想,设计银行的储户信息,描述存款、取款业务。 储户类的思路分析: 属性:用户姓名、密码、身份证号、账号、帐户余额 方法&a…

vue生命周期函数

父子组件加载顺序 加载渲染过程 父beforeCreate->父created->父beforeMount->子beforeCreate->子created->子beforeMount->子mounted->父mounted子组件更新过程 父beforeUpdate->子beforeUpdate->子updated->父updated父组件更新过程 父beforeU…

JS画布内生成图标,并实现拖拽,连线,刷新

JS有现成的拖拽命令&#xff0c;但是只能实现简单的拖拽功能&#xff0c;下面演示的可以在画布的任意一个地方拖拽&#xff0c;并停留在画布的任意地方。 整个框架代码如下&#xff1a; <html> <head><meta charset"UTF-8"><title>拖拽放置…

【详解】图的概念和存储结构(邻接矩阵,邻接表)

目录 图的基本概念&#xff1a; 图的存储结构 邻接矩阵&#xff08;GraphByMatrix&#xff09;&#xff1a; 基本参数&#xff1a; 初始化&#xff1a; 获取顶点元素在其数组中的下标 &#xff1a; 添加边和权重&#xff1a; 获取顶点的度&#xff1a; 打印图&#xf…