对PosWiseFFN的改进: MoE、PKM、UltraMem

先从PosWiseFFN说起

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.GeLU(),
            nn.Linear(d_ff, d_model, bias=False))

    def forward(self, inputs):                                  # inputs: [batch_size, seq_len, d_model]
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model)(output + residual)  # [batch_size, seq_len, d_model]

如果Attention的维度是d_model,通常PosWiseFFN模型结构就是2个矩阵中间加个Gelu,d_ff是d_model的4倍:第1个矩阵的weight是[d_model, 4*d_model],第2个矩阵的的weight是[4*d_model, d_model]。

PosWiseFFN这个结构也可以理解成一种qkv查询的思路,如果第1个矩阵理解成key,第二矩阵理解成value,那么输入就是[batch_size, seq_len, d_model]的input作为query先去和key做矩阵乘法,得到一个[batch_size, seq_len, 4*d_model]的dots,这个dots过了GeLU后再去和[4*d_model, d_model]的第二个矩阵相乘,这一步变向取了前d_model重要的结果。问题来了,能不能把 4*d_model的d_ff给变得更大呢?Figure 1来自Large Memory Layers with Product Keys的Figure1,图里的|K|在PosWiseFFN里就是 4*d_model。
在这里插入图片描述

下面的PKM简单来说就是把这种qkv查询的思路借用PQ的思想给改进了

PKM(Product Key Memory,这个Product其实就是Product Quantization的Product)

