Google的MLP-MIXer的复现(pytorch实现)

Google的MLP-MIXer的复现(pytorch实现)

该模型原论文实现用的jax框架实现,先贴出原论文的代码实现:

# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp


class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) #  (32, 512, 196)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)


class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)
    x = einops.rearrange(x, 'n h w c -> n (h w) c')  # 从(32,512,14,14)变成了(32,196,512)
    for _ in range(self.num_blocks):
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    x = jnp.mean(x, axis=1)
    if self.num_classes:
      x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
    return x


model_params = {
    'patches': {'size': (16, 16), 'stride': (16, 16)}, # 这里需要一个描述patch大小和步长的对象,例如Flax的stem模块初始化参数
    'num_classes': 10,  # 分类任务的类别数
    'num_blocks': 8,  # Mixer Block的重复次数
    'hidden_dim': 512,  # 隐藏层维度
    'tokens_mlp_dim': 256,  # token mixing的MLP维度
    'channels_mlp_dim': 2048,  # channel mixing的MLP维度
}

# 准备输入数据,例如一批32张图片,每张图片尺寸为512x14x14(假设已经按要求预处理)

# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)

input_data = jnp.ones((4096, 224, 224, 3))  # 示例输入数据
# 调用模型进行前向传播
output = model(input_data)

print("Output shape:", output)  # 打印输出形状,预期是(32, 10)如果num_classes=10

该模型的总体框架图如下所示:

在这里插入图片描述

对该框架的讲解,网上已经很多了,就不在此赘述。

实现的pytorch代码如下所示:

class MlpBlock(nn.Module):
    def __init__(self, in_mlp_dim=196, out_mlp_dim=256):
        super(MlpBlock, self).__init__()
        self.mlp_dim = out_mlp_dim
        self.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim)  # 若输入的向量为[32,196, 512]则输入的也应该是512,输出可以自己定
        self.gelu = nn.GELU()
        self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)

    def forward(self, x):
        y = self.dense1(x)
        y = self.gelu(y)
        y = self.dense2(y)
        return y


class MixerBlock(nn.Module):
    def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):
        super(MixerBlock, self).__init__()
        self.batch_size = batch_size
        self.norm1 = nn.LayerNorm(512)  # 对512维的做归一化,默认给最后一个维度做归一化
        self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)
        self.norm2 = nn.LayerNorm(512)      # 对512维的做归一化
        self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)

    def forward(self, x):
        y = self.norm1(x)
        y = y.permute(0, 2, 1)
        y = self.token_Mixing(y)
        y = y.permute(0, 2, 1)
        x = x + y
        y = self.norm2(x)
        return x + self.channel_mixing(y)


class MlpMixer(nn.Module):
    def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):
        super(MlpMixer, self).__init__()
        self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)
        self.mixer_block_1 = MixerBlock()
        self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])
        self.pre_head_norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        x = self.stem(x)
        b, c, h, w = x.shape
        x = x.view(b, c, -1).permute(0, 2, 1)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.pre_head_norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x


# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224)  # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)

在将flax框架的代码改为pytorch实现的时候,还是踩了不少的坑,在此讲一下,希望后面做的人,可以避免。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

在nn.linear那儿的in_channel与第三个维度保持一致,就可以不必将其三维的转换为二维的。同时在对layernorm那儿转换的时候,默认也是对最后一个维度进行正则化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635458.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【杂七杂八】Huawei Gt runner手表系统降级

文章目录 Step1:下载安装修改版华为运动与健康Step2:在APP里进行配置Step3:更新固件(时间会很长) 目前在使用用鸿蒙4 111版本的手表系统,但是感觉睡眠检测和运动心率检测一言难尽,于是想到是否能回退到以前的版本&…

从需求角度介绍PasteSpider(K8S平替部署工具适合于任何开发语言)

你是否被K8S的强大而吸引,我相信一部分人是被那复杂的配置和各种专业知识而劝退,应该还有一部分人是因为K8S太吃资源而放手! 这里介绍一款平替工具PasteSpider,PasteSpider是一款使用c#编写的linux容器部署工具(使用PasteSpider和…

算法学习:快速排序

🔥 个人主页:空白诗 文章目录 🚀 引言📌 快速排序算法核心思想1. 选择基准值(Pivot)2. 分区操作(Partitioning)3. 递归排序子序列 📌 JavaScript 实现1. 快速排序主函数2…

Mac 安装 git

文章目录 前言一、介绍二、下载三、验证四、配置五、Git常用命令六、git提交和撤销工作流程代码提交和提交同步代码撤销和撤销同步 FAQ1.homebrew 下载解决方法一(强烈推荐):解决方法二: 总结 前言 Git 是一个开源的分布式版本控…

Flutter问题 --- 用web端打开项目提示:The XMLHttpRequest onError callback was called.

跨域问题: 打开当前使用的flutter sdk,找到flutter\packages\flutter_tools\lib\src\web\chrome.dart文件并打开. 在--disable-extensions位置后添加--disable-web-security 找到flutter\bin\cache目录,删除flutter_tools.stamp和flutter_tools.snaps…

为什么说 Redis 是单线程的?——Java全栈知识(25)

为什么说 Redis 是单线程的? 我们常说的 Redis 是单线程的,但是我前面在讲持久化机制的时候又说 RDB 的持久化是通过主进程 fork 出一个子进程来实现 RDB 持久化。那么 Redis 到底是多线程还是单线程的呢? Redis 的网络 IO 和键值的读写是单…

