CV算法工程师的LLM日志(5)Mixture-of-depths——transformers改进结构 【15分钟代码和原理速通】

前言

简而言之,这是google对transformer一些改进设计,如果这个有效性能够证明并普及,那么下一个大模型的transformer范式就是这个了,当然同时也存在mamba和transformer的jamba崛起,不过现在主流还是transformer,让我们看下文章和代码复现的过程,如果看过我的MOE特别篇中MOE的部分,会更加清晰。


CV算法工程师的LLM日志(5)Mixture-of-depths——transformers改进结构 【15分钟代码和原理速通】

  • 前言
  • 一 、Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
  • 二、MODE架构和MOD代码
    • 代码
  • 总结

一 、Mixture-of-Depths: Dynamically allocating compute in transformer-based language models

动机:大模型训练和推理中,有很多计算是没必要的,即在基于Transformer的语言模型中动态地分配计算资源(FLOPs),以优化模型的性能和效率.
Feature:
通过限制每层可以参与自注意力和多层感知机(MLP)计算的标记Token数量来强制执行总计算预算。
MoD方法使用静态计算图,与动态计算图技术不同,它允许在保持硬件效率的同时动态和上下文敏感地分配计算资源。
Moe结合可能性,能够减少模型的计算需求,还能够在保持或提高性能的同时加快模型的推理速度。

总结:核心点是通过路由决策来决定使用哪些层和跳过哪些层。
路由方案(Routing Schemes)(与MOE思路几乎一样)
Token-Choice Routing:
在这种路由方案中,每个标记Token根据自己的偏好被分配到不同的计算路径上。这通常是通过为每个标记Token生成一个概率分布来实现的,然后根据这个分布将标记Token路由到它最偏好的路径。
这种方法可能会导致负载均衡问题,因为不能保证标记Token会均匀地分配到所有可能的路径上。
Expert-Choice Routing:
与Token-Choice Routing不同,Expert-Choice Routing是由每个计算路径根据标记Token的偏好选择一定数量的标记Token(例如,top k个最高权重的token)。
这种方法确保了完美的负载均衡,因为每个路径都会获得相同数量的标记Token。但它也可能导致某些标记Token被过度处理或处理不足,因为一些标记Token可能因为权重高而被多条路径选中,或者没有被任何路径选中。

论文采用的Expert-Choice Routing方案中,由于只使用单一的计算路径,利用了一个隐含的知识:如果规定了每层处理的token数量K小于序列长度,则超出的TOKEN将被丢弃。这意味着,可以根据序列长度和计算容量,有选择地将标记Token路由到或绕过自注意力和MLP计算,从而在一个前向传播过程中减少FLOPs的消耗。

在这里插入图片描述
在这里插入图片描述
和传统的transformer架构区别:

  1. 每个mod-block增加了一个route 线性层
  2. 动态处理逻辑:决策负载均衡
  3. 动态分配token的比例决定top K,这个很重要对应上述说的token长度问题,(根据序列长度和计算容量,有选择地将标记Token路由到或绕过自注意力和MLP计算,从而在一个前向传播过程中减少FLOPs的消耗。)

二、MODE架构和MOD代码

在这里插入图片描述

如图,MoD和MoE的结合,即MoDE模型,可以通过以下两种方式实现:
Staged MoDE:
在这种方法中,MoD机制首先被应用,它决定标记Token是否绕过某些层或者被送往自注意力机制。
然后,MoE机制被应用,它将参与自注意力计算的标记Token分配给不同的专家进行处理。
这种方式的优点是标记Token可以跳过自注意力步骤,直接被送往专家处理,从而节省计算资源。
Integrated MoDE:
在集成的MoDE模型中,MoD的路由功能被集成到MoE的专家选择机制中。
专家集合中包括了“no-op”(无操作)专家,这些专家相当于MoD中的跳过连接,即通过这些专家的标记Token不做任何计算。
路由机制会将标记Token分配给专家或者“no-op”专家,这样可以简化路由的复杂性,并且使得标记Token显式地学会选择是否绕过专家

MOD的结构基于已有的transfomer很可能像去年的MOE一样迅速普及在学术以及工业界。

代码

从代码上来看MOD可以作为即插即用的结构修改形式。针对上述提到的三个特点,可以参考代码:
代码源于Mod

import torch
import torch.nn as nn
from typing import Optional, Tuple, Any
from transformers import PreTrainedModel

