文章目录
- 前言
- 一、切割图像(window_partition)
- 二、还原图像(window_unpartition)
- 三、整体代码
前言
假如b ,h,w,c=(3,32,32,768)需将h w按照14尺寸切割,32/14无法整除,需pad为(3,42,42,768)完成固定尺寸块切割,进而完成transformer结构,最终摒弃pad数据还原为(3,32,32,768)。在使用Transformer结构提取特征时,通常会使用window_partition和window_unpartition来划分和还原图像块的过程。这两个步骤是为了将图像分割成小块,送入Transformer网络进行处理,然后再将处理后的特征重新组合成原始图像的尺寸。为此,我摘录TAM大模型处理方法代码,记录图像尺寸切割与还原。
一、切割图像(window_partition)
这一步骤是将原始图像按照设定的窗口大小划分成多个块,并将这些块重新排列成一个较大的矩阵,以便送入Transformer网络。
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
""" Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
二、还原图像(window_unpartition)
在完成特征提取后,使用window_unpartition将处理后的特征重新还原为原始图像的尺寸。这样可以保持特征与原始图像之间的对应关系。
def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
三、整体代码
import torch
from typing import Optional, Tuple, Type
import torch.nn.functional as F
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
if __name__ == '__main__':
x=torch.randn((3,32,32,768)) # b,h,w,c
window_size=14
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, window_size) # 使用window_size尺寸划分图像块
print("使用window_partition填充,修改尺寸格式为:",x.shape)
y = window_unpartition(x, window_size, pad_hw, (H, W)) # 在返回原有尺寸
print("window_unpartition,返回原有尺寸格式为:",y.shape)
结果显示: