【传知代码】掩码自回归编码器法(论文复现)

前言:在探索现代数据科学的前沿领域时,掩码自回归编码器法(Masked Autoencoder,简称MAE)无疑是一个引人注目的亮点。这一技术,凭借其独特的训练机制和卓越的性能,已经在图像识别、自然语言处理以及众多其他领域展现出强大的潜力。今天,我们就来深入剖析掩码自回归编码器法的核心原理、应用前景以及它如何引领我们迈向更加智能的数据分析新时代。

本文所涉及所有资源均在传知代码平台可获取

目录

概述

演示效果

核心代码

写在最后


概述

        掩码自动编码器MAE是一款具有可扩展性的计算机视觉自我监控学习器。它可以从一个不完整或错误的图像序列中提取出感兴趣的信息来进行分类和识别,在图像处理领域得到了广泛的应用。MAE的核心策略包括:对输入图像的随机补丁进行屏蔽,并对遗失的像素进行重建,这一策略是基于两个主要的设计思路,如下:

1)一种非对称编码器-解码器架构,其中编码器只对可见的补丁子集进行操作(没有掩码标记)

2)一个轻量级解码器,它根据潜在表示和掩码标记重建原始图像

        MAE的掩码自编码器是一种简单地自编码方法,它在给定原始信号的部分观测值的情况下重建原始信号。和所有的自编码器一样,MAE有一个将观察到的信号映射到潜在表示的编码器,以及一个从潜在表示重建原始信号的解码器,如下图所示:

与经典的自编码器不同,MAE采用了一种非对称设计,允许编码器仅对部分观察到的信号进行操作(没有掩码标记),并采用了一个轻量级解码器,该解码器根据潜在表示和掩码标记重建全部信号,以下是对相关概念的讲解:

掩码

        与ViT方法相似,MAE将图像分为有规律的非重叠部分,然后MAE对补丁的子集进行取样,并对剩下的补丁进行屏蔽或移除。对于每个候选补丁而言,它是由一系列不同大小和位置分布的样本构成的集合,因此可以用多个阈值来确定这些样本的权重。MAE的取样方法相当直接,它对补丁进行随机取样,不做替换,并按照均匀的方式分布。为了提高效率和减少计算量,提出一种基于均匀抽样技术的快速修复算法。随机采样具有较高的掩码比(即去除补丁的比例),这在很大程度上减少了数据冗余。因此,这导致了一个问题,即不能仅通过从相邻的明显补丁中进行外推来解决。通过均匀分布,可以避免中心偏差(即图像中心附近的掩码补丁越多)。最后一个高度稀疏的输入为设计一个高效的编码器提供了可能性。

MAE编码器

        MAE中编码器为ViT,但仅适用于可见和不屏蔽补丁。与标准ViT类似,MAE中编码器也是通过增加位置嵌入线性投影嵌入补丁再经过一系列Transformer块对结果集进行处理。然而,MAE的编码器只对全集的一小部分(例如25%)进行操作。带MAE可以只利用小部分计算与内存,就可以训练出很大的编码器。

MAE解码器

        MAE解码器输入为编码器可见补丁与掩码令牌构成一个完整令牌集合。每一个掩码标记为共享和学习向量,表示是否有丢失补丁需要被预测。MAE将位置嵌入添加到该全集中的所有令牌中,如果没有这一点,掩码令牌将没有关于其在图像中的位置信息。MAE解码器仅在预训练期间用于执行图像重建任务(仅使用编码器生成识别用图像表示。)因此,可以以独立于编码器设计的方式灵活地设计解码器架构。

重建目标

        MAE通过预测每个掩码补丁的像素值来重建输入,解码器输出中的每个元素是表示补丁的像素值的矢量。解码器末层为线性投影且输出通道个数与块内像素值个数相等。重建解码器输出,形成重建图像。在像素空间中,MAE的损失函数用于计算重建图像与原始图像的均方误差(MSE),这与BERT是一致的,而MAE仅用于计算掩码补丁上的损失。MAE也研究了以各屏蔽补丁归一化像素为重构对象的变体。具体而言,MAE在一个Patch上计算所有像素的均值与标准差并用其归一化这个patch。以归一化像素为重构对象,改善表示质量。