在Large Memory Layers with Product Keys的Figure1里,q的shape是[…,d_model],k的shape是[d_model, |K|],下面看Figure2里怎么解决|K|过大的问题?图里把d_model维的q劈成q1和q2,q1和q2的维度分别是d_model/2;同样的,把[d_model, |K|]的keys劈成[d_model/2, |K|]的sub-key set 1(下图里不带’的 c 1 c_1 c1, c 2 c_2 c2, c 3 c_3 c3)和[d_model/2, |K|]的sub-key set 2(下图里带’的 c 1 ′ c^{'}_1 c1, c 2 ′ c^{'}_2 c2, c 3 ′ c^{'}_3 c3)。这样两半都出topk,最后从 k 2 k^2 k2里再选出k个,这就是Product Quantization的思想
在这里插入图片描述

代码赏析

代码来自https://github.com/lucidrains/product-key-memory/tree/master,里面einops用的不错,下面给一些注释:

class PKM(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        num_keys = 128,
        topk = 32,
        dim_head = 128,
        input_dropout = 0.,
        query_dropout = 0.,
        value_dropout = 0.,
        attn_dropout = 0.,
        use_layernorm = True,
        pre_layernorm = False,
        differentiable_topk = False,
        concat_values_and_combine = False,
        norm_output = False,
        non_competitive_gates = False # Csordas et al. claims non-competitive gates work even better
    ):
        super().__init__()
        self.topk = topk
        self.heads = heads
        self.num_keys = num_keys
        dim_query = dim_head * heads * 2
        self.to_queries = nn.Linear(dim, dim_query, bias = False)

        # pre-layernorm pattern
        self.pre_layernorm = nn.LayerNorm(dim) if pre_layernorm else nn.Identity()

        # batchnorm would break causality
        self.use_layernorm = use_layernorm

        if use_layernorm:
            self.norm = nn.LayerNorm(dim_head)
        else:
            self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))

        # keys
        self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
        init_(self.keys)

        # values
        self.concat_values_and_combine = concat_values_and_combine
        if concat_values_and_combine:
            values = nn.Embedding(num_keys ** 2, dim_head)

            self.values = nn.Sequential(
                values,
                Reduce('b (h k) d -> b h d', 'sum', h = heads),
                Rearrange('b n d -> b (n d)'),
                nn.Linear(dim_head * heads, dim, bias = False)
            )
        else:
            values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
            self.values = values
        init_(values.weight)

        # dropouts
        self.input_dropout = nn.Dropout(input_dropout)
        self.query_dropout = nn.Dropout(query_dropout)
        self.value_dropout = nn.Dropout(value_dropout)
        self.attn_dropout = nn.Dropout(attn_dropout)

        # non competitive gates
        self.gate_activation = nn.Softmax(dim = -1) if not non_competitive_gates else nn.ReLU()
        # use a differentiable topk, based on coordinate descent
        self.differentiable_topk = differentiable_topk
        # https://arxiv.org/abs/2302.06461
        # claims to boost performance of softmax key / value networks by simply layernorming the output
        self.output_norm = nn.LayerNorm(dim) if norm_output else nn.Identity()

    def forward(
        self,
        x,
        input_mask = None,
        gumbel_noise_scale = 0.,
        **kwargs
    ):
        b, t, h = *x.shape[:2], self.heads

        x = self.pre_layernorm(x)
        x = self.input_dropout(x)

        queries = self.to_queries(x)

        #写一下queries的shape: b=batch_size, t=target_seq_len, p=partition, h=num_heads, d=head_dim
        queries = rearrange(queries, 'b t (p h d) -> (b p h) t d', p = 2, h = h)

        # norm and dropout queries
        norm_kwargs = dict(mask = input_mask) if not self.use_layernorm else dict()
        queries = self.norm(queries, **norm_kwargs)
        queries = self.query_dropout(queries)

        queries = rearrange(queries, '(b p h) t d -> p b t h d', p = 2, h = h)

        # similarity to keys
        # keys.shape:heads, num_keys, 2, dim_head。这里的n是keys的batch_size
        # 这里的keys本质上是一个二维数组
        dots = einsum('p b t h d, h n p d -> b t h p n', queries, self.keys)

        # gumbel noise
        if gumbel_noise_scale > 0.:
            dots = dots + gumbel_noise(dots) * gumbel_noise_scale

        # topk scores
        if self.differentiable_topk:
            scores, indices, *_ = coor_descent_topk(dots, k = self.topk, fused = True)
        else:
            scores, indices = dots.topk(k = self.topk, dim = -1)
        # scores are factorized
        (scores_x, scores_y), (indices_x, indices_y) = map(lambda t: t.chunk(2, dim = 3), (scores, indices))

        all_topk = self.topk ** 2

        all_scores = rearrange((
            rearrange(scores_x, '... k -> ... k 1') +
            rearrange(scores_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        all_indices = rearrange((
            rearrange(indices_x, '... k -> ... k 1') * self.num_keys +
            rearrange(indices_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
        value_indices = all_indices.gather(-1, final_indices)

        # attention

        attn = self.gate_activation(final_topk)
        attn = self.attn_dropout(attn)

        value_indices, attn = map(lambda t: rearrange(t, 'b t h k -> (b t) (h k)'), (value_indices, attn))

        # aggregate

        if self.concat_values_and_combine:
            out = self.values(value_indices)
        else:
            out = self.values(value_indices, per_sample_weights = attn)

        out = self.value_dropout(out)

        # maybe layernorm the output

        out = self.output_norm(out)

        return rearrange(out, '(b t) d -> b t d', b = b)

UltraMem

来自ULTRA-SPARSE MEMORY NETWORK,字节发这个时候吹“有效解决了MoE推理时高额的访存问题,推理速度较MoE架构提升2-6倍,推理成本最高可降低83%”,猛地一看以为把DeepSeekMoE又给提升了2-6倍,可实际上是下面这个MoE的paper。UltraMem的思路实际上是对PKM思路的一种改进,但字节并没有公布源代码,也不知道他们家的智障豆包用上了没,先摘录一些核心想法,等代码出了再仔细拜读。
在这里插入图片描述
为了解决drawback1和drawback3,把PQ改成了下面的TDQKR,一种基于SVD分解的方法:
在这里插入图片描述

MoE

这个MoE不同于MoE架构LLM中的MoE,而是对PosWiseFFN的改进,来自于Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,以下是论文中的截图,看一眼就知道大致的思路:
在这里插入图片描述

附录:

  1. https://mp.weixin.qq.com/s/BPGbzAQ5AKPj7yqrOCCuGQ?token=2117558689&lang=zh_CN
  2. https://team.doubao.com/zh/publication/ultra-sparse-memory-network?view_from=research
  3. https://www.cls.cn/detail/1940788

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/969751.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

web第三次作业

弹窗案例 1.首页代码 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>综合案例</title><st…

计数排序

目录 计数排序原理和步骤&#xff1a; 完整代码实现&#xff1a; 计数排序原理和步骤&#xff1a; 当一段数据比较集中在一个范围&#xff0c;比如 98&#xff0c;95&#xff0c;98&#xff0c;91&#xff0c;90&#xff0c;93&#xff0c;94&#xff0c;97&#xff0c;93&…

安装OpenXR运行时 微软Windows Mixed Reality的OpenXR

1、下载openxr示例代码 https://github.com/KhronosGroup/OpenXR-SDK-Source.git mkdir build\win64 cd build\win64 2、编译会生成可执行文件 C:\Work\github\OpenXR-SDK-Source\build\win64\src\tests\hello_xr\Debug\hello_xr.exe 执行 C:\Work\github\OpenXR-SDK-Source\b…

伯克利 CS61A 课堂笔记 07 —— Lists

本系列为加州伯克利大学著名 Python 基础课程 CS61A 的课堂笔记整理&#xff0c;全英文内容&#xff0c;文末附词汇解释。 目录 01 Lists 列表 02 Containers 容器 03 For Statements For 语句 04 Ranges 范围 05 List Comprehensions 列表理解 06 Box-and-Pointer Nota…

数据结构--八大排序算法

1. 直接插入排序 当插入第 i(i>1) 个元素时&#xff0c;前面的 array[0],array[1],…,array[i-1] 已经排好序&#xff0c;此用 array[i] 的排序码与 array[i-1],array[i-2],… 的排序码顺序进行比较&#xff0c;找到插入位置即将 array[i] 插入&#xff0c;原来位置上的元素…

mapbox V3 新特性,添加下雪效果

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;mapbox 从入门到精通 文章目录 一、&#x1f340;前言1.1 ☘️mapboxgl.Map 地图对象…

STM32 GPIO误触发问题全解析:从噪声干扰到电路设计优化

问题描述 在STM32项目中&#xff0c;配置某GPIO为内部上拉输入模式&#xff0c;并外接了一个上拉电阻。该引脚通过1米长的线束连接至电机控制模块&#xff0c;但出现以下异常&#xff1a; 弯折线束或手指触碰线束时&#xff0c;电机误触发&#xff08;MCU检测到低电平&#x…

pyqt自制简单浏览器(python)

确保已安装 PyQt5 和 PyQtWebEngine 库。 import sys from PyQt5.QtCore import QUrl from PyQt5.QtWidgets import QApplication, QMainWindow, QToolBar, QLineEdit, QAction, QListWidget, QVBoxLayout, QDialog, QMessageBox, QInputDialog, QTabWidget from PyQt5.QtWebE…

【人工智能】如何选择合适的大语言模型,是能否提高工作效率的关键!!!

DeepSeek R1入门指南 导读一、提示语差异1.1 指令侧重点不同1.2 语言风格差异1.3 知识运用引导不同 二、挑选原则2.1 模型选择2.2 提示语设计2.3 避免误区 结语 导读 大家好&#xff0c;很高兴又和大家见面啦&#xff01;&#xff01;&#xff01; 在前面的内容中&#xff0c…

​矩阵元素的“鞍点”​

题意&#xff1a; 一个矩阵元素的“鞍点”是指该位置上的元素值在该行上最大、在该列上最小。 本题要求编写程序&#xff0c;求一个给定的n阶方阵的鞍点。 输入格式&#xff1a; 输入第一行给出一个正整数n&#xff08;1≤n≤6&#xff09;。随后n行&#xff0c;每行给出n个整数…

ChatGPT搜索免费开放:AI搜索引擎挑战谷歌霸主地位全面分析

引言 2025年2月6日&#xff0c;OpenAI宣布ChatGPT搜索功能向所有用户免费开放&#xff0c;且无需注册登录。这一重大举措在搜索引擎行业引发巨大反响&#xff0c;有观点认为"谷歌搜索时代即将结束"。本文将深入分析ChatGPT生成式AI搜索对谷歌搜索业务及全球搜索市场…

NO.18十六届蓝桥杯备战|循环嵌套|乘法表|斐波那契|质数|水仙花数|(C++)

循环嵌套 循环嵌套的使⽤ while &#xff0c; do while &#xff0c; for &#xff0c;这三种循环往往会嵌套在⼀起才能更好的解决问题&#xff0c;就是我们所说的&#xff1a;循环嵌套。这三种循环都可以任意嵌套使⽤ ⽐如&#xff1a; 写⼀个代码&#xff0c;打印⼀个乘法⼝…

国际互联网安全日|Web3 世界的安全挑战与防护指南

2025 年 2 月 11 日是全球 “国际互联网安全日”&#xff08;Safer Internet Day&#xff09;。当我们跨越 Web2 迈入 Web3 时代&#xff0c;互联网安全的内涵也在悄然改变。在 Web2 时代&#xff0c;我们主要关注社交媒体隐私泄露、账号密码被盗、网络诈骗等传统安全问题。而在…

DeepseeK自动写作,自动将回答导出文档

在使用 DeepseeK 进行内容生成时&#xff0c;您是否也遇到了答案导出的困扰&#xff1f;无论是内容创作、数据分析还是项目报告&#xff0c;快速、高效地将生成的答案导出是提升工作效率的关键。本文将为您提供简单易行的解决方案&#xff0c;助您轻松实现 DeepseeK 答案的导出…

deep seek

1.介绍:DeepSeek是一款由国内人工智能公司研发的大型语言模型&#xff0c;拥有强大的自然语言处理能力&#xff0c;能够理解并回答问题&#xff0c;还能辅助写代码、整理资料和解决复杂的数学问题。免费开源&#xff0c;媲美ChatGPT 最近最火爆的AI对话程序。 www.deepseek.com…

数据结构中的邻接矩阵

一、概念 邻接矩阵&#xff08;Adjacency Matrix&#xff09;是图&#xff08;Graph&#xff09;的一种表示方法&#xff0c;用于描述图中顶点之间的连接关系。它是一种常见的数据结构&#xff0c;特别适用于表示稠密图&#xff08;即边数接近于顶点数平方的图&#xff09;。 …

微软AutoGen高级功能——Selector Group Chat

介绍 大家好&#xff0c;这次给大家分享的内容是微软AutoGen框架的高级功能Selector Group Chat(选择器群聊)&#xff0c;"选择器群聊"我在给大家分享的这篇博文的代码中有所体现微软AutoGen介绍——Custom Agents创建自己的Agents-CSDN博客&#xff0c;但是并没有详…

web前端开发中vscode常用的快捷键

1.快速复制一行 快捷键&#xff1a; shiftalt 下箭头(上箭头) 或者 ctrlc 然后 ctrlv 2.选定多个相同的单词 快捷键&#xff1a; ctrl d 先双击选定一个单词&#xff0c;然后按下 ctrl d 可以往下依次选择相同的单词。 这样同时修改相同的单词 3.全局替换某单词 当我们一个…

网络安全要学python 、爬虫吗

网络安全其实并不复杂&#xff0c;只是比普通开发岗位要学习的内容多一点。无论是有过编程基础还是零基础的都可以学习的。网络安全目前可就业的岗位从技术上可分为两部分&#xff1a;web安全和二进制逆向安全。web安全是网络安全的入门方向&#xff0c;内容简单&#xff0c;就…

深入解析哈希表:原理、实现与应用

目录 一、哈希表是什么&#xff1f; 1.1 基本概念 1.2 哈希表的工作原理 二、哈希表的使用方法 2.1 C 标准库中的哈希表 示例&#xff1a;std::unordered_map 的基本用法 2.2 自定义哈希函数 示例&#xff1a;自定义哈希函数 三、什么时候使用哈希表&#xff1f; 3.1…