class TokenRouter(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.weight_predictor = nn.Linear(embed_dim, 1)

    def forward(self, x):
        weights = self.weight_predictor(x).squeeze(-1)  # [batch_size, seq_len]
        return weights

class MoD(nn.Module):
    def __init__(self, capacity, block):
        super().__init__()
        self.router = TokenRouter(block.hidden_size)
        self.block = block
        self.capacity = capacity
        self.training_step = 0

    def forward(self,
                x: torch.Tensor,
                attention_mask: torch.Tensor,
                position_ids: torch.Tensor,
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]],
                output_attentions: bool,
                use_cache: bool,
                cache_position: Optional[torch.Tensor] = None,
                **kwargs: Any
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        b, s, d = x.shape
        weights = self.router(x)
        if self.router.training:
            self.training_step += 1 if self.training_step < 1000 else 999
            self.capacity = 0.125 + ((1 - 0.125) * (1. / self.training_step))

        k = int(self.capacity * s)
        top_k_values, top_k_indices = torch.topk(weights, k, dim=1, sorted=True)
        threshold = top_k_values[:, -1].unsqueeze(-1)
        selected_mask = weights > threshold

        # Use torch.gather to select tokens
        selected_tokens = torch.gather(x, 1, top_k_indices.unsqueeze(-1).expand(-1, -1, d))
        selected_position_ids = torch.gather(position_ids, 1, top_k_indices)

        # Create a causal mask for the selected tokens
        if attention_mask is not None:
            selected_attention_mask = torch.gather(attention_mask, 1, top_k_indices.unsqueeze(-1).expand(-1, -1, s))
            selected_attention_mask = torch.gather(selected_attention_mask, 2, top_k_indices.unsqueeze(1).expand(-1, s, -1))
        else:
            selected_attention_mask = None

        # Apply the block to the selected tokens
        if use_cache:
            selected_cache_position = torch.gather(cache_position, 1, top_k_indices) if cache_position is not None else None
            block_output = self.block(
                selected_tokens,
                attention_mask=selected_attention_mask,
                position_ids=selected_position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=selected_cache_position,
                **kwargs
            )
            if len(block_output) == 2:
                processed_tokens, cache = block_output
            else:
                processed_tokens, cache = block_output[0], None
        else:
            processed_tokens = self.block(
                selected_tokens,
                attention_mask=selected_attention_mask,
                position_ids=selected_position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                **kwargs
            )[0]

        # Apply weights to the processed tokens
        processed_tokens = processed_tokens * torch.where(selected_mask, weights.unsqueeze(-1), torch.zeros_like(weights).unsqueeze(-1))

        # Combine the processed tokens with the original tokens
        output = torch.where(selected_mask.unsqueeze(-1), processed_tokens, x)

        return (output, cache) if cache is not None else (output,)


总结

MOD的结构和MOE是天然的相似,整合起来的MODE可以试试fine-tune。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/564321.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

nlp 自然语言处理的dataset数据库积累

下面的这个和 entity recognition有关的。 Weights & Biases

巧用波卡生态优势,Mythical Games 引领 Web3 游戏新航向

Polkadot 对创新、安全和治理的承诺为 Mythical Games 提供了极大的发展价值。这个链上生态不仅将支持 Mythical Games 成长发展&#xff0c;还将帮助其他 Mythos 合作伙伴来壮大建设项目。 —— Mythical Games 创始人兼首席执行官 John Linden 近期 Web3 游戏行业又有新动向&…

《C语言深度解剖》(8):一篇文章彻底学会Visual Studio 调试技巧,新手必看!

&#x1f921;博客主页&#xff1a;醉竺 &#x1f970;本文专栏&#xff1a;《C语言深度解剖》 &#x1f63b;欢迎关注&#xff1a;感谢大家的点赞评论关注&#xff0c;祝您学有所成&#xff01; ✨✨&#x1f49c;&#x1f49b;想要学习更多数据结构与算法点击专栏链接查看&am…

创建电商产品说明书的这些雷,你踩了几条

现如今电商的流行&#xff0c;让电商产品说明书不仅是产品的“身份证”&#xff0c;更是商家与消费者沟通的桥梁。但是&#xff0c;在创建电商产品说明书时&#xff0c;稍不注意就可能踩到“雷区”&#xff0c;给消费者留下不好的印象&#xff0c;甚至影响销量。今天&#xff0…

【计算机2区】毕业快刊 —— 非黑!非预警!各指标优异!

No.1 工程综合类SCIE 【期刊简介】IF&#xff1a;6.0-7.0&#xff0c;JCR1区&#xff0c;中科院2区 【版面类型】纯正刊&#xff0c;仅10篇版面 【自引率】13.30%&#xff08;位于安全阈值内&#xff09; 【年发文量】400篇左右&#xff08;发文量稳定&#xff09; 【国人…

单机三pxc节点集群,+docker-haproxy2.0负载均衡实现

一.下载 https://www.haproxy.org/download/2.0/src/haproxy-2.0.5.tar.gz 或者在这里下载&#xff08;下面需要的各个配置文件都有&#xff09;&#xff1a; https://download.csdn.net/download/cyw8998/89170129 二.编写文件&#xff0c;制作docker镜像 1.Dockerfile&a…

信创产业发展迅速,信创测试需要伴随

信创产业的发展现状呈现出蓬勃的生机与活力。这一领域不仅构成了数据安全、网络安全的基石&#xff0c;更是新型基础设施建设的重要一环。信创产业涵盖了众多关键领域&#xff0c;如云计算、软件&#xff08;包括操作系统、中间件、数据库及应用软件&#xff09;、硬件&#xf…

