【图像分割】mask2former:通用的图像分割模型详解

最近看到几个项目都用mask2former做图像分割,虽然是1年前的论文,但是其attention的设计还是很有借鉴意义,同时,mask2former参考了detr的query设计,实现了语义和实例分割任务的统一。

1.背景

1.1 detr简介

detr算是第一个尝试用transformer实现目标检测的框架,其设计思路也很简单,就是定义object queries,用来查询是否存在目标以及目标位置的,类似cnn检测中的rpn,产生候选框。在detr中,object queries为(100,b,256)的可学习的参数,其中每个256维的向量代表了检测的box信息,这个信息是由类别和空间信息(box坐标)组成,其中类别信息用于区别类别,而空间信息则描述了目标在图像中的位置。

通过设置query,则不需要像传统cnn检测时预设anchor,最后通过匈牙利匹配算法将query到的目标和gt进行匹配,计算loss。

decoder过程中,query object先初始化为0,然后经过self attention,再和encoder的输出进行cross attention。

1.2 Deformable-DETR简介

Deformable-Detr是在detr的基础上了主要做了2个改进,Deformable attention(可变形注意力)和多尺度特征,通过可变性注意力降低了显存,多尺度特征对小目标检测效果比较好。

(1)Deformable attention(可变形注意力)

这个设计参考了可变性卷积(DCN),后续很多设计都参考了这个。先看下DCN,就是在标准卷积(a)的3 * 3的卷积核上,每个点上增加一个偏移量(dx,dy),让卷积核不规则,可以适应目标的形状和尺度。

对于一般的attention,query与key的每个值都要计算注意力,这样的问题就是耗显存;另外,对图像来说,假设其中有一个目标,一般只有离图像比较近的像素才有用,离比较远的像素,对目标的贡献很少,甚至还有负向的干扰。

Defromable attention的设计思路就是query不与全局的key进行计算,而是至于其周围的key进行计算。至于这个周围要选哪几个位置,就类似DCN,让模型自己去学。

  • 单尺度的可变性注意力机制

DeformAttn的公式如下:

  • 多尺度的可变性注意力机制

多尺度即类似fpn,提取不同尺度的特征,但由于特征的尺寸不一样,需要将不同尺度的特征连接起来。

可变性注意力机制公式如下:

相比单尺度的,多尺度多了一个l,代表第几个尺度,一般取4个层级。

对于一个query,在其参考点(reference point)对应的所有层都采用K个点,然后将每层的K个点特征融合(相加)。

整个deformable atten的流程如下:

2.mask2former

mask2former的设计上使用了deformable detr的可变形注意力。

主要计算过程用下图表示:

2.1 模型改进

(1)masked attention

一般计算过程中,计算atten时只用前景部分计算,减少显存占用。

(2) 多分辨率特征

如上图,图像经过backbone得到4层特征,然后经过Pixel Decoder得到O1,O2,O3,O4,注意O1,O2,O3经过Linear+Deform atten Layer,O4只通过Linear+卷积得到,具体可以区别看上图。

(3) decoder优化

在transformer decoder(这个过程用的是标准attention)计算过程中,query刚开始都是随机初始化的,没有图像特征,如果按常规直接self attention可能学不到充分的信息,所以将ca和sa两个模块反过来,先和pixdecoder得到的图像O1,O2,O3计算ca,再继续计算sa。

2.2 类别和mask分开预测

class和mask预测独立开来,mask只预测是背景还是前景,class负责预测类别,这部分保留了maskformer的设计。

如上图,class通过query加上Linear直接将维度转到(n,k+1),其中k为类别数目。

mask通过decoder和最后一层的mask做外积运算,得到(k,h,w)的tensor,每个k代表一个前景。

采用这种query的方式,既可以做instance也可以做语义分割,query的数量N和类别K数量无关。

2.3 loss优化

mask decoder过程中,主要用最后一层的输出计算loss;同时为了辅助训练,默认开启了auxiliary loss(辅助loss),其他层的输出也去计算loss。

还有一个trick,mask计算loss时,不是mask上的所有点都去计算,而是随机采样一定数目的点去计算loss。默认设置= 12544, i.e., 112 × 112 points,这样可以节省显存。

3.扩展

3.1 DAT:另一个Deform atten设计

