AttentionFreeTransformer 源码解析(一):AFTFull、AFTSimple、AFTLocal

我觉得源码写的很好懂,我就不加注释了,直接上计算流程图。

AFTFull

在这里插入图片描述

class AFTFull(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full

        Number of heads is 1 as done in the paper
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)
        self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))
        nn.init.xavier_uniform_(self.wbias)

    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)
        temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length

        '''
        From the paper
        '''
        Q_sig = torch.sigmoid(Q)
        temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)
        weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))
        Yt = torch.mul(Q_sig, weighted)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

AFTSimple

在这里插入图片描述

class AFTSimple(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full
        
        Number of Heads is 1 as done in the paper.
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)

        '''
        From the paper
        '''
        weights = torch.mul(torch.softmax(K, 1), V).sum(dim=1, keepdim=True)
        Q_sig = torch.sigmoid(Q)
        Yt = torch.mul(Q_sig, weights)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

AFTLocal

在这里插入图片描述

class AFTLocal(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64, s=256):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full
        s: the window size used for AFT-Local in the paper

        Number of heads is 1 as done in the paper
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)
        self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))
        self.max_seqlen = max_seqlen
        self.s = s
        nn.init.xavier_uniform_(self.wbias)


    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)
        self.wbias = nn.Parameter(torch.Tensor([
            [self.wbias[i][j] if math.fabs(i-j) < self.s else 0 for j in range(self.max_seqlen)] 
            for i in range(self.max_seqlen)
            ]))
        temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length

        '''
        From the paper
        '''
        Q_sig = torch.sigmoid(Q)
        temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)
        weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))
        Yt = torch.mul(Q_sig, weighted)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/68187.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DP(区间DP)

石子合并 设有 N 堆石子排成一排&#xff0c;其编号为 1,2,3,…,N。 每堆石子有一定的质量&#xff0c;可以用一个整数来描述&#xff0c;现在要将这 N 堆石子合并成为一堆。 每次只能合并相邻的两堆&#xff0c;合并的代价为这两堆石子的质量之和&#xff0c;合并后与这两堆…

cesium学习记录06-视图、场景与相机

一、视图&#xff08;Viewer&#xff09; viewer是cesium的核心类&#xff0c;是一切的开端。通过new Cesium.Viewer(container, options)来创建一个Viewer对象&#xff0c;而通过这个 Viewer对象&#xff0c;可以添加图层、实体、相机控制等&#xff0c;以及设置一些全局属性…

esp8266使用arduinoJson与tft_espi库发生冲突解决方法

esp8266使用arduinoJson与tft_espi库发生冲突解决方法 arduinoJson与tft_espi库发生冲突解决方法下载arduinoJson5.0版本的&#xff0c;不要用最新版本 示范代码&#xff1a; // Copyright Benoit Blanchon 2014 // MIT License // // Arduino JSON library // https://git…

Unity游戏源码分享-仿帝国时代游戏Demo-uRTS Toolkit

Unity游戏源码分享-仿帝国时代游戏Demo-uRTS Toolkit 游戏的架构值得参考 项目地址&#xff1a;https://download.csdn.net/download/Highning0007/88189905

Pycharm 双击启动失败?

事故 双击 Pycharm 后&#xff0c;出现加载工程&#xff0c;我不想加载这个工程&#xff0c;就点击了弹出的 cancle 取消按钮。然后再到桌面双击 Pycharm 却发现无法启动了。哪怕以管理员权限运行也没用&#xff0c;就是不出界面。 原因未知 CtrlshiftESC 打开后台&#xff…

易服客工作室:如何创建有用的内容日历

利用技巧和工具优化您的内容营销效率和效果。创建一个内容日历&#xff0c;您的整个团队都会从中受益&#xff01; 欢迎来到熙熙攘攘、瞬息万变的内容营销世界&#xff0c;在这里&#xff0c;截止日期到来的速度比喝咖啡的猎豹还要快。 现在&#xff0c;想象一下在没有地图、…

MapBox加载不同风格

初始化MapBox地图&#xff1a; var map new mapboxgl.Map({container: map,zoom: 3,center: [105, 34],//此处更改地图风格style: mapbox://styles/mapbox/satellite-v9,hash: false,});1.户外地图&#xff08;mapbox://styles/mapbox/basic-v9&#xff09;新版&#xff1a;&a…

设计模式之模板方法

一、概述 定义一个操作中的算法的骨架&#xff0c;将一些步骤延迟到子类中。 TemplateMethod使得子类可以不改变一个算法的结构即可重新定义该算法的某些特定步骤。 二、适用性 1.一次性实现一个算法的不变的部分&#xff0c;并将可变的行为留给子类来实现。 2.各子类中公共…

