每日Attention学习22——Inverted Residual RWKV

模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import load


T_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
                verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])


def get_norm(norm_layer='in_1d'):
	eps = 1e-6
	norm_dict = {
		'none': nn.Identity,
		'in_1d': partial(nn.InstanceNorm1d, eps=eps),
		'in_2d': partial(nn.InstanceNorm2d, eps=eps),
		'in_3d': partial(nn.InstanceNorm3d, eps=eps),
		'bn_1d': partial(nn.BatchNorm1d, eps=eps),
		'bn_2d': partial(nn.BatchNorm2d, eps=eps),
		# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
		'bn_3d': partial(nn.BatchNorm3d, eps=eps),
		'gn': partial(nn.GroupNorm, eps=eps),
		'ln_1d': partial(nn.LayerNorm, eps=eps),
		# 'ln_2d': partial(LayerNorm2d, eps=eps),
	}
	return norm_dict[norm_layer]


def get_act(act_layer='relu'):
	act_dict = {
		'none': nn.Identity,
		'sigmoid': Sigmoid,
		'swish': Swish,
		'mish': Mish,
		'hsigmoid': HardSigmoid,
		'hswish': HardSwish,
		'hmish': HardMish,
		'tanh': Tanh,
		'relu': nn.ReLU,
		'relu6': nn.ReLU6,
		'prelu': PReLU,
		'gelu': GELU,
		'silu': nn.SiLU
	}
	return act_dict[act_layer]


class ConvNormAct(nn.Module):
	def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
				 skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
		super(ConvNormAct, self).__init__()
		self.has_skip = skip and dim_in == dim_out
		padding = math.ceil((kernel_size - stride) / 2)
		self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
		self.norm = get_norm(norm_layer)(dim_out)
		self.act = get_act(act_layer)(inplace=inplace)
		self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
	
	def forward(self, x):
		shortcut = x
		x = self.conv(x)
		x = self.norm(x)
		x = self.act(x)
		if self.has_skip:
			x = self.drop_path(x) + shortcut
		return x
      

class SE(nn.Module):
    def __init__(
            self,
            in_chs: int,
            rd_ratio: float = 0.25,
            rd_channels: Optional[int] = None,
            act_layer: LayerType = nn.ReLU,
            gate_layer: LayerType = nn.Sigmoid,
            force_act_layer: Optional[LayerType] = None,
            rd_round_fn: Optional[Callable] = None,
    ):
        super(SE, self).__init__()
        if rd_channels is None:
            rd_round_fn = rd_round_fn or round
            rd_channels = rd_round_fn(in_chs * rd_ratio)
        act_layer = force_act_layer or act_layer
        self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
        self.act1 = create_act_layer(act_layer, inplace=True)
        self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        return x * self.gate(x_se)
    

def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):
    assert gamma <= 1/4
    B, N, C = input.shape
    input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])
    B, C, H, W = input.shape
    output = torch.zeros_like(input)
    output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]
    output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]
    output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]
    output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]
    output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]
    return output.flatten(2).transpose(1, 2)


def RUN_CUDA(B, T, C, w, u, k, v):
    return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())