另一篇deform atten的论文DAT,和deform attention思路类似,也是学习offset。只不过在偏移量设计上有区别,如下图所示,DAT在当前特征图F上学习offset时,进行了上采样2倍,在得到offset后需要插值回F的尺寸,增加了相对位置的bias。

对比几种查询的注意力结果,vit是全查,swin固定窗口大小,有可能限制查到的key,DCN为可变性卷积,DAT学到的key更好。

模型设计上,参考swin-transformer,只将最后2层替换Deformable attention,效果最好。

3.2 视频实例分割跟踪

mask2former用于视频分割,结构如下

模型结构上和图像的分割基本一致。

修改主要在transformer decoder,包含以下3个地方:

(1)增加时间编码t

主要在Transformer decoder过程,图像的位置编码为(x,y),对于视频,由于考虑了多帧数据,增加时间t进行编码,位置编码为(x,y,t)。

       # b, t, c, h, w
        assert x.dim() == 5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        z_embed = not_mask.cumsum(1, dtype=torch.float32)  # not_mask【bath,t,h,w】1代表时间列的索引,cumsum累加计算,得到位置id
        y_embed = not_mask.cumsum(2, dtype=torch.float32)  # h
        x_embed = not_mask.cumsum(3, dtype=torch.float32)  # w
        if self.normalize:
            eps = 1e-6
            z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
            y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)
        dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))

        pos_x = x_embed[:, :, :, :, None] / dim_t  # [b,t,h,w]->[b,t,h,w,d] xy编码的d长度是位置编码向量长度的一半
        pos_y = y_embed[:, :, :, :, None] / dim_t
        pos_z = z_embed[:, :, :, :, None] / dim_t_z # z用编码向量长度,然后和xy编码相加
        pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)  # b, t, c, h, w

(2) query和多帧数据进行atten计算

        for i in range(self.num_feature_levels):
            size_list.append(x[i].shape[-2:])
            pos.append(self.pe_layer(x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None).flatten(3))
            src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])  #level_embed size [level_num,d],level embed和输入相加

            # NTxCxHW => NxTxCxHW => (TxHW)xNxC  # 多帧数据融合
            _, c, hw = src[-1].shape
            pos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
            # 其中src是Pixel decoder的输出
            src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)

(3)query和mask计算优化

如代码所示,query和mask 外积计算,从q外积mask得到mask的shape为[b,q,t,h,w],也就是得到(b,q,t)个instance mask,然后query的instance mask和每帧的gt计算loss。

    def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
        decoder_output = self.decoder_norm(output)
        decoder_output = decoder_output.transpose(0, 1)
        outputs_class = self.class_embed(decoder_output)
        mask_embed = self.mask_embed(decoder_output)
        # query和mask 外积计算,从q外积mask得到[b,q,t,h,w]个mask
        outputs_mask = torch.einsum("bqc,btchw->bqthw", mask_embed, mask_features)
        b, q, t, _, _ = outputs_mask.shape

        # NOTE: prediction is of higher-resolution
        # [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]
        attn_mask = F.interpolate(outputs_mask.flatten(0, 1), size=attn_mask_target_size, mode="bilinear", align_corners=False).view(
            b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])
        # must use bool type
        # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
        attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
        attn_mask = attn_mask.detach()

        return outputs_class, outputs_mask, attn_mask

训练时是以instance作为一个基础单元,假设有t帧图像,有n个instance(实例),instance和frame的关系如下图表示:

instance在每帧上都可能存在或者不存在。对于每个instance,初始化t个mask,初始化为0,所以instace的shape是[b,n,t,h,w],如果这个instance在某帧上存在,即赋真值mask,用于匹配计算loss;不存在,即为0。

instance在每帧上都是同一个物体(形态可能变化,但是instance id是相同的),所以预测instance的类别时,每个instance只需要预测一个类别即可,所以类别的shape为[b,n]

3.3 思考

sam(segment anything model)可以通过prompt进行分割,但是缺乏类别信息,可以参考mask2former的思想,mask和类别是独立的,可以添加分类的query,接一个分类的分支,然后在coco等数据集上单独训练这个分支,让sam分割后增加类别信息。

4.参考资料

  • mask2former论文
  • mask2former代码


附赠

【一】上千篇CVPR、ICCV顶会论文
【二】动手学习深度学习、花书、西瓜书等AI必读书籍
【三】机器学习算法+深度学习神经网络基础教程
【四】OpenCV、Pytorch、YOLO等主流框架算法实战教程

