【实验】SegViT: Semantic Segmentation with Plain Vision Transformers

在这里插入图片描述
想要借鉴SegViT官方模型源码部署到本地自己代码文件中

1. 环境配置

官网要求安装mmcv-full=1.4.4和mmsegmentation=0.24.0
在这之前记得把mmcv和mmsegmentation原来版本卸载

pip uninstall mmcv
pip uninstall mmcv-full
pip uninstall mmsegmentation

安装mmcv

其中,mmcv包含两个版本:一个是完整版mmcv(原来叫mmcv-full),一个是精简版mmcv-lite(原来叫mmcv),2.0.0版本之后更名了,具体的区别可以看mmcv官网手册和博客
安装mmcv-full(也就是mmcv完整版)主要参考mmcv官网手册。
如果你要安装mmcv>=2.0.0直接根据官网手册安装即可,不再赘述。
如果你要安装历史版本,例如我安装mmcv-full==1.4.4,可以参考我的记录。
在安装mmcv前,首先要知道自己的pytorch和cuda对应版本。
查看pytorch版本:

python -c 'import torch;print(torch.__version__)'

如果输出版本信息则已经安装pytorch
查看cuda版本:
注意要查你这个环境下pytorch对应的cuda版本
例如
这是我使用nvidia-smi命令查看的cuda版本:
在这里插入图片描述
这是我使用查看pytorch对应cuda版本命令:

python -c 'import torch;print(torch.version.cuda)'

也可以写成:

参考博客:https://blog.csdn.net/qq_49821869/article/details/127700187

python

>>>import torch
>>>torch.version.cuda

在这里插入图片描述
在这里我的pytorch版本应该是1.11.0,对应cuda版本是11.3

参考博客:https://blog.csdn.net/qq_41661809/article/details/125345690

于是,我输入命令:

pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

不成功,于是我访问了这个网址查看,发现我能用的最低版本也就是1.4.7
在这里插入图片描述
于是我把命令换成了:

pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

mmcv-full安装结束

安装mmsegmentation

mmsegmentation原本我是按照官网指导安装的,
但是要求mmcv>=2.0.0,而且安装的版本是mmsegmentation==1.0.0,这和我的要求冲突了。
注意mmsegmentation要和mmcv版本匹配:

参考博客:https://blog.csdn.net/CharilePuth/article/details/122909620

在这里插入图片描述

于是我直接:

pip install mmsegmentation==0.24.0

安装成功。
“pip安装包像喝水一样简单”——曾经一位大佬如是说道。

2. 搞代码!

找模型配置文件

进入官网,在Training中找到模型对应的config文件:
在这里插入图片描述
Highlights中我知道了本文的一大亮点就是收缩结构,可以减小计算成本,因此接下来我会选择收缩结构:
在这里插入图片描述

由于我要跑的图片大小为512,因此我在这个代码的Results中找到同样512*512的COCO数据集对应模型:
在这里插入图片描述

返回configs文件夹找到这个数据集对应网络模型:
在这里插入图片描述
在这里插入图片描述
观察其代码得知所用backbone为vit_shrink,解码头为TPNATMHead:
在这里插入图片描述
注意其中的参数设置,同时还要关注__base__的配置文件,其中的参数在模型声明的时候要输入进去。

找模型代码

进入backbone文件夹下找到vit_shrink网络:
在这里插入图片描述
复制粘贴到自己的py文件中
在decode_heads文件夹下找到解码头代码:
在这里插入图片描述
复制粘贴到自己的py文件中

对代码缝缝补补

  1. 补充库文件
    库文件缺什么补什么,例如在tpn_atm_head解码器代码中需要引用另外两个解码器代码中的内容,直接把另外两个解码器的代码ctrl C+V进来,将需要使用的模块留下来即可:
    在这里插入图片描述
    在这里插入图片描述
  2. 检查输入输出
    backbone的输入和输出:
    在这里插入图片描述
    解码器部分的输入输出如图:
    在这里插入图片描述
    写一个SegViT来测试输入输出,注意参考配置文件将对应配置提前声明好:
class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        # self.backbone = get_vit_shrink()
        self.backbone = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        # if self.training:
            # return out['pred'], out['ce_aux']
        # else:
            # return out
        return out
 

运行检查out的类型

if __name__ == "__main__":
    x = torch.randn(4, 3, 512, 512)
    net = SegViT(6)
    # flops, params = profile(net, (x,))
    # print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0))
    # res, aux = net(x)
    res = net(x)
    print(res)

然后发现输出是一个字典类型,prediction是其中键名为pred对应的值,该值为tensor类型,shape大小为(4,6,512,512),输出正确。
接下来要找辅助分支的输出。
在解码器头的forward中发现:
在这里插入图片描述
将注释去掉,得到辅助分支的输出(会将辅助分支的输出结果以字典元素形式加入到atm_out中,可以调试看看),记得要把对应的初始化函数的注释也去掉:
在这里插入图片描述
其中,由于我是单卡运行,于是把SyncBN改成了BN,否则报错。
另外,训练阶段和测试阶段的输出是不一样的,可以调试检查:

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
  1. 加载权重文件
    权重文件注意可以提前下载好
def get_vit_shrink(pretrained=True, img_size=512, in_channels=1024, out_indices=[7,23]):
    model = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
    if pretrained:
        checkpoint = '权重文件所在路径'
        # if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
        # else: state_dict = checkpoint
        model.load_state_dict(checkpoint, strict=False)
    return model

最终的模型是:

class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        # checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        self.backbone = get_vit_shrink()
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
 
  1. 检查最终的输入输出
    结束。

3. 运行模型

在自己的框架里,配置参数,然后运行即可。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/23561.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

旋翼无人机常用仿真工具

四旋翼常用仿真工具 rviz: 简单的质点(也可以加上动力学姿态),用urdf模型在rviz中显示无人机和飞行轨迹、地图等。配合ROS代码使用,轻量化适合多机。典型的比如浙大ego-planner的仿真: https://github.c…

Java面试知识点(全)-分布式算法- ZAB算法

Java面试知识点(全) 导航: https://nanxiang.blog.csdn.net/article/details/130640392 注:随时更新 研究zookeeper时,必须要了解zk的选举和集群间个副本间的数据一致性。 什么是 ZAB 协议? ZAB 协议介绍 ZAB 协议全称&#xf…

树和二叉树