class WKV(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, w, u, k, v):
        ctx.B = B
        ctx.T = T
        ctx.C = C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0

        half_mode = (w.dtype == torch.half)
        bf_mode = (w.dtype == torch.bfloat16)
        ctx.save_for_backward(w, u, k, v)
        w = w.float().contiguous()
        u = u.float().contiguous()
        k = k.float().contiguous()
        v = v.float().contiguous()
        y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
        wkv_cuda.forward(B, T, C, w, u, k, v, y)
        if half_mode:
            y = y.half()
        elif bf_mode:
            y = y.bfloat16()
        return y

    @staticmethod
    def backward(ctx, gy):
        B = ctx.B
        T = ctx.T
        C = ctx.C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        w, u, k, v = ctx.saved_tensors
        gw = torch.zeros((B, C), device='cuda').contiguous()
        gu = torch.zeros((B, C), device='cuda').contiguous()
        gk = torch.zeros((B, T, C), device='cuda').contiguous()
        gv = torch.zeros((B, T, C), device='cuda').contiguous()
        half_mode = (w.dtype == torch.half)
        bf_mode = (w.dtype == torch.bfloat16)
        wkv_cuda.backward(B, T, C,
                          w.float().contiguous(),
                          u.float().contiguous(),
                          k.float().contiguous(),
                          v.float().contiguous(),
                          gy.float().contiguous(),
                          gw, gu, gk, gv)
        if half_mode:
            gw = torch.sum(gw.half(), dim=0)
            gu = torch.sum(gu.half(), dim=0)
            return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
        elif bf_mode:
            gw = torch.sum(gw.bfloat16(), dim=0)
            gu = torch.sum(gu.bfloat16(), dim=0)
            return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
        else:
            gw = torch.sum(gw, dim=0)
            gu = torch.sum(gu, dim=0)
            return (None, None, None, gw, gu, gk, gv)
        

class VRWKV_SpatialMix(nn.Module):
    def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):
        super().__init__()
        self.n_embd = n_embd
        attn_sz = n_embd
        self._init_weights()
        self.shift_pixel = shift_pixel
        if shift_pixel > 0:
            self.channel_gamma = channel_gamma
        else:
            self.spatial_mix_k = None
            self.spatial_mix_v = None
            self.spatial_mix_r = None

        self.key = nn.Linear(n_embd, attn_sz, bias=False)
        self.value = nn.Linear(n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(n_embd, attn_sz, bias=False)
        self.key_norm = nn.LayerNorm(n_embd)
        self.output = nn.Linear(attn_sz, n_embd, bias=False)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def _init_weights(self):
        self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))
        self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))
        self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
        self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
        self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
    def jit_func(self, x, patch_resolution):
        # Mix x with the previous timestep to produce xk, xv, xr
        B, T, C = x.size()
        # Use xk, xv, xr to produce k, v, r
        if self.shift_pixel > 0:
            xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)
            xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)
            xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)
            xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)
        else:
            xk = x
            xv = x
            xr = x
        k = self.key(xk)
        v = self.value(xv)
        r = self.receptance(xr)
        sr = torch.sigmoid(r)
        return sr, k, v

    def forward(self, x, patch_resolution=None):
        B, T, C = x.size()
        sr, k, v = self.jit_func(x, patch_resolution)
        x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)
        x = self.key_norm(x)
        x = sr * x
        x = self.output(x)
        return x
    

class iR_RWKV(nn.Module):
    def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
                 act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,
                 attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):
        super().__init__()
        self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
        dim_mid = int(dim_in * exp_ratio)
        self.ln1 = nn.LayerNorm(dim_mid)
        self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)
        self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
        if attn_s==True:
                self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)
        self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
        self.proj_drop = nn.Dropout(drop)
        self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
        self.attn_s=attn_s
        self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)
        
    def forward(self, x):
        shortcut = x
        x = self.norm(x)
        x = self.conv(x)
        if self.attn_s:
            B, hidden, H, W = x.size()
            patch_resolution = (H,  W)
            x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)
            x = x.permute(0, 2, 1)
            x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))
            B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidde
            h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
            x = x.permute(0, 2, 1)
            x = x.contiguous().view(B, hidden, h, w)
        x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
        x = self.proj_drop(x)
        x = self.proj(x)
        x = (shortcut + self.drop_path(x)) if self.has_skip else x
        return x


if __name__ == '__main__':
    x = torch.randn([1, 64, 11, 11]).cuda()
    ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()
    out = ir_rwkv(x)
    print(out.shape)  # [1, 64, 11, 11]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965435.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Git--使用教程