➤ 在助理处自取:

➤ 还可咨询论文辅导❤【毕业论文、SCI、CCF、中文核心、El会议】评职称、研博升学、本升海外学府!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/772518.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

.NET发布成单个文件后获取不到程序所在路径的问题

.net程序不发布成单个文件&#xff0c;所以运行都是正常的&#xff0c;但是发布成单个文件后发现使用&#xff1a; var basePath Path.GetDirectoryName((System.Reflection.Assembly.GetExecutingAssembly().Location)); 获取不到应用程序所在的路径了。 找一下几个获取本程…

Flutter集成高德导航SDK(Android篇)(JAVA语法)

先上flutter doctor&#xff1a; flutter sdk版本为&#xff1a;3.19.4 引入依赖&#xff1a; 在app的build.gradle下&#xff0c;添加如下依赖&#xff1a; implementation com.amap.api:navi-3dmap:10.0.700_3dmap10.0.700navi-3dmap里面包含了定位功能&#xff0c;地图功能…

Cloudflare 推出一款免费对抗 AI 机器人的可防止抓取数据工具

上市云服务提供商Cloudflare推出了一种新的免费工具&#xff0c;可防止机器人抓取其平台上托管的网站以获取数据以训练AI模型。 一些人工智能供应商&#xff0c;包括谷歌、OpenAI 和苹果&#xff0c;允许网站所有者通过修改他们网站的robots.txt来阻止他们用于数据抓取和模型训…

系统架构设计师教程(清华第2版)<第1章 绪论>解读

系统架构设计师教程 第一章 绪论 1.1 系统架构概述1.1.1 系统架构的定义及发展历程1.1.2 软件架构的常用分类及建模方法1.1.3 软件架构的应用场景1.1.4 软件架构的发展未来1.2 系统架构设计师概述1.2.1 架构设计师的定义、职责和任务1.2.2 架构设计师应具备的专业素质1.3 如何成…

Unity中TimeLine的一些用法

Unity中TimeLine的一些用法 概念其他 概念 无Track模式&#xff08;PlayableAsset、PlayableBehaviour&#xff09; 1. 两者关系 运行在PlayableTrack中作用 PlayableBehaviour 实际执行的脚本字段并不会显示在timeline面板上 PlayableAsset PlayableBehaviour的包装器&#x…

电脑彻底删除的文件还能恢复吗怎么弄 电脑删除的文件怎么恢复 回收站也删了

实测可行的文件恢复方法&#xff0c;无论是彻底删除的文件&#xff0c;还是被清空的回收站文件&#xff0c;使用该方法都可以轻松找回。整个恢复过程操作简单&#xff0c;并且绝不会损伤电脑硬件。这意味着&#xff0c;您再也不用为误删文件而焦虑了。有关电脑彻底删除的文件还…

【Windows】Bootstrap Studio(网页设计)软件介绍及安装步骤

软件介绍 Bootstrap Studio 是一款专为前端开发者设计的强大工具&#xff0c;主要用于快速创建现代化的响应式网页和网站。以下是它的主要特点和功能&#xff1a; 直观的界面设计 Bootstrap Studio 提供了直观的用户界面&#xff0c;使用户能够轻松拖放元素来构建网页。界面…

audo dl上使用tensorrt llm,baichuan7B为例

1. 在社区镜像搜索 nvidia 找一个tensorrt llm 0.10 以上的版本&#xff0c;系统盘30g安装软件应该够用&#xff0c;免费的数据盘50G用来存放模型。baichuan7B原始模型应该会占用14G&#xff0c;转换为fp16的 ckpt后再占用14G&#xff0c;build后占用14G。总共需要占用42G&…

视频太大发不出去怎么处理,视频太大发不了邮件怎么办

在数字化时代&#xff0c;视频已成为我们分享生活、传递信息的重要方式。然而&#xff0c;当遇到视频文件过大&#xff0c;无法发送或分享时&#xff0c;你是否感到困扰&#xff1f;别担心&#xff0c;本文将为你揭秘轻松解决视频太大发不了的问题。 电脑频编辑器可以用于简单的…

工业智能网关的作用有哪些?工业智能网关与传统网关的主要区别-天拓四方

