pytorch实现胶囊网络(capsulenet)

胶囊网络在hinton刚提出来的时候小热过一段时间,之后热度并没有维持多久。vision transformer之后基本少有人问津了。不过这个模型思路挺独特的,值得研究一下。

这个模型的提出是为了解决CNN模型学习到的特征之间没有空间上的关系,从而对于各种变换不鲁棒的缺点。

模型的整体思路如下:

1,胶囊:

抛开论文里花哨的描述,胶囊其实就是特征图上比点更大的单元,本质上我觉得类似transformer的patch。当然也有一定的差别,因为后续要用动态路由更新胶囊,所以胶囊必须要是向量,而不是标量。

2,动态路由:

由于pooling会导致信息丢失,作者使用动态路由来连接两个胶囊层,并更新胶囊。

同时,动态路由也能建立不同层胶囊(特征)在空间上的相对关系。

由于胶囊其实是向量,动态路由算法会根据这些向量的相似性(点积)和一致性(加权)来决定信息传递的路径。

3,整体结构:

1)卷积层

2)PrimaryCaps层:这层的作用就是把卷积特征转变成胶囊的形式

3)DigitCaps层:用动态路由迭代生成高层的胶囊。

4)解码器

4,loss

胶囊网络的损失函数主要由两部分组成:间隔损失(Margin Loss)和重构损失。

在计算间隔损失时,会使用一个阈值(通常设置为0.9和0.1)来区分正样本和负样本。如果某一类的胶囊输出向量的模长大于阈值m+(正样本阈值,例如0.9),则认为该类存在,并将其视为正样本;反之,如果输出向量的模长小于阈值m-(负样本阈值,例如0.1),则认为该类不存在,将其视为负样本。

重构损失的计算通常基于原始输入数据与重构数据之间的差异,例如使用均方误差(MSE)来衡量这种差异。

如果站在2024年的如今再来看当初的设计,其实胶囊的思路还是很像后来的transformer的,有点殊途同归的感觉。


pytorch实现:

1,实现初始胶囊

首先是会用到的压缩函数,压缩函数的作用是将向量的长度压缩到0和1之间,同时保留向量的方向不变。

公式:

def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2 + 1e-8) / (norm + 1e-8)
    return scale * inputs

初始胶囊,这一层的作用是将卷积特征转换为胶囊的形式。

class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.reshape(x.size(0), -1, self.dim_caps)
        return squash(outputs)

2,实现胶囊层

路由算法

这个伪代码初看起来挺乱的,我翻译成人话如下:

首先,每一次迭代由两层胶囊层做点积后再通过softmax计算出耦合系数c。

耦合系数和下层胶囊的预测计算加权和,这是个投票的过程。

再通过压缩函数,就得到了本层的胶囊v。

因为这是个迭代的过程,需要不断更新耦合系数C。

新的耦合系数由两层胶囊之间的相似度决定。


具体实现中,会对低层胶囊先做一个变换,也就是下面代码里的weight。这个权重矩阵代表的是对下层胶囊的变化,变换之后的结果Ui|j用论文里的话说叫做“prediction vectors”。

胶囊层代码:

class DenseCapsule(nn.Module):
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings #路由的迭代次数
        #初始化
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        u_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
        #从当前计算图中分离出x_hat,这样在后续的反向传播中不会计算其梯度 
        u_hat_detached = u_hat.detach()
        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps).cuda()
        #路由算法
        for i in range(self.routings):
            c = F.softmax(b, dim=1)
            if i == self.routings - 1:
                v = squash(torch.sum(c[:, :, :, None] * u_hat, dim=-2, keepdim=True))
            else:
                v = squash(torch.sum(c[:, :, :, None] * u_hat_detached, dim=-2, keepdim=True))
                b = b + torch.sum(v * u_hat_detached, dim=-1)

        return torch.squeeze(v, dim=-2)

需要将的是u_hat_detached = u_hat.detach()这一步。将u_hat从计算图中分离出来的目的,是为了防止迭代过程中梯度不断累积,导致梯度过大。所以我们可以在后续的路由算法中看出,只有在最后一次计算路由时使用了u_hat,之前的迭代中都是使用的u_hat_detached。从而让整个路由过程中梯度只更新一次。

3,损失函数

def caps_loss(y_true, y_pred, x, x_recon, lambd=0.5):
    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()

    L_recon = nn.MSELoss()(x_recon, x)

    return L_margin + lambd * L_recon

4,整体模型

模型返回两个值,一个是预测的概率,一个是重建的图像。这两个值会分别用来计算间隔损失和重构损失。

