【论文笔记】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

论文链接:https://arxiv.org/abs/2404.02668

code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

 以下是作者针对该部分进行改进的代码:

def antidiagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_scatter(tensor_flat, original_shape):
    # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor_flat.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 创建一个空的张量来存储反向散布的结果
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_根据expanded_index将元素放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

def antidiagonal_scatter(tensor_flat, original_shape):
    # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor_flat.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 初始化一个与原始张量形状相同、元素全为0的张量
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_将元素根据索引放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

class CrossScan(torch.autograd.Function):
    # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        # xs = x.new_empty((B, 4, C, H * W))
        xs = x.new_empty((B, 8, C, H * W))
        # 添加横向和竖向的扫描
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    
        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x)
        xs[:, 5] = antidiagonal_gather(x)
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))

        y_res = y_rb + y_da
        # return y.view(B, -1, H, W)
        return y_res


class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)

        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))

        y_res = y_rb + y_da
        return y_res.view(B, D, -1)
        # return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        # xs = x.new_empty((B, 4, C, L))
        xs = x.new_empty((B, 8, C, L))

        # 横向和竖向扫描
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        # xs = xs.view(B, 4, C, H, W)

        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
        xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        # return xs
        return xs.view(B, 8, C, H, W)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/563537.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大模型系列课程学习-大预言模型微调方法介绍

1.大语言模型相关基本概念综述 语言模型指对语言进行建模,其起源于语音识别(speech recognition),输入一段音频数据,语音识别系统通常会生成多个句子作为候选,究竟哪个句子更合理? 学术上表达为:描述一段自…

GitHub登录收不到邮箱验证码

由于长时间没有登录GitHub,浏览器可能清除了相应的cookie信息,所以需要对应绑定邮箱进行验证,但因为邮箱长时间没有收到验证码,所以给到以下一种可能解决的方法: 需要输入验证码进行验证 我们可以打开QQ邮箱&#xff0…

Linux——网络管理nmcli

nmcli 不能独立使用,需要对应的服务启动 1. NetworkManager.service 2. 网络配置和服务不相关 3. 通过 nmcl i 建立网络配置和网卡之前的映射关系 网卡 简称:nmcli d DEVICE :物理设备 TYPE: 物理设备类型 ethernet 以太网…

【Java基础】25.包(package)

文章目录 前言一、包的作用二、创建包三、import 关键字四、package 的目录结构五、设置 CLASSPATH 系统变量 前言 为了更好地组织类,Java 提供了包机制,用于区别类名的命名空间。 一、包的作用 把功能相似或相关的类或接口组织在同一个包中&#xff…

Android 性能优化之黑科技开道(二)

3. 其它可以黑科技优化的方向 3.1 核心线程绑定大核 3.1.1 定义 核心线程绑定大核的思路也很容易理解,现在的 CPU 都是多核的,大核的频率比小核要高不少,如果我们的核心线程固定运行在大核上,那么应用性能自然会有所提升。 核…

C++相关概念和易错语法(8)(匿名对象、构造+拷贝构造优化、构造析构顺序)

1.匿名对象 当我们实例化对象后,有的对象可能只使用一次,之后就没用了。这个时候我们往往要主动去析构它,否则会占着浪费空间。但是如果遇到大量的这种情况,我们并不想每次都去创建对象、调用、析构,这样会写出很多重…

软考 系统架构设计师系列知识点之大数据设计理论与实践(15)

接前一篇文章:软考 系统架构设计师系列知识点之大数据设计理论与实践(14) 所属章节: 第19章. 大数据架构设计理论与实践 第4节 Kappa架构 19.4.5 常见Kappa架构变型 1. Kappa架构 Kappa是Uber提出的流式数据处理架构&#xff0…

传统与创新的交响:『线上求签祈福』游戏案例赏析

Part1. 设计背景 在当代社会,寺庙文化正经历一场复兴,尤其受到年轻一代的热烈欢迎。无论是在传统的节假日还是平日里,寺庙总是吸引着众多年轻人前来,他们怀着虔诚的心祈求平安健康或财富好运。在面对生活中难以抉择或无法掌控的情…

JAVA-服务器搭建-创建web后端项目

首先打开IDEA 点击新建项目 写好名称-模板选择 Web应用程序 -语言选择 Java 构建系统选择 Maven 然后点击下一步 选择版本-选择依赖项 Web Profile 点击创建 点击当前文件-选择编辑配置 选择左上角的加号-选择Tomcat服务器-选择本地 点击配置-选择到Tomcat目录-点击确定 起个…