工业智能网关是一种专为工业环境设计的网络设备&#xff0c;具备数据采集、传输、协议转换以及边缘计算等功能。它作为连接工业设备与互联网的关键枢纽&#xff0c;不仅实现了工业设备的互联互通&#xff0c;还通过对采集到的数据进行实时分析&#xff0c;为工业生产的智能化管…

第一百四十三节 Java数据类型教程 - Java Boolean包装类

Java数据类型教程 - Java Boolean包装类 布尔类的对象包装一个布尔值。 Boolean.TRUE和Boolean.FALSE是布尔类型的两个常量&#xff0c;用于表示布尔值true和false值。 我们可以使用构造函数或valueOf()工厂方法创建一个布尔对象。 当解析字符串时&#xff0c;此类将处理“t…

软考-软件设计师 知识点整理(一篇就过了 建议收藏)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、前言&#x1f680;&#x1f680;&#x1f680;二、正文☀️☀️☀️1.进制转换2.码制3.浮点数表示4.逻辑运算5.奇偶校验6.CRC循环冗余7.海明校验码8.CPU组成&am…

mybatis实现动态sql

第一章、动态SQL MyBatis 的强大特性之一便是它的动态 SQL。如果你有使用 JDBC 或其它类似框架的经验&#xff0c;你就能体会到根据不同条件拼接 SQL 语句的痛苦。例如拼接时要确保不能忘记添加必要的空格&#xff0c;还要注意去掉列表最后一个列名的逗号。利用动态 SQL 这一特…

关于虚拟机CentOS 7使用ssh无法连接(详细)

虚拟机CentOS 7使用ssh无法连接 猜测&#xff1a;可能是虚拟机软件的网关和和centos7的网关不同导致的问题。 首先打开CentOS7的终端, 输入ifconfig&#xff0c;查看一下系统的ip 打开虚拟机的虚拟网络编辑器, 查看一下网关, 发现确实不一样. 这里有两种方式, 要么修改虚…

104.二叉树的最大深度——二叉树专题复习

深度优先搜索&#xff08;DFS&#xff09;是一种常用的递归算法&#xff0c;用于解决树形结构的问题。在计算二叉树的最大深度时&#xff0c;DFS方法会从根节点开始&#xff0c;递归地计算左右子树的最大深度&#xff0c;然后在返回时更新当前节点所在路径的最大深度。 如果我…

gin项目部署到服务器并后台启动

文章目录 一、安装go语言环境的方式1.下载go安装包&#xff0c;解压&#xff0c;配置环境变量2.压缩项目上传到服务器并解压3.来到项目的根目录3.开放端口&#xff0c;运行项目 二、打包的方式1.在项目的根目录下输入以下命令2.把打包好的文件上传到服务器3.部署网站4.ssl证书 …

Web前端开发——HTML快速入门

HTML&#xff1a;控制网页的结构CSS&#xff1a;控制网页的表现 一、什么是HTML、CSS &#xff08;1&#xff09;HTML &#xff08;HyperText Markup Languaqe&#xff1a;超文本标记语言&#xff09; 超文本&#xff1a;超越了文本的限制&#xff0c;比普通文本更强大。除了…

vienna整流器过零畸变原因分析

Vienna整流器是一种常见的三电平功率因数校正&#xff08;PFC&#xff09;整流器&#xff0c;广泛应用于电源和电能质量控制领域。由于其高效率、高功率密度和低谐波失真的特点&#xff0c;Vienna整流器在工业和电力电子应用中具有重要地位。然而&#xff0c;在实际应用中&…

新手拍短视频的些许建议

1、尽早行动&#xff0c;拒绝完美主义&#xff0c;有手机就能上车&#xff0c;一开始别花太多时间在打磨细节上。总是要准备好了后再做&#xff0c;就总比别人慢一步&#xff0c;可能永远也追不上了&#xff1b; 2、坚持发&#xff0c;度过难熬的启动期就行&#xff0c;不要走…

比Proxmox VE更易用的免费虚拟化平台

之前虚拟化一直玩Proxmox VE&#xff0c;最近发现一个更易用的虚拟化软件CSYun&#xff0c;他与Proxmox VE类似&#xff0c;都是一个服务器虚拟化平台。它不像VMware ESXi那么复杂&#xff0c;对于个人使用者和中小企业是一个比较好的选择。 这个软件所在的网址为&#xff1a;…