class CapsuleNet(nn.Module):
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings
        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)

        self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)
        if y is None:
            index = length.max(dim=1)[1]
            y = torch.zeros(length.size()).scatter_(1, index.view(-1, 1), 1.)
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        return length, reconstruction.view(-1, *self.input_size)

5,注意事项:

1)one-hot

在重建过程中使用的标签y是one-hot形式的,因此在训练和测试时需要加上这行代码,转换一下

targets = F.one_hot(targets, num_classes=classes).to(device)

2) loss

训练和测试时的loss设置如下

loss = caps_loss(y_true=targets,y_pred=y_pred,x=imgs,x_recon=x_recon,lambd=0.5)
        loss = loss.to(device)

其中lambd这个系数决定的是重构损失所占的比例 loss=margin_loss+lambd*recon_loss

总结:

胶囊网络分类结果不算差,在我的一些任务中train from scratch的胶囊网络就超越了imagenet1k上预训练过再finetune的vit。也超过了无预训练的VGG和resnet。(但是不如预训练过的vgg和resnet)。

这样的表现放在2017年已经很能打了,没火的原因我感觉有3个:

首先,由于胶囊网络迭代过程需要多次完整的特征图点乘特征图,所以内存消耗和时间消耗都是巨大的。我跑256的图时,24g显存的4090也只能把batch设置成5,运行速度非常慢。放在2017年,只能用1080ti来跑这个模型,简直折磨。(我2018年时也试过这个模型,训练都是按周算的,这谁愿意用啊)

另外一个原因可能是它的改进潜力不大。例如vit的核心机制是自注意力,注意力大家都玩出花来了,各种改进思路都很好借鉴。虽然vit效果很一般,但是后续的改进模型一个比一个厉害。而胶囊网络的核心路由算法想要创新就比较难。

最后还有一点就是原作者没放出胶囊网络在imagenet上的预训练模型。这个对模型热度的影响其实挺大的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/535773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Sketch3D:用于草图到3D生成的样式一致性指南

Sketch3D: Style-Consistent Guidance for Sketch-to-3D Generation Sketch3D:用于草图到3D生成的样式一致性指南 Wangguandong Zheng 重试 错误原因 Southeast UniversityChina 重试 错误原因 wgdzhengseu.edu.cnHaifeng Xia 重试 错误原因 Southeast Universit…

CSS - 盒子模型、图片模糊、过渡效果、2D图移动、放大缩小、CSS动画、flex布局