Git的框架讲解 Git 是一个分布式版本控制系统&#xff0c;其架构设计旨在高效地管理代码版本&#xff0c;支持分布式协作&#xff0c;并确保数据的完整性和安全性。 Git 的核心组件&#xff1a; 工作区&#xff08;Working Directory&#xff09;&#xff1a; 工作区是你在本…

智慧停车系统:不同规模停车场的应用差异与YunCitys解决方案

在智慧停车领域&#xff0c;不同规模停车场因自身特点&#xff0c;对智慧停车系统的需求和应用效果存在显著差异。云创智城凭借丰富的经验和先进的技术&#xff0c;为各类规模停车场打造了贴合需求的智慧停车系统&#xff0c;下面为您详细剖析。 小型停车场&#xff1a;精准高…

snort的学习记录

一、what is snort&#xff1f;什么是snort? Snort 是一款开源的 网络入侵检测系统&#xff08;NIDS&#xff09; 和 网络入侵防御系统&#xff08;NIPS&#xff09;&#xff0c;能够实时监控网络流量&#xff0c;检测恶意行为&#xff08;如端口扫描、SQL注入、DDoS攻击等&a…

PHP-trim

[题目信息]&#xff1a; 题目名称题目难度PHP-trim1 [题目考点]&#xff1a; trim() 函数移除字符串两侧的空白字符或其他预定义字符。[Flag格式]: SangFor{dl9hFiITmhQNAJysCgigAskyCZ6kQaDc}[环境部署]&#xff1a; docker-compose.yml文件或者docker tar原始文件。 ht…

maven如何不把依赖的jar打包到同一个jar?

spring boot项目打jar包部署&#xff1a; 经过以下步骤&#xff0c; 最终会形成maven依赖的多个jar&#xff08;包括lib下添加的&#xff09;、 我们编写的程序代码打成一个jar&#xff0c;将程序jar与 依赖jar分开&#xff0c;便于管理&#xff1a; success&#xff1a; 最终…

【ArcGIS Pro 简介1】

ArcGIS Pro 是由 Esri &#xff08;Environmental Systems Research Institute&#xff09;公司开发的下一代桌面地理信息系统&#xff08;GIS&#xff09;软件&#xff0c;是传统 ArcMap 的现代化替代产品。它结合了强大的空间分析能力、直观的用户界面和先进的三维可视化技术…

DeepSeek 部署过程中的问题

文章目录 DeepSeek 部署过程中的问题一、部署扩展&#xff1a;docker 部署 DS1.1 部署1.2 可视化 二、问题三、GPU 设置3.1 ollama GPU 的支持情况3.2 更新 GPU 驱动3.3 安装 cuda3.4 下载 cuDNN3.5 配置环境变量 四、测试 DeepSeek 部署过程中的问题 Windows 中 利用 ollama 来…

Elasticsearch基本使用详解

文章目录 Elasticsearch基本使用详解一、引言二、环境搭建1、安装 Elasticsearch2、安装 Kibana&#xff08;可选&#xff09; 三、索引操作1、创建索引2、查看索引3、删除索引 四、数据操作1、插入数据2、查询数据&#xff08;1&#xff09;简单查询&#xff08;2&#xff09;…

点(线)集最小包围外轮廓效果赏析

“ 图像、点集、线集合最小外轮廓计算应用较为广泛&#xff0c;如抠图、神奇选择、LOD、碰撞检查等领域&#xff0c;提高场景效率” 1.前言 作者基于递归迭代求解实现点集的最小外轮廓计算&#xff0c;在CGLib库中实现&#xff0c;已集成于CGViewer&#xff0c;可联系作者试用&…

IOPS与吞吐量、读写块大小及延迟之间的关系

IOPS&#xff08;每秒输入/输出操作次数&#xff09;、吞吐量、读写块大小及延迟是衡量存储系统性能的四个关键指标&#xff0c;它们之间存在密切的关系。以下从多个方面详细说明这些指标之间的关系&#xff1a; 1. IOPS与吞吐量的关系 公式关系&#xff1a;吞吐量&#xff0…