一个简约高级视差效果PR动态图文开场视频模板

这是一个高质量且易于定制的pr模板。具有模块化结构,可以轻松更改内容。包括视频教程,即使是新手小白也可以轻松套用模板制作视频。 主要特点: 水平(19201080)和垂直(10801920)分辨率&#xff…

如何用PLC监听消防系统485总线通讯获取火灾报警数据

本文由艺捷自动化编写,艺捷自动化旗下产品有艺捷自动化网站和易为二维码说明书小程序(微信) 前言,先来说一下为什么会有这么一个奇怪的应用。在一个自动化系统改造升级项目中,甲方要求把消防的画面加到他们的后台上。…

判定字符是否唯一

题目链接 判定字符是否唯一 题目描述 注意点 0 < len(s) < 100s[i]仅包含小写字母 解答思路 首先想到的是使用数组存储字母是否出现过&#xff0c;如果多次出现则直接返回false为了不适用额外的数据结构&#xff0c;可以使用位运算判定字符是否唯一&#xff0c;思路…

Kibanna安装配置

环境&#xff1a;windows10、ES&#xff08;8.13.3&#xff09;、Kibana&#xff08;8.13.3&#xff09;、Logstash&#xff08;8.13.3&#xff09; 1.Kibanna安装配置 Kibanna对ES的数据进行可视化、分析和监控 Download Kibana Free | Get Started Now | ElasticDownload K…

如何让大模型更聪明?

如何让大模型更聪明&#xff1f; *随着人工智能技术的飞速发展&#xff0c;大模型在多个领域展现出了前所未有的能力&#xff0c;但它们仍然面临着理解力、泛化能力和适应性等方面的挑战。那么&#xff0c;如何让大模型变得更聪明呢&#xff1f; 方向一&#xff1a;算法创新 …

手写tomcat(Ⅱ)——Socket通信+tomcat静态资源的获取

Socket通信简介 参考文章&#xff1a;socket通讯原理及例程&#xff08;一看就懂&#xff09; socket是介于应用层&#xff08;http协议&#xff09;和传输层&#xff08;TCP/UDP协议&#xff09;之间的一层虚拟层 Socket是一个程序&#xff0c;符合TCP/UDP协议的规范&…

GNU Radio之OFDM Divide和Matrix Transpose底层C++实现

文章目录 前言一、OFDM Divide 模块1、简介2、模块作用3、参数意义4、C 具体实现 二、Matrix Transpose 模块1、简介2、参数意义3、C 具体实现 前言 gr-radar 中的 OFDM Divide 模块是GNU Radio中的一个组件&#xff0c;专门用于处理正交频分复用&#xff08;OFDM&#xff09;…

力扣HOT100 - 72. 编辑距离

解题思路&#xff1a; 动态规划 class Solution {public int minDistance(String word1, String word2) {int n1 word1.length();int n2 word2.length();int[][] dp new int[n1 1][n2 1];for (int j 1; j < n2; j) dp[0][j] dp[0][j - 1] 1;for (int i 1; i < …

【Android安全】AOSP版本对应编号| AOSP版本适配Pixel或Nexus型号 | 驱动脚本下载地址

AOSP版本对应编号 https://source.android.com/docs/setup/about/build-numbers?hlzh-cn#source-code-tags-and-builds 例如android-8.1.0_r1 对应的编号是OPM1.171019.011 可以适配Pixel 2 XL AOSP驱动脚本下载 编译AOSP时&#xff0c;需要Google的驱动&#xff0c;后面才…

SaToken+SpringBoot+Redis前后端分离登录认证

目录 前言一、创建工程项目&#x1f38d;1.1 创建后端工程1.2 创建前端工程 二、业务代码&#x1f38a;后端代码前端代码 三、测试参考资料 前言 Sa-Token 是一款 Java 语言的权限认证框架&#xff0c;提供了灵活、高效、易用的权限认证和会话管理功能。它是 SpringBoot、Spri…

解除网页禁止选择

控制台输入以下命令 复制&#xff1a;javascript:void(document.body.οncοpy) 可选&#xff1a;javascript:void(document.body.onselectstart) 拖拉&#xff1a;javascript:void(document.body.οnmοuseup)

短剧系统源码解析与应用

在数字化时代&#xff0c;短剧作为一种新兴的娱乐形式&#xff0c;因其内容紧凑、节奏快速而受到广大年轻群体的喜爱。短剧系统源码的开发和应用&#xff0c;不仅为创作者提供了一个展示才华的平台&#xff0c;也为观众带来了全新的观看体验。本文将对短剧系统源码进行解析&…

MVSnet 代码详解(pytorch)

大致过一下MVSnet 论文中核心的点对应代码应该怎么写。 forward 函数需要 照片&#xff0c;映射矩阵&#xff0c;以及深度值。 照片的shape是 &#xff08;1&#xff0c;5,3&#xff0c;1184,1600&#xff09;代表着1个batch,5张图片&#xff0c;然后一次是每张图片的channel和…

【软考】设计模式之装饰器模式

目录 1. 说明2. 应用场景3. 结构图4. 构成5. 适用性6. 优点7. 缺点8. java示例 1. 说明 1.动态地给一个对象添加一些额外的职责。2.Decorator Pattern。3.就增加功能而言&#xff0c;装饰器模式比生成子类更加灵活。4.一种在不改变现有对象结构的情况下&#xff0c;动态地给对…