盒子模型 CSS盒子模型是指在网页布局中,每个元素都被看作是一个矩形的盒子,这个盒子由内容区域、内边距、边框和外边距组成。盒子模型在CSS中用于确定元素在页面中的尺寸、位置和边距。 盒子模型由以下几个部分组成: 内容区域(…

行云堡垒国密算法应用与信创支持

一、 国密算法和信创的介绍 1.1 什么是国密算法 国密算法是国家密码管理局制定颁布的一系列的密码标准,即已经被国家密码局认定的国产密码算法,又称商用密码(是指能够实现商用密码算法的加密,解密和认证等功能的技术)…

Qlik Sense : Crosstable在数据加载脚本中使用交叉表

什么是Crosstable? 交叉表是常见的表格类型,特点是在两个标题数据正交列表之间显示值矩阵。如果要将数据关联到其他数据表格,交叉表通常不是最佳数据格式。 本主题介绍了如何逆透视交叉表,即,在数据加载脚本中使用 L…

批归一化(BN)在神经网络中的作用与原理

文章目录 1. 批归一化(BN)在神经网络中的作用与原理1.1 作用与优势1.2 原理与推导 2. 将BN应用于神经网络的方法2.1 训练时的BN 2. 将BN应用于神经网络的方法2.1 训练时的BN2.2 测试时的BN代码示例(Python): 3. BN的优…

机器学习-09-图像处理01-理论

总结 本系列是机器学习课程的系列课程,主要介绍机器学习中图像处理技术。 参考 02图像知识 色彩基础知识整理-色相、饱和度、明度、色调 图像特征提取(VGG和Resnet特征提取卷积过程详解) Python图像处理入门 【人工智能】PythonOpenCV…

基于python的天气数据可视化系统、Flask框架,爬虫采集天气数据,可视化分析

系统介绍 基于Python的天气预测可视化分析系统,该项目的主要流程和功能包括: 数据获取: 使用Python的pandas库从2345天气网(http://tianqi.2345.com/Pc/GetHistory)抓取山东省各市区县在2021年至2023年间的天气历史数…

【方法】PDF密码如何取消?

对于重要的PDF文件,很多人会设置密码保护,那后续不需要保护了,如何取消密码呢? 今天我们来看看,PDF的两种密码,即“限制密码”和“打开密码”,是如何取消的,以及忘记密码的情况要怎…

文献学习-33-一个用于生成手术视频摘要的python库

VideoSum: A Python Library for Surgical Video Summarization Authors: Luis C. Garcia-Peraza-Herrera, Sebastien Ourselin, and Tom Vercauteren Source: https://arxiv.org/pdf/2303.10173.pdf 这篇文章主要关注的是如何通过视频摘要来简化和可视化手术视频&#xff0c…

计算机基础知识-第4章-真值表和逻辑运算、位运算

一、真值表与逻辑运算 真值表 真值表是什么呢?我们来看百度百科的定义。表征逻辑事件输入和输出之间全部可能状态的表格。列出命题公式真假值的表。通常以1表示真,0 表示假。命题公式的取值由组成命题公式的命题变元的取值和命题联结词决定,…

开源监控zabbix对接可视化工具grafana教程

今天要给大家介绍的是开源监控工具zabbix对接可视化工具grafana问题。 有一定运维经验的小伙伴大抵都或多或少使用过、至少也听说过开源监控工具zabbix,更进一步的小伙伴可能知道zabbix在数据呈现方面有着明显的短板,因此需要搭配使用第三方的可视化工具…

Qlik Sense :use Peek function to Group by and Get Rowno

Question Row number based on groups of data Calculate row number for groups 有时候我们需要基于分组来对数据进行内部排序,例如一个iddate,把不同的属性的记录标记为123,又或者把重复记录标记出来 Solved: Calculate row number for…

MacOS安装openMP报错【已解决】

error: Target “WLBG” links to: OpenMP::OpenMP_CXX but the target was not found. Possible reasons include: * There is a typo in the target name. * A find_package call is missing for an IMPORTED target. * An ALIAS target is missing. 最开始是报这个错&#x…

云上配置Hadoop环境

Hadoop概述 Hadoop技术主要是由下面这三个组件组合而成的: HDFS是一个典型的主从模式架构。 HDFS的基础架构 HDFS的集群搭建 一点准备工作 其实这一块没啥内容,就是将Hadoop官网下载下来的Hadoop的tar包上传到我们服务器上的文件目录下: …

2024考研调剂须知

----------------------------------------------------------------------------------------------------- 考研复试科研背景提升班 教你快速深入了解掌握考研复试面试中的常见问题以及注意事项,系统的教你如何在短期内快速提升自己的专业知识水平和编程以及英语…

Vue ElementUI el-input-number 改变控制按钮 icon 箭头为三角形

el-input-number 属性 controls-position 值为 right 时&#xff1b; <el-input-number v-model"num" controls-position"right" :min"1" :max"10"></el-input-number>原生效果 修改后效果 CSS 修改 .el-input-number…

点亮一颗 LED: 单片机 ch32v003 (RISC-V) 使用 rust 编写固件

首发日期 2024-04-09, 以下为原文内容: 使用 rust 编写单片机的程序 ? 很新, 但没问题. 使用 RISC-V CPU 的单片机 (比如 ch32v003) ? 也没问题. 同时使用 ? 哦嚯, 问题出现了 !! ch32v003 是一款使用 rv32ec 指令集的国产单片机, 很便宜 (某宝零卖只要 0.4 元一个, 在同档…

学习JavaEE的日子 Day33 File类,IO流

Day33 1.File类 File是文件和目录路径名的抽象表示 File类的对象可以表示文件&#xff1a;C:\Users\Desktop\hhy.txt File类的对象可以表示目录路径名&#xff1a;C:\Users\Desktop File只关注文件本身的信息&#xff08;文件名、是否可读、是否可写…&#xff09;&#xff0c…

基于SSM的电影网站(有报告)。Javaee项目。ssm项目。

演示视频&#xff1a; 基于SSM的电影网站&#xff08;有报告&#xff09;。Javaee项目。ssm项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;通过Spring SpringMv…

MySQL 全文检索

不是所有的数据表都支持全文检索 MySQL支持多种底层数据库引擎&#xff0c;但是并非所有的引擎支持全文检索 &#xff0c;目前最常用引擎是是MyISAM和InnoDB&#xff1b;前者支持全文检索&#xff0c;后者不支持。 booolean模式操作符 操作符含义必须有-必须不包含>包含对应…