5、MAE:探索视觉预训练模型

目录

1、论文

2、背景与动机

3、回答的问题

4、创新与卖点

5、实现细节

模型框架

具体步骤

简单代码示例

6、一些资料


1、论文

Masked Autoencoders Are Scalable Vision Learnersicon-default.png?t=N7T8https://arxiv.org/pdf/2111.06377.pdf

2、背景与动机

        在深度学习和计算机视觉的领域中,预训练模型已经成为了提高下游任务性能的重要手段。传统上,许多预训练模型如ResNet、VGG等都是在大规模数据集(如ImageNet)上通过监督学习训练得到的。然而,监督学习需要大量的标记数据,这在成本和可扩展性上都是一个不小的挑战。

        最近,自监督学习作为一个新兴研究领域,提供了一种无需手工标注数据的解决方案。自监督学习的一个关键点是设计预测任务,通过这些任务模型可以从输入数据本身学习到有用的表示。在自然语言处理(NLP)领域,BERT通过掩码语言模型(MLM)任务表现出色,这激发了计算机视觉领域对类似方法的探索。

        MAE (Masked Autoencoder) 正是从这样的背景和动机出发,它将自监督学习中的掩码预测任务引入到视觉领域,致力于从图像数据中以无监督的方式学习高效的特征表示。

3、回答的问题

        论文中回答了一个问题。为什么自监督在CV领域的发展要滞后于NLP呢?论文中给了两个解释:

(1)NLP主流方法是Transformer,视觉里CNN是主流方法,结构差异让视觉很难构造类似于“masked autoencoding”的任务。但是ViT的提出解决了这个问题;

(2)语言和视觉的信息密度(information density)差异巨大,前者是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding),但是对视觉图像来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断

        所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效

4、创新与卖点

MAE 的核心创新在于其独特的自监督预训练方法。不同于之前的自监督视觉模型通常需要对比学习或复杂的数据增强,MAE 提出了一种简洁高效的方法:

  1. Masking 策略:MAE 对输入图像进行随机遮蔽,只露出一小部分像素,模型的任务是预测被遮蔽部分的原始像素。这种策略减少了模型需要处理的数据量,同时迫使模型学习丰富的上下文信息来重建图像。

  2. 编码器-解码器架构:MAE 采用了一个不对称的编码器-解码器架构,其中编码器只对未被遮蔽的部分进行处理,大幅减少了计算量。解码器则负责图像的重建工作,它的结构相对简单,因为其主要任务是理解编码器提供的特征。

  3. 预训练与微调:MAE 的预训练阶段不依赖于标签,这使得模型可以在非常大的数据集上进行训练。一旦预训练完成,MAE 可以通过微调在各种下游任务上实现优异的性能,包括分类、检测和分割等。

5、实现细节

模型框架

具体步骤

  1. 数据遮掩:首先,在输入图像或序列数据中随机选择一定比例的区域进行遮掩,将其替换为特定的遮掩标记(如0或[MASK])。

  2. 编码阶段:仅将未遮掩的数据部分输入到一个轻量级的Transformer编码器中,以提取局部上下文特征。

  3. 解码阶段:将编码后的向量传递给一个解码器,该解码器通常也是一个Transformer,但会对所有像素或位置进行解码预测,恢复出被遮掩部分的信息。

  4. 损失函数:使用L1或L2距离作为损失函数,衡量预测的像素值或词向量与原始未遮掩数据之间的差异。

  5. 预训练与微调:经过大规模无标签数据上的预训练后,可以将模型参数迁移到特定的下游任务中进行微调,进一步提升任务性能。

简单代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    # 用于添加位置信息的模块,通常在Transformer结构中使用
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Encoder(nn.Module):
    def __init__(self, embed_dim, num_layers, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),
                                      dropout=drop_rate, attention_dropout=attn_drop_rate, bias_qkv=qkv_bias)
            for _ in range(num_layers)])

    def forward(self, src, mask=None):
        output = src
        for layer in self.layers:
            output = layer(output, src_key_padding_mask=mask)
        return output