简单实现

        首先,MAE为每个输入补丁生成一个标记(通过添加位置嵌入的线性投影),接下来,MAE随机打乱令牌列表,并根据屏蔽比率删除列表的最后一部分。这个过程为编码器生成一小部分标记,相当于采样补丁而不进行替换。编码后,MAE将一个掩码令牌列表添加到编码补丁列表中,并对这个完整列表纪念性unshuffle(反转随机混洗操作),以将所有标记与其目标对齐。编码器应用于该完整列表(添加了位置嵌入)。如前所述,不需要稀疏运算,这种简单地实现引入了可忽略不计的开销,因为混洗和取消混洗操作很快。

该编码方式参考如下论文内容,地址 :

演示效果

通过如下的方式对项目进行相关部署:

#  linux系统下python=3.7
conda create -n mae python=3.7
conda activate mae

# 下载torch
wget https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp37-cp37m-linux_x86_64.whl
pip install 'torch的下载地址'
# 下载torchvision
wget https://download.pytorch.org/whl/cu116/torchvision-0.14.0%2Bcu116-cp37-cp37m-linux_x86_64.whl
pip install 'torchvision的下载地址'

pip install timm==0.4.5
pip install ipykernel
pip install matplotlib
pip install tensorboard

MAE随机掩码图像实现的效果如下:

核心代码

随机掩码的实现逻辑如下:

def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        # 确定需要保存多少个patch
        len_keep = int(L * (1 - mask_ratio))
        # [1,196] 用batch此时输入的图片可能不止一个,196表示patch的个数
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]      
        # sort noise for each sample 
        # 默认按升序排序,此时返回的是序号,首先获取从低到高排列的序号
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        # 获取ids_shuffle从低到高排列的序号,这样就能还原原始的noise的情况
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep] # 保存数据少的情况
        # [1,49,1024] dim=0 按列进行索引,dim=1按行进行索引,获取x的取值
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask 为0表示没有被掩码,1表示被掩码
        # 将是否被掩码通过mask表示出来
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

编码器的实现逻辑如下:

def forward_encoder(self, x, mask_ratio):
        # embed patches [1,3,224,224]->[1,196,1024]
        x = self.patch_embed(x)

        # add pos embed w/o cls token 除了全局特征,全部加上了位置信息
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        # id_restore保存的是原来的位置
        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1) # [1,1,1024]
        # [1,50,1024] 要包含一个class的情况
        x = torch.cat((cls_tokens, x), dim=1) 

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

解码器的实现逻辑如下:

def forward_decoder(self, x, ids_restore):
        # embed tokens [1,50,1024]->[1,50,512]
        x = self.decoder_embed(x)

        # append mask tokens to sequence 获取被掩码的token [1,147,512]
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        # 将经过编码的数据和原始的初始化为0的数据编码在一起。
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        # 将编码的和为编码的重新转变为原始的patch大小,其实本质上只需要考虑编码的位置,因为其余都是随机初始化的
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection 将其转换为所有像素
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

写在最后

        掩码自回归编码器法(Masked Autoencoder, MAE)作为一种前沿的深度学习架构,尤其在处理大规模数据集和复杂特征空间时展现出了其独特的优势。通过对输入数据进行部分掩码(即随机遮盖部分输入),MAE迫使模型从剩余的可见数据中预测被遮盖的部分,这种自监督的学习方式有效地提高了模型的泛化能力和鲁棒性。然而,MAE也面临着一些挑战和限制。例如,如何确定最佳的掩码比例和策略仍然是一个开放的问题。此外,MAE在处理某些特定任务时可能不如其他方法有效,这需要根据具体任务和数据集进行选择和调整,我们有望进一步优化和完善MAE的性能,并将其应用于更加广泛和复杂的任务中。

详细复现过程的项目源码、数据和预训练好的模型可从该文章下方附件获取。