YOLOv5修改注意力机制CBAM

直接上干货 CBAM注意力机制是由通道注意力机制&#xff08;channel&#xff09;和空间注意力机制&#xff08;spatial&#xff09;组成。 传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析&#xff0c;局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 sp…

Kafka与Zookeeper版本对应关系

文章目录 了解版本对应Kafka安装包Kafka源码包 了解 比如&#xff1a; kafka_2.11-1.1.1.jar包 其中2.11表示的是Scala的版本&#xff0c;因为Kafka服务器端代码完全由Scala语音编写。”-“后面的1.1.1表示的kafka的版本信息。遵循一个基本原则&#xff0c;Kafka客户端版本和服…

MySQL5.7数据库、Navicat Premium1.6可视化工具安装教程【详细教程】

文章目录 一、MySQL、Navicat、注册机地址二、安装&#xff08;一&#xff09;、MySQL安装&#xff08;二&#xff09;、Navicat Premium安装&#xff08;三&#xff09;、集活Navicat Premium 三、遇到的问题1、Are you sure your navicat has not beenpatched/modified befor…

Spring 事务管理

目录 1. 事务管理 1.1. Spring框架的事务支持模型的优势 1.1.1. 全局事务 1.1.2. 本地事务 1.1.3. Spring框架的一致化编程模型 1.2. 了解Spring框架的事务抽象&#xff08;Transaction Abstraction&#xff09; 1.2.1. Hibernate 事务设置 1.3. 用事务同步资源 1.3.1…

Malloc动态内存分配

在C语言中我们会使用malloc来动态地分配内存&#xff0c;这样做的一个主要理由是有些数据结构的大小只有在运行时才能确定。例如&#xff0c;如果你正在编写一个程序&#xff0c;需要用户输入一些数据&#xff0c;但你不知道用户会输入多少数据&#xff0c;那么你就需要使用动态…

vue3使用pinia和pinia-plugin-persist做持久化存储

1、安装依赖 pnpm i pinia // 安装 pinia pnpm i pinia-plugin-persist // 安装持久化存储插件2、main.js引入 import App from ./App.vue const app createApp(App)//pinia import { createPinia } from pinia import piniaPersist from pinia-plugin-persist //持久化插件 …

Python中enumerate用法详解

目录 1.简介 2.语法 3.参数 4.返回值 5.详解 6.实例 7.补充 1.简介 enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列&#xff0c;同时列出数据和数据下标&#xff0c;一般用在 for 循环当中。 2.语法 以下是 enumerate() 方法的语…

Qt应用开发(基础篇)——堆栈窗口 QStackedWidget

一、前言 QStackedWidget继承于QFrame&#xff0c;QFrame继承于QWidget&#xff0c;是Qt常用的堆栈窗口部件。 框架类QFrame介绍 QStackedWidget堆栈窗口&#xff0c;根据下标切换&#xff0c;一次显示一个小部件&#xff0c;常用于应用界面切换、图片轮询播放等场景。 二、QSt…

通义千问开源模型部署使用

首先可以参考modelScope社区给出的使用文档&#xff0c;已经足够全面 通义千问-7B-Chat 但在按照文档中步骤部署时&#xff0c;还是有些错误问题发生&#xff0c;可以搜索参考的解决方式不多&#xff0c;所以记录下来 个人电脑部署 这里不太建议使用自己的笔记本部署通义千…

exchange partition global index

EXCHANGE PARTITION with a Table having a UNIQUE INDEX and PK Constraint. (Doc ID 1620636.1)​编辑To Bottom In this Document Symptoms Changes Cause Solution References APPLIES TO: Oracle Database - Enterprise Edition - Version 11.2.0.3 and later Oracle Da…

Spring整合MyBatis(详细步骤)

Spring与Mybatis的整合&#xff0c;大体需要做两件事&#xff0c; 第一件事是:Spring要管理MyBatis中的SqlSessionFactory 第二件事是:Spring要管理Mapper接口的扫描 具体的步骤为: 步骤1:项目中导入整合需要的jar包 <dependency><!--Spring操作数据库需要该jar包…

gazebo 导入从blender导出的dae等文件

背景&#xff1a; gazebo 模型库里的模型在我需要完成的任务中不够用&#xff0c;还是得从 solidworks、3DMax, blender这种建模软件里面在手动画一些&#xff0c;或者去他们的库里面在挖一挖。 目录 1 blender 1-1 blender 相关links 1-2 install 2 gazebo导入模型 2-1 g…