树 逻辑表示方法 树形表示法 文氏图表示法 凹入表示法 括号表示法 性质 树的结点数等于所有结点的度加一 度为m的树中第i层最多有m的(i-1)次方个结点 高度为h的m次树最多的节点数(等比数列公式求和&am…

【数据结构】什么是堆,如何使用无序数组生成一个堆?

文章目录 一、堆的概念及其介绍二、如何使用无序序列构建一个堆?三、C语言实现堆的基本操作结构体创建与销毁获取堆顶数据与个数及堆的判空堆的插入与删除 源代码分享 一、堆的概念及其介绍 堆(Heap)是计算机科学中一类特殊的数据结构的统称,堆通常是一…

公网远程连接Redis数据库【内网穿透】

文章目录 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接 转发自cpolar内网穿透的文章:公网远程连接…

docker构建镜像上传到DockerHub

docker构建镜像上传到DockerHub DockerHub注册账号 DockerHub网址: https://hub.docker.com/ 注册 登录 安装docker docker宿主机环境 centos7 参考网址: https://yeasy.gitbook.io/docker_practice/install/centos 测试 docker 是否安装好 docker -v登录docker 登录 dock…

Chatgpt版本的opencv安装教程

文章目录 前言一、安装opencv方法一二、安装opencv方法二 前言 最近刚买了台RTX 3070的电脑,顺手刷了个ubuntu系统专门玩Carla,为了方便查资料,也顺手搭了浏览chatgpt的环境,用的clash,还挺好用的。然后刚好在看Carla…

如何使用JQuery实现Js二级联动和三级联动

前言:使用JQuery封装好的js方法来实现二级三级联动要比直接使用js来实现二级三级联动要简洁很多。所以说JQuery是个非常强大的、简单易用的、兼容性好的JavaScript库,已经成为前端开发人员不可缺少的一部分,是Web开发中最流行的JavaScript库之…

Mysql数据库对表的基本操作

一.表基本操作 1.当前数据库内创建表 2.查看表 3.删除表 4.修改表结构 5.复制表(结构) 二.表约束创建 1.约束的作用 2.约束的类型 3.演示 一.表基本操作 1.当前数据库内创建表 CREATE TABLE 表名( 列名 列数据类型, 列名 列…

小兔鲜--项目总结3

目录 结算模块-地址切换交互实现 地址切换交互需求分析 打开弹框交互实现 地址激活交互实现 订单模块-生成订单功能实现 支付模块-实现支付功能 支付业务流程 支付模块-支付结果展示 支付模块-封装倒计时函数 理解需求 实现思路分析 会员中心-个人中心信息渲染 分页…

solr快速上手:managed-schema标签详解(三)

0. 引言 core核心是solr中的重中之重,类似数据库中的表,在搜索引擎中也叫做索引,在solr中索引的建立,要先创建基础的数据结构,即schema的相关配置,今天继续来学习solr的核心知识: solr快速上手…

OpenCV——最小外接矩形

目录 一、主要函数二、代码实现三、结果展示 一、主要函数 cv::RotatedRect cv::minAreaRect(const cv::Mat& points );emspminAreaRect 函数用于计算给定点集的最小外接矩形。该矩形的长和宽是可以任意旋转的,因此被称为旋转矩形。 points :是一个…

article-码垛机器人admas仿真

按照运动学仿真的类似步骤为机器人添加材料、运动副和关节驱动,给机器人手腕末端施加50N最大负载,仿真模型如图5-17。 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AXYQVZPq-1684936426972)(data:image/svgxml;utf8, )] 图…

Python实现ACO蚁群优化算法优化BP神经网络回归模型(BP神经网络回归算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 蚁群优化算法(Ant Colony Optimization, ACO)是一种源于大自然生物世界的新的仿生进化算法&#xff0c…

Qt自定义的ColorDialog--仿QColorDialog

Qt已经有了色板选择,但是它使用QDialog形成的,每次调用基本上都成了点一个按钮,谈一个模态框,选择好颜色之后再关掉模态框。 但是,如果想将颜色选择板放在窗口上,并不会有模态的功能就会比较麻烦&#xff…

docker安装mysql8.0.33

1 从docker仓库中拉去mysql 8.0 docker pull mysql:8.0如果使用 docker pull mysql 默认拉取的是最新版本的mysql 上面我拉去的是8.0的版本,最后拉取过来的是8.0.33 如果有想要指定的版本,可以直接写指定版本,如: docker pull my…

pytorch:nn.ModuleList和nn.Sequential、list的用法以及区别

文章目录 在构建网络的时候,pytorch有一些基础概念很重要,比如nn.Module,nn.ModuleList,nn.Sequential,这些类我们称为为容器(containers),可参考containers。本文中我们主要学习nn.…

【Python】正则表达式应用

知识目录 一、写在前面✨二、姓名检查三、解析电影排行榜四、总结撒花😊 一、写在前面✨ 大家好!我是初心,希望我们一路走来能坚守初心! 今天跟大家分享的文章是 正则表达式的应用 ,希望能帮助到大家!本篇…

把字节大佬花3个月时间整理的软件测试面经偷偷给室友,差点被他开除了···

写在前面 “这份软件测试面经看起来不错,等会一起发给他吧”,我看着面前的面试笔记自言自语道。 就在这时,背后传来了leder“阴森森”的声音:“不错吧,我可是足足花了三个月整理的” 始末 刚入职字节的我收到了大学室…

Windows 10 X64 内核对象句柄表解析

fweWindows 很多API函数都会创建和使用句柄(传入参数),句柄代表一个内核对象的内存地址,每个进程都有一个句柄表,它保存着进程拥有的句柄,内核也有一个句柄表 PspCidTable,它保存着整个系统的句柄。 ExpLookupHandleTa…