class MaskedAutoencoder(nn.Module):
    def __init__(self, image_size, patch_size, num_channels, embed_dim, num_layers, num_heads, mlp_ratio, num_classes):
        super(MaskedAutoencoder, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (image_size // patch_size) ** 2
        self.encoder = nn.Sequential(
            nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.LayerNorm(embed_dim),
        )
        self.pos_embed = PositionalEncoding(embed_dim)
        self.transformer_encoder = Encoder(embed_dim, num_layers, num_heads, mlp_ratio)
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, num_channels * patch_size ** 2),
            nn.PixelShuffle(patch_size),
        )
        self.to_patch_embedding = nn.Sequential(
            nn.Unflatten(dim=1, unflattened_size=(num_patches, embed_dim)),
            nn.Dropout(p=0.1),
        )

    def forward(self, x, mask_ratio=0.75):
        B, C, H, W = x.shape
        assert H == W, "Input image must be square"
        x = self.encoder(x)
        x = self.pos_embed(x)
        
        # 随机掩码
        rand_mask = torch.rand(B, self.num_patches, 1, 1, device=x.device) < mask_ratio
        masked_x = x.clone()
        masked_x[rand_mask] = 0.

        # 编码
        encoded_patches = self.transformer_encoder(self.to_patch_embedding(masked_x))

        # 解码
        reconstructed_image = self.decoder(encoded_patches)

        return reconstructed_image

# 初始化模型
model = MaskedAutoencoder(image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4., num_classes=0)

# 假设我们有输入数据x
x = torch.randn((10, 3, 224, 224))

# 计算重构后的图像
reconstruction = model(x)

6、一些资料