FPGA与ASIC:到底选哪个好?

不少人想转行FPGA&#xff0c;但在ASIC和FPGA之间犹豫不决。要做出选择&#xff0c;首先需要清楚两者的区别和各自特点。 FPGA (Field Programmable Gate Array) 是一种现场可编程门阵列芯片&#xff0c;本质上它是一种半定制的芯片&#xff0c;可以根据需要重新编程&#xff…

LLMs之data:synthetic-data-generator的简介、安装和使用方法、案例应用之详细攻略

LLMs之data&#xff1a;synthetic-data-generator的简介、安装和使用方法、案例应用之详细攻略 目录 synthetic-data-generator的简介 1、核心功能和优势 2、特点 synthetic-data-generator的安装和使用方法 1、安装 pip安装 安装依赖项 运行应用 2、使用方法 快速入…

第二天:系统从BIOS/UEFI到GRUB/bootloader的启动过程

目录 **一、BIOS/UEFI初始化阶段****二、引导加载程序&#xff08;GRUB&#xff09;的启动过程****1. BIOS模式下的GRUB分阶段加载****2. UEFI模式下的GRUB加载** **三、操作系统内核加载与初始化****四、关键组件与配置文件****五、故障排查与恢复****总结**常见问题如何在UEF…

113,【5】 功防世界 web unseping

进入靶场 代码审计 <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;方便开发者查看代码结构和内容 highlight_file(__FILE__);// 定义一个名为 ease 的类 class ease {// 私有属性 $method&#xff0c;用于存储要调用的方法名private $method;// 私有属性 $args&…

图解BWT(Burrows-Wheeler Transform) 算法

Burrows-Wheeler Transform (BWT) 是一种数据转换算法, 主要用于数据压缩领域. 它由 Michael Burrows 和 David Wheeler 在 1994 年提出, 广泛应用于无损数据压缩算法(如 bzip2)中. BWT 的核心思想是通过重新排列输入数据, 使得相同的字符更容易聚集在一起, 从而提高后续压缩算…

DeepSeek辅助学术写作【句子重写】效果如何?

句子重写(功能指数:★★★★★) 当我们想引用一篇文章中的一-些我们认为写得很好的句子时&#xff0c;如果直接将原文加人自己的文章&#xff0c;那么即使我们标注上了引用&#xff0c;也依旧会被查重软件计算在重复比例中。查重比例过高的话&#xff0c;会影响投稿或毕业答辩送…

【CAD】卸载清理注册表后安装失败/错误 1402无法打开主键:UNKNOWN Components DAFE。。。

这是因为注册表中的 CurrentVersion\Installer\UserData二级子目录中Components的文件缺少权限&#xff0c;只需要将权限添加为administors即可 一图流&#xff0c;就是把权限给到Administrators&#xff0c;然后就再以管理员权限安装就可以了&#xff0c;还有如果破解的软件可…

w192中国陕西民俗网的设计与实现

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;原创团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339;赠送计算机毕业设计600个选题excel文…

2.6日学习总结

题目一&#xff1a; AC代码&#xff1a; #include <stdio.h>// 宏 _for 用于简化 for 循环 #define _for(i, a, b) for (int i (a); i < (b); i)// 最大节点数 #define MAXN 1000010// 树节点结构体 typedef struct {int left;int right; } Node;// 树节点数组 Nod…

【数据结构】(4) 线性表 List

一、什么是线性表 线性表就是 n 个相同类型元素的有限序列&#xff0c;每一个元素只有一个前驱和后继&#xff08;除了第一个和最后一个元素&#xff09;。 数据结构中&#xff0c;常见的线性表有&#xff1a;顺序表、链表、栈、队列。 二、什么是 List List 是 Java 中的线性…