【传知科技】关注有礼     公众号、抖音号、视频号

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/640896.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《我的阿勒泰》观后感(二、返璞归真也是一种美)

看了李娟的小说《我的阿勒泰》逐渐悟到一个道理,返璞归真也是一种美,没必要每个人的人生三十年的年华,都去追求房子,车子等逐渐贬值的东西。人究竟应该追求怎样的一种活法? 什么是城市化?这是我听到的最好…

osgearth 3.5 vs 2019编译

下载源码 git clone --recurse-submodules https://github.com/gwaldron/osgearth.git 修改配置文件 主要是修改bootstrap_vcpkg.bat,一处是vs的版本,第二处是-DCMAKE_BUILD_TYPERELEASE 构建 执行bootstrap_vcpkg.bat vs中生成安装 vs2019打开bu…

spring boot打的包直接运行

Spring Boot 提供了一个插件 spring-boot-maven-plugin 把程序打包成一个可执行的jar包&#xff0c;直接执行java -jar xxx.jar即可以启动程序 1、引用 spring-boot-maven-plugin插件 <build><plugins><plugin><groupId>org.springframework.boot<…

LED显示屏的智能化发展与未来趋势

摘要&#xff1a;随着智能化技术的飞速发展&#xff0c;LED显示屏行业也迎来了新的变革。本文将探讨LED显示屏的智能化发展方向&#xff0c;包括人屏互动、大屏中控智能化&#xff0c;以及智能LED显示屏在不同领域的应用前景。 1、引言 在智能化浪潮的推动下&#xff0c;LED显示…

GPT-4o: 未来的智能助手

GPT-4o: 未来的智能助手 在这个信息爆炸的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经成为我们生活中不可或缺的一部分。作为OpenAI最新推出的语言模型&#xff0c;GPT-4o不仅继承了前几代模型的优点&#xff0c;还在多个方面进行了显著的提升。本文将带你深入了解…

家政预约小程序03分类管理

目录 1 创建数据源2 搭建导航菜单3 搭建小程序4 设置变量5 变量绑定总结 家政预约小程序里&#xff0c;在首页需要展示家政可以开展的各类业务。我们把业务按照类别进行划分&#xff0c;本篇我们介绍一下管理后台的维护功能以及小程序的展示功能。 1 创建数据源 为了管理和展示…

WiFi蓝牙模块开发配置过程中需要注意的细节

在很多产品的应用场景中&#xff0c;WIFI网络会给我们提供很多便捷&#xff0c;MCU开发中大多使用串口WIFI蓝牙模块来实现产品接入WIFI网络中。   具体的使用模型如下图所示&#xff1a;整个系统涉及到WIFI网络、手机、服务器平台以及我们设计的产品&#xff0c;一个完整的生…

uniapp+php服务端实现苹果iap内购的消耗性项目和非续期订阅项目,前后端代码加逻辑分析

前言&#xff1a;公司的项目app在上架苹果商店时发现人家要求里面的部分购买项目必须使用iap购买的方式&#xff0c;使用原本的微信支付方式审核不给通过&#xff0c;无奈只能重新研究这个东西。做起来还是有点麻烦&#xff0c;主要是网上的文章很少&#xff0c;不能直接硬抄。…

彩信JSON接口对接发送

随着通讯技术的飞速发展&#xff0c;传统的短信已经无法满足人们日益增长的沟通需求。在这样的背景下&#xff0c;群发彩信作为一种更为先进、更为丰富的信息传递方式&#xff0c;逐渐受到了企业和个人的青睐。那么&#xff0c;群发彩信应该怎么对接&#xff0c;又具体有哪些优…

经常碰到的20个等待事件

经常碰到的20个等待事件 oracle等待事件简介 DBA团队维护的部分应用运行在oracle数据库平台&#xff0c;为及时了解数据库的运行情况&#xff0c;需要建立涵盖各个维度的监控体系&#xff0c;包括实例状态、空间使用率、ORA错误等数十项监控指标。这其中有一个有效判断数据库…