创建会计凭证:BAPI_ACC_DOCUMENT_POST 增强字段

创建会计凭证:BAPI_ACC_DOCUMENT_POST 增强字段 在ABAP程序中使用BAPI_ACC_DOCUMENT_POST的时候,如果有些字段在Tables参数中没有,比如,现在大家都用Reason code来作为现金流量表的表现方案。但是在BAPI_ACC_DOCUMENT_POST的acco…

Java新特性(jdk8)

第一章-lambda表达式 1.函数式编程思想和Lambda表达式定义格式 1.面向对象思想: 强调的是找对象,帮我们去做事儿 比如:去北京 -> 强调的是怎么去,火车,高铁,飞机,汽车,自行车,腿儿 2.jdk8开始有了一个新的思想:函数式编程思想: 强调的是结…

FreeRTOS之任务挂起和恢复

1.本文介绍FreeRTOS的任务挂起和恢复函数。任务删除后将不再存在,不能恢复,而任务挂起是暂停任务,可以通过调用函数进行恢复。FreeRTOS任务挂起和恢复的主要步骤如下: (1)将相关的宏定义设置为1&#xff1…

OPAM模型(细粒度图像分类)

OPAM模型(细粒度图像分类) 摘要Abstract1. OPAM1.1 文献摘要1.2 细粒度图像分类1.3 研究背景1.4 OPAM模型创新点1.5 OPAM模型1.5.1 补丁过滤1.5.2 显着性提取1.5.3 细粒度区域级注意模型对象-空间约束方法(Object spatial constraint&#xf…

钟薛高创始人称卖红薯也把债还上:网友,您可千万别……

网红雪糕品牌钟薛高,是真的网红属性强到让所有消费品牌羡慕。 纵使跌落神坛、纵使站在「破产」边缘,依然话题感满满,隔段时间,总能上一个热搜。 比如欠薪上热搜、产品降价上热搜、甚至官网微博微信停更,也得上个热搜&…

MLLM | InternLM-XComposer2-4KHD: 支持336 像素到 4K 高清的分辨率的大视觉语言模型

上海AI Lab,香港中文大学等 论文标题:InternLM-XComposer2-4KHD: A Pioneering Large Vision-Language Model Handling Resolutions from 336 Pixels to 4K HD 论文地址:https://arxiv.org/abs/2404.06512 Code and models are publicly available at https://gi…

.net core webapi 添加日志管理看板LogDashboard

.net core webapi 添加日志管理看板LogDashboard 添加权限管理&#xff1a; 我们用的是Nlog文件来配置 <?xml version"1.0" encoding"utf-8" ?> <nlog xmlns"http://www.nlog-project.org/schemas/NLog.xsd"xmlns:xsi"http:/…

网络基础-TCP/IP和OSI协议模型

一、OSI和TCP/IP模型 二、OSI七层模型 三、TCP/IP模型 参考&#xff1a;https://www.cnblogs.com/f-ck-need-u/p/7623252.html

Scanpy(1)数据结构和样本过滤

注&#xff1a;主要讲述scanpy处理数据的结构、数据过滤&#xff08;生信领域&#xff09;和数据预处理&#xff08;和机器学习类似&#xff0c;但是又有不同。&#xff09; 1. Scanpy简介与安装 Scanpy 是一个可扩展的工具包&#xff0c;用于分析与 AnnData&#xff08;一种…

螺纹滑牙的原因有哪些——SunTorque智能扭矩系统

螺纹滑牙的原因&#xff0c;通常是由于在旋紧或旋松过程中&#xff0c;螺纹副之间的摩擦力不足以维持所需的预紧力或工作载荷&#xff0c;导致螺纹副的相对位置发生变化。这种现象可能由多种因素引起&#xff0c;包括材料选择不当、设计不合理、制造工艺缺陷、环境因素以及使用…

欧科云链:香港虚拟资产OTC合规在即,技术监管成市场规范关键

4月12日香港OTC发牌制度公众咨询结束后&#xff0c;欧科云链研究院在星岛日报发表专栏文章&#xff0c;分享对香港OTC市场的调研情况&#xff0c;并提出“技术监管是香港OTC及Web3生态走向规范的关键”。欧科云链研究院认为&#xff0c;随着OTC监管及虚拟资产现货ETF等事件向前…