Android studio配置Flutter(看这一篇就够了)

Flutter 是 Google 推出并开源的移动应用开发框架&#xff0c;主打跨平台、高保真、高性能。开发者可以通过 Dart 语言开发 App&#xff0c;一套代码同时运行在 iOS 和 Android平台。 Flutter 提供了丰富的组件、接口&#xff0c;开发者可以很快地为 Flutter 添加 Native&#…

#vscode | poetry | 虚拟环境 | Interpreter# 使用Poetry进行Python项目依赖管理和VSCode环境配置

系统安装poetry curl -sSL https://install.python-poetry.org | python3 - 安装 poetry --version 验证安装是否成功 项目安装poetry poetry install install 命令从当前项目中读取 pyproject.toml 文件&#xff0c;解析依赖项并安装它们。 Vscode配置 对应虚拟环境的in…

攻防打点|Shiro漏洞利用大全【附工具】

Shiro反序列化漏洞在目前攻防打点中仍然可以使用,如一些废弃的忘记关掉的旁站之类的。。。 「手工如何判断是否存在shiro」 特征码为响应包存在rememberMe=deleteMe 打开burp进行抓包,在请求包中添加Cookie: rememberMe=me,查看返回包中是否存在rememberMe=deleteMe。 「工…

可视化大屏在政务领域应用非常普遍,带你看看

可视化大屏在政务领域的应用非常普遍&#xff0c;政务领域需要处理大量的数据和信息&#xff0c;通过可视化大屏可以将这些数据以直观、易懂的方式展示出来&#xff0c;帮助政府决策者和工作人员更好地了解和分析数据&#xff0c;从而做出更准确、科学的决策。 在政务领域&…

API接口新探索:一键获取商品标题、分类与店铺名称

一、引言 在当今信息化社会&#xff0c;电子商务的蓬勃发展使得各类商品信息浩如烟海。为了高效地获取商品信息&#xff0c;许多开发者选择使用API接口。API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&#xff09;是一种定义明确的方法&…

玩转压力管理,轻松高效编程

程序员缓解工作压力的小窍门 在当今快速发展的科技时代&#xff0c;程序员作为数字世界的建筑师&#xff0c;面临着高强度、高压力的工作环境。为保持工作效率和创新能力&#xff0c;同时也确保身心健康和个人热情的持久续航&#xff0c;采取科学合理的减压策略至关重要。 方…

Django中的定时任务与后台任务队列的实践

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 在Web开发中&#xff0c;处理定时任务和后台任务队列是很常见的需求。Django作为一个功能强…

了解边缘计算,在制造行业使用边缘计算。

边缘计算是一种工业元宇宙技术&#xff0c;可以帮助组织实现其数据的全部潜力。 处理公司的所有数据可能具有挑战性&#xff0c;而边缘计算可以帮助公司更快地处理数据。在制造业中&#xff0c;边缘计算可以帮助进行预测性维护和自动驾驶汽车操作等工作。 什么是边缘计算? …

Spring Boot 自动装配执行流程

Spring Boot 自动装配执行流程 Spring Boot 自动装配执行流程如下&#xff1a; Spring Boot 启动时会创建一个 SpringApplication实例&#xff0c;该实例存储了应用相关信息&#xff0c;它负责启动并运行应用。实例化 SpringApplication 时&#xff0c;会自动装载META-INF/spr…

go语言通过TCP协议实现聊天室样例

1、服务端&#xff1a; package mainimport ("fmt""net""sync" )type ChatServer struct {clients map[string]net.ConnclientsMux sync.Mutex }func NewChatServer() *ChatServer {return &ChatServer{clients: make(map[string]net.Co…

【NoC片上网络 On-Chip Network】应用程序的网络流量 合成网络流量

应用程序的网络流量 and 合成网络流量 1. 应用程序的网络流量 APPLICATION TRAFFIC2. 合成网络流量 SYNTHETIC TRAFFIC3. 合成网络流量的具体介绍 应用程序的网络流量 and 合成网络流量 1. 应用程序的网络流量 APPLICATION TRAFFIC 在 MPSoC(多处理器片上系统) 中&#xff…

书生·浦语大模型第二期实战营(6)作业

1。完成 Lagent Web Demo 使用&#xff0c;并在作业中上传截图。 文档可见 Lagent Web Demo 2、完成 AgentLego 直接使用部分&#xff0c;并在作业中上传截图。 文档可见 直接使用 AgentLego

前端crypto-js, 文件加密,判断相同文件、图片(MD5,SHA256)

文章目录 前情提要应用场景实战解析最后前情提要 大家好,今天我们来接触一个库crypto-js 没错,上面是有道翻译的截图,为了我们得到的信息更权威,这个库是用来加密的,但介绍是说,已经停止维护,但并不影响我们在前端项目中的使用,所以学学也没有坏处 应用场景 判断图片…