MAE(Masked Autoencoders) - 知乎简介MAE(Masked Autoencoders)是用于CV的自监督学习方法,优点是扩展性强的(scalable),方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。MAE基于两个核心设计:(1)不对称的(…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/446761025

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/312331.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【uniapp】新课uniapp零基础入门到项目打包(微信小程序/H5/vue/安卓apk)全掌握

一、uniapp和HBuilderX介绍 uni-app官方网站&#xff1a;https://uniapp.dcloud.net.cn/ 为什么要学习uniapp&#xff1f; 1、一套代码可以打包到不同的应用平台&#xff1b;一套代码编到十几个平台&#xff0c;这不是梦想。眼见为实&#xff0c;扫描以下二维码&#xff0c;…

2023年全国职业院校技能大赛软件测试赛题—单元测试卷⑧

单元测试 一、任务要求 题目1&#xff1a;根据下列流程图编写程序实现相应处理&#xff0c;执行j10*x-y返回文字“j1&#xff1a;”和计算值&#xff0c;执行j(x-y)*(10⁵%7)返回文字“j2&#xff1a;”和计算值&#xff0c;执行jy*log(x10)返回文字“j3&#xff1a;”和计算值…

Qt / day04

1. 思维导图 2. 编写程序实现闹钟如下 文件 代码 clock.h #include "clock.h" #include "ui_clock.h"Clock::Clock(QWidget *parent): QWidget(parent), ui(new Ui::Clock), ptrTimerShow(new QTimer(this)), ptrTimerAlarm(new QTimer(this)), ptrSpeec…

Spark六:Spark 底层执行原理SparkContext、DAG、TaskScheduler

Spark底层执行原理 学习Spark运行流程 学习链接&#xff1a;https://mp.weixin.qq.com/s/caCk3mM5iXy0FaXCLkDwYQ 一、Spark运行流程 流程&#xff1a; SparkContext向管理器注册并向资源管理器申请运行Executor资源管理器分配Executor&#xff0c;然后资源管理器启动Execut…

iOS rootless无根越狱解决方案

据游戏工委数据统计&#xff0c;2023年国内游戏市场实际销售收入与用户规模双双创下新高&#xff0c;游戏普遍采用多端并发方式&#xff0c;成为收入增长的主因之一。 中国市场实际销售收入及增长率丨数据来源&#xff1a;游戏工委 多端互通既是机遇&#xff0c;也是挑战。从游…

使用 Docker 和 Diffusers 快速上手 Stable Video Diffusion 图生视频大模型

本篇文章聊聊&#xff0c;如何快速上手 Stable Video Diffusion (SVD) 图生视频大模型。 写在前面 月底计划在机器之心的“AI技术论坛”做关于使用开源模型 “Stable Diffusion 模型” 做有趣视频的实战分享。 因为会议分享时间有限&#xff0c;和之前一样&#xff0c;比较简…

MySQL篇—通过Clone插件进行远程克隆数据(第三篇,总共三篇)

在介绍 Clone 最终篇之前&#xff0c;我们先简要回顾一下前面所讲的内容。在第一篇中&#xff0c;我们探讨了 Clone 的用途、使用的前提条件、存在的限制&#xff0c;以及它的备份原理。Clone 是一种用于复制和备份数据的工具&#xff0c;它能够快速高效地创建数据的精确副本。…

超维空间M1无人机使用说明书——53、ROS无人机二维码识别与降落——V2升级版本

引言&#xff1a;使用二维码引导无人机实现精准降落&#xff0c;首先需要实现对二维码的识别和定位&#xff0c;可以参考博客的二维码识别和定位内容。本小节主要是通过获取拿到的二维码位置&#xff0c;控制无人机全向的移动和降落&#xff0c;本小节再V1版本的基础上增加了动…

【AI视野·今日NLP 自然语言处理论文速览 第七十二期】Mon, 8 Jan 2024

AI视野今日CS.NLP 自然语言处理论文速览 Mon, 8 Jan 2024 Totally 17 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computation and Language Papers DeepSeek LLM: Scaling Open-Source Language Models with Longtermism Authors DeepSeek AI Xiao Bi, Deli Ch…

【修图】AI修图工具

人脸替换 免费的人脸替换工具&#xff1a; Face Swap&#xff1a; https://vmodel.ai/face-swap 支持单人换脸、多人换脸 AI消除 SnapEdit https://snapedit.app/remove-object 不付费的话只能下载清晰度较低的版本 但我试了几个在线的AI消除工具&#xff0c;SnapEdit算是…

最新版docker-compose安装

Ubuntu/Kali 下载安装最新版 docker-compose # FastGit加速 sudo curl -L "https://hub.fgit.cf/docker/compose/releases/download$(curl -L -i -s -o /dev/null -w "%{url_effective}\n" https://hub.fgit.cf/docker/compose/releases/latest | awk -F tag …

docker-compose部署kafka、SASL模式(密码校验模式)

一.基础kafka部署 zookeeper&#xff0c;kafka&#xff0c;kafka-ui docker-compose.yml 注意点&#xff1a;192.168.1.20 是宿主机的ip version: "3" services:zookeeper:image: wurstmeister/zookeepercontainer_name: zookeeperrestart: alwaysports:- 2181:2…

【python】python新年烟花代码【附源码】

欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 新年的钟声即将敲响&#xff0c;为了庆祝这个喜庆的时刻&#xff0c;我们可以用 Python 编写一个炫彩夺目的烟花盛典。本文将详细介绍如何使用 Pygame 库创建一个令人惊叹的烟花效果。 一、效果图&#xff1a; 二…

质量好洗地机有哪些?洗地机口碑榜

在很多人眼中&#xff0c;洗地机可能被简单地视为一种高价的拖把&#xff0c;但作为一个经验丰富的洗地机测评博主&#xff0c;我要强调洗地机在家务工作中的巨大价值。它不仅仅是一种清洁工具&#xff0c;更是集扫地、拖地、洗地以及擦干地板等多项功能于一身的强大设备。通过…

【python 的各种模块】(9) 在python使用PIL,即pillow模块

目录 1 导入PIL模块&#xff08;pillow&#xff09; 1.1 导入PIL模块 1.1.1 可用的导入形式 1.1.2 常用的导入形式 1.1.3 PIL下面的常用子模块 2 用 PIL读入&#xff0c;生成和显示图片 2.1 用 PIL.Image.open() 可以读入图片 2.2 用PIL.Image.new() 生成新图片 2.3 …

详解矩阵的正交化(附例题分析)

目录 一. 矩阵Gram-Schmidt正交化的好处 二. 矩阵标准正交化过程 三. 例题 3.1 标准正交化 3.2 算法小结 3.3 优化分析 四. 小结 矩阵有两类等价关系 矩阵对角化 特殊矩阵 一. 矩阵Gram-Schmidt正交化的好处 假如有三个线性独立的向量a,b,c&#xff0c;他们是标准正…

antv/x6_2.0学习使用(五、路由)

X6 默认提供了以下几种路由: 路由名称说明normal默认路由&#xff0c;原样返回路径点orth正交路由&#xff0c;由水平或垂直的正交线段组成oneSide受限正交路由&#xff0c;由受限的三段水平或垂直的正交线段组成manhattan智能正交路由&#xff0c;由水平或垂直的正交线段组成…

【Linux】Linux 系统编程——tree 命令

文章目录 1. 命令概述2. 命令格式3. 常用选项4. 相关描述4.1 tree 命令安装 5. 参考示例5.1 创建树形目录5.2 使用 tree 命令查看树形目录 1. 命令概述 tree 命令用于在命令行界面以树状图形式显示目录及其子目录的内容。这个命令递归地列出所有子目录&#xff0c;并可选择显示…

uⅤ打印-小理光上海RYPC后台运动系统

uⅤ打印-小理光上海RYPC后台运动系统

应用在游戏机触摸屏中的触摸感应芯片

触屏游戏机的屏幕是由液晶屏和触控层组成的。触控层分为电容式触屏和电阻式触屏两种。电容式触屏是将悬空电极和屏幕玻璃上的电极组成静电场,当人体接近屏幕时,就会改变静电场分布,从而实现触摸的位置探测。而电阻式触屏则是利用玻璃上的两层电极之间通电形成一个电阻值,当手指…