Parasoft C++Test软件静态分析操作指南_软件质量度量

系列文章目录 Parasoft CTest软件安装指南 Parasoft CTest软件静态分析操作指南_编码规范/标准检查 Parasoft CTest软件静态分析操作指南_软件质量度量 Parasoft CTest软件静态分析_自动提取静态分析数据生成文档 Parasoft CTest软件单元测试_操作指南 Parasoft CTest软件单元…

Mqtt_Java_IDEA中编写“发布者”和“订阅者”

1Java创建项目 2导入依赖 将下面Mqtt的库名复制到 <dependencies> 下面 <dependency><groupId>org.eclipse.paho</groupId><artifactId>org.eclipse.paho.client.mqttv3</artifactId><version>1.2.5</version></d…

20212416 2023-2024-2 《移动平台开发与实践》第5次作业

百度地图应用 1.实验内容2.实验过程2.1 Android Studio配置2.1. 创建一个Android项目2.2 在项目中本地集成BaiduMap SDK 2.2 编写代码2.2.1 配置AndroidManifest.xml文件2.2.2 编写UI界面布局文件2.2.3 编写主函数代码2.2.4 运行结果 3.学习中遇到的问题及解决4.学习感悟与思考…

netdiscover一键收集子网内的所有信息(KALI工具系列六)

目录 1、KALI LINUX简介 2、netdiscover工具简介 3、在KALI中使用netdiscover 3.1 目标主机IP&#xff08;win&#xff09; 3.2 KALI的IP 4、命令示例 4.1 扫描子网整个网段 4.2 指定网卡进行扫描 4.3 扫描网卡的公共网络 4.4 快速扫描网卡的公共lan地址 4.5 设置…

网络拓扑—DHCP服务配置

文章目录 DHCP服务搭建相关配置细节前提安装DHCP服务 DHCP服务搭建 相关配置细节前提 系统&#xff1a;Windows Server 2003 IP网段&#xff1a;10.0.0.0/24 三台机子&#xff1a; 普通PC机 DHCP服务器 路由器&#xff08;两块网卡&#xff0c;连接内外网&#xff09; //注…

Java进阶学习笔记6——继承的介绍

继承的学习目标&#xff1a; 认识继承&#xff1b; 继承的好处、应用场景 什么是继承&#xff1f; Java中提供了一个关键字extends&#xff0c;用这个关键字&#xff0c;可以让一个类和另外一个类建立父子关系。 继承的特点: 子类能继承父类的非私有成员&#xff08;成员变…

利用sql注入对某非法网站的渗透

本文仅用于技术讨论&#xff0c;切勿用于违法途径&#xff0c;且行且珍惜&#xff0c; 所有非经授权的渗透&#xff0c;都是违法行为 前言 这段时间一直在捣鼓sql注入&#xff0c;最近又通过一个sql注入点&#xff0c;成功进入某个非法网站的后台&#xff0c;拿到整个网站的…

mac版本Phpstudy本地环境安装Discuz教程【2024】

此方法适用于m1版本的mac版本Phpstudy本地环境安装Discuz&#xff0c;当然同样使用更高版本的mac端。网上各种安装教程参差不齐&#xff0c;根本解决不了小白的入门需求&#xff0c;以下是最新且直接明了的安装教程。 Phpstudy本地环境安装Discuz教程&#xff1a; 1、安装Phps…

渗透测试 一个很奇怪的支付漏洞

新手实战刷课网站、好玩又有趣&#xff01; 第一步 打开网站、任意账户名密码登陆发现验证码可重复利用 这时候我们可以试试admin账号、发现如果账号正确会提示账户已存在、反之回显账户密码错误 第二步 既然验证码可以重复利用&#xff1b;而且账号名有回显 这时候我们试…

安装harbor出现问题: Running 1/1 ✘ Network harbor_harbor Error

安装harbor出现问题&#xff1a; [] Running 1/1 ✘ Network harbor_harbor Error 0.2s failed to create network harbor_harbor: Error response from daemon: Fa…