大模型面试准备(五):图解 Transformer 最关键模块 MHA

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何备战、面试常考点分享等热门话题进行了深入的讨论。


合集在这里:《大模型面试宝典》(2024版) 正式发布!


Transformer 原始论文中的模型结构如下图所示:
图片

上一篇文章讲解了 Transformer 的关键模块 Positional Encoding(大家可以自行翻阅),本篇文章讲解一下 Transformer 的最重要模块 Multi-Head Attention(MHA),毕竟 Transformer 的论文名称就叫 《Attention Is All You Need》。

Transformer 中的 Multi-Head Attention 可以细分为3种,Multi-Head Self-Attention(对应上图左侧Multi-Head Attention模块),Multi-Head Cross-Attention(对应上图右上Multi-Head Attention模块),Masked Multi-Head Self-Attention(对应上图右下Masked Multi-Head Attention模块)。

其中 Self 和 Cross 的区分是对应的 Q和 K、 V是否来自相同的输入。是否Mask的区分是是否需要看见全部输入和预测的输出,Encoder需要看见全部的输入问题,所以不能Mask;而Decoder是预测输出,当前预测只能看见之前的全部预测,不能看见之后的预测,所以需要Mask。

本篇文章主要通过图解的方式对 Multi-Head Attention 的核心思想和计算过程做讲解,喜欢本文记得收藏、点赞、关注。技术和面试交流,文末加入我们

MHA核心思想

在这里插入图片描述

MHA过程图解

注意力计算公式如下:

在这里插入图片描述

图示过程图下:

图片

多头注意力

MHA通过多个头的方式,可以增强自注意力机制聚合上下文信息的能力,以关注上下文的不同侧面,作用类似于CNN的多个卷积核。下面我们就通过一张图来完成MHA的解析:

图片

在这里插入图片描述

单头注意力

知道了多头注意力的实现方式后,那如果是通过单头注意力完成同样的计算,矩阵形式是什么样的呢?下面我还是以一图胜千言的方式来回答这个问题:

图片通过单头注意力的比较,相信大家对多头注意力(MHA)应该有了更好的理解。我们可以发现多头注意力就是将一个单头进行了切分计算,最后又将结果进行了合并,整个过程中的整体维度和计算量基本是不变的,但提升了模型的学习能力。

最后附上一份MHA的实现和Transformer的构建代码:

import torch
import torch.nn as nn

# 定义多头自注意力层
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads  # 多头注意力的头数
        self.d_model = d_model  # 输入维度(模型的总维度)
        self.head_dim = d_model // n_heads  # 每个注意力头的维度
        assert self.head_dim * n_heads == d_model, "d_model必须能够被n_heads整除"  # 断言,确保d_model可以被n_heads整除

        # 线性变换矩阵,用于将输入向量映射到查询、键和值空间
        self.wq = nn.Linear(d_model, d_model)  # 查询(Query)的线性变换
        self.wk = nn.Linear(d_model, d_model)  # 键(Key)的线性变换
        self.wv = nn.Linear(d_model, d_model)  # 值(Value)的线性变换

        # 最终输出的线性变换,将多头注意力结果合并回原始维度
        self.fc_out = nn.Linear(d_model, d_model)  # 输出的线性变换


    def forward(self, query, key, value, mask):
        # 将嵌入向量分成不同的头
        query = query.view(query.shape[0], -1, self.n_heads, self.head_dim)
        key = key.view(key.shape[0], -1, self.n_heads, self.head_dim)
        value = value.view(value.shape[0], -1, self.n_heads, self.head_dim)

        # 转置以获得维度 batch_size, self.n_heads, seq_len, self.head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # 计算注意力得分
        scores = torch.matmul(query, key.transpose(-2, -1)) / self.head_dim
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = torch.nn.functional.softmax(scores, dim=-1)

        out = torch.matmul(attention, value)

        # 重塑以恢复原始输入形状
        out = out.transpose(1, 2).contiguous().view(query.shape[0], -1, self.d_model)

        out = self.fc_out(out)
        return out

# 定义Transformer编码器层
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dim_feedforward, dropout):
        super(TransformerEncoderLayer, self).__init__()
        
        # 多头自注意力层,接收d_model维度输入,使用n_heads个注意力头
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        
        # 第一个全连接层,将d_model维度映射到dim_feedforward维度
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        
        # 第二个全连接层,将dim_feedforward维度映射回d_model维度
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        # 用于随机丢弃部分神经元,以减少过拟合
        self.dropout = nn.Dropout(dropout)
        
        # 第一个层归一化层,用于归一化第一个全连接层的输出
        self.norm1 = nn.LayerNorm(d_model)
        
        # 第二个层归一化层,用于归一化第二个全连接层的输出
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, src, src_mask):
        # 使用多头自注意力层处理输入src,同时提供src_mask以屏蔽不需要考虑的位置
        src2 = self.self_attn(src, src, src, src_mask)
        
        # 残差连接和丢弃:将自注意力层的输出与原始输入相加,并应用丢弃
        src = src + self.dropout(src2)
        
        # 应用第一个层归一化
        src = self.norm1(src)

        # 经过第一个全连接层,再经过激活函数ReLU,然后进行丢弃
        src2 = self.linear2(self.dropout(torch.nn.functional.relu(self.linear1(src))))
        
        # 残差连接和丢弃:将全连接层的输出与之前的输出相加,并再次应用丢弃
        src = src + self.dropout(src2)
        
        # 应用第二个层归一化
        src = self.norm2(src)

        # 返回编码器层的输出
        return src


# 实例化模型
vocab_size = 10000  # 词汇表大小(根据实际情况调整)
d_model = 512  # 模型的维度
n_heads = 8  # 多头自注意力的头数
num_encoder_layers = 6  # 编码器层的数量
dim_feedforward = 2048  # 全连接层的隐藏层维度
max_seq_length = 100  # 最大序列长度
dropout = 0.1  # 丢弃率

# 创建Transformer模型实例
model = Transformer(vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout)

最后的最后再贴上一张非常不错的 Transformer 手绘吧!

在这里插入图片描述

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了算法岗技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

用通俗易懂方式讲解系列

  • 《大模型面试宝典》(2024版) 正式发布!
  • 《大模型实战宝典》(2024版)正式发布!
  • 大模型面试准备(一):LLM主流结构和训练目标、构建流程
  • 大模型面试准备(二):LLM容易被忽略的Tokenizer与Embedding
  • 大模型面试准备(三):聊一聊大模型的幻觉问题
  • 大模型面试准备(四):大模型面试必会的位置编码(绝对位置编码sinusoidal,旋转位置编码RoPE,以及相对位置编码ALiBi)

参考文献:

参考资料:
[1] https://jalammar.github.io/illustrated-transformer/
[2] https://zhuanlan.zhihu.com/p/264468193
[3] https://zhuanlan.zhihu.com/p/662777298

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/490294.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

优秀电源工程师需要的必备技能

随着电源市场的不断扩张,开关电源行业飞速发展,企业对电源工程师的需求日益增加,对电源工程师的技能要求也日渐提高,相信没有一位电源工程师会错过让自己变得更优秀的机会。作为一名数字电源从业者,今天就带大家细数一下优秀电源工程师具备的那些技能。 一、新手必备课程…

[leetcode]283. 移动零

前言:剑指offer刷题系列 问题: 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行操作。 示例: 输入: nums …

【ZZULIOJ】1002: 简单多项式求值(Java)

目录 题目描述 输入 输出 样例输入 样例输出 code 题目描述 对用户输入的任一整数,输出以下多项式的值。 输入 输入整数x的值。 输出 输出一个整数,即多项式的值。 样例输入 1 样例输出 11 code import java.util.*;public class Ma…

【AI与WEB3】未来已来:十大领域揭示AI与Web3如何联手重塑全球经济版图

在不远的未来,当科技的脉搏跳动得愈发强劲有力,AI与Web3这两股创新力量正以前所未有的方式交织共舞,犹如科幻电影中的场景跃然现实。在这场颠覆性的技术革命中,我们正见证着十个关键领域的华丽转身,它们如同璀璨的星辰…

Copilot 编程助手的介绍及使用

介绍 Copilot 是2021年由 GitHub 与 OpenAI 合作研发的一款编程助手,同时也是全球首款使用OpenAI Codex模型(GPT-3后代)打造的大规模生成式AI开发工具。 Copilot 底层模型目前经过了数十亿行公开代码的训练,与大多数代码辅助工具…

白话模电:4.耦合、差分、无源滤波、反馈(考研面试常问问题)

一、介绍一下三极管多级放大电路的三种耦合方式及其特点?耦合的目的是什么? 多级放大电路中各放大级之间的连接方式称为耦合方式。常见的耦合方式有三种:阻容耦合(RC耦合)、直接耦合和变压器耦合。 耦合的目的是将信号…

ES6 字符串/数组/对象/函数扩展

文章目录 1. 模板字符串1.1 ${} 使用1.2 字符串扩展(1) ! includes() / startsWith() / endsWith()(2) repeat() 2. 数值扩展2.1 二进制 八进制写法2.2 ! Number.isFinite() / Number.isNaN()2.3 inInteger()2.4 ! 极小常量值Number.EPSILON2.5 Math.trunc()2.6 Math.sign() 3.…

蓝桥-数位排序

目录 题目链接: 思路: 代码: 题目链接: 0数位排序 - 蓝桥云课 (lanqiao.cn) 思路: 自定义排序比较函数 用一个函数来求某个数的数位和 sum() 用一个函数,自定义排序比较函数…

说说2024年度孝感建筑类初中级职称申报评审

说说2024年度孝感建筑类初中级职称申报评审 认真看,错过了就失去2024年申报孝感中级职称评审的机会。孝感中级职称申报评审一年两次,上半年一次,下半年一次。注意!职称水平能力测试是重点。 建筑类职称水平能力测试一年就一次机…

【精简】Spring笔记

文章目录 跳转链接(学习路线)及前言(更新中) 快速入门配置文件详解依赖注入(bean实例化)自动装配集合注入使用spring加载properties文件容器注解开发bean管理注解开发依赖注入第三方bean整合mybatis整合junit AOP入门案例切入点表…

1.Git快速入门

文章目录 Git快速入门1.Git概述2.SCM概述3.Git安装3.1 软件下载3.2 软件安装3.3 软件测试 Git快速入门 1.Git概述 Git是一个免费的,开源的分布式版本控制系统,可以快速高效地处理从小型到大型的各种项目,Git易于学习,占用空间小&…

全面剖析Java多线程编程,抢红包、抽奖实战案例

黑马Java进阶教程,全面剖析Java多线程编程,含抢红包、抽奖实战案例 1.什么是多线程? 2.并发与并行 CPU有这些,4,8,16,32,64 表示能同时进行的线程 3.多线程的第一种实现方式 package com.itheima.reggie;/*** Author lpc* Date …

关系型数据库mysql(7)sql高级语句

目录 一.MySQL常用查询 1.按关键字(字段)进行升降排序 按分数排序 (默认为升序) 按分数升序显示 按分数降序显示 根据条件进行排序(加上where) 根据多个字段进行排序 ​编辑 2.用或(or&…

Word通配符替换章节序号

这里写自定义目录标题 通配符替换章节序号切换域通配符替换内容插入编号切换域代码 通配符替换章节序号 碎片化学习word通配符知识 切换域 切换域:Alt F9 域都变成静态文字:Ctrl/Command Shift F9 通配符 内容通配符单个数字[0-9]多个数字&#…

【python从入门到精通】-- 第二战:注释和有关量的解释

🌈 个人主页:白子寰 🔥 分类专栏:python从入门到精通,魔法指针,进阶C,C语言,C语言题集,C语言实现游戏👈 希望得到您的订阅和支持~ 💡 坚持创作博文…

每天上万简历,录取不到1%!阿里腾讯的 offer 都给了哪些人?

三月天杨柳醉春烟~正是求职好时节~ 与去年秋招的冷淡不同,今年春招市场放宽了许多,不少企业纷纷抛出橄榄枝,各大厂的只差把“缺人”两个字写在脸上了。 字节跳动技术方向开放数10个类型岗位,研发需求占比60%,非研发新增…

【数据结构与算法】java有向带权图最短路径算法-Dijkstra算法(通俗易懂)

目录 一、什么是Dijkstra算法二、算法基本步骤三、java代码四、拓展(无向图的Dijkstra算法) 一、什么是Dijkstra算法 Dijkstra算法的核心思想是通过逐步逼近的方式,找出从起点到图中其他所有节点的最短路径。算法的基本步骤如下:…

【剑指offr--C/C++】JZ22 链表中倒数最后k个结点

一、题目 二、思路及代码 遍历链表并存入vector容器&#xff0c;通过下标取出对应位置元素或者返回空 /*** struct ListNode {* int val;* struct ListNode *next;* ListNode(int x) : val(x), next(nullptr) {}* };*/ #include <cstddef> #include <iterator> #…

轻松搞定!使用Python操作 xlsx 文件绘制饼图

今天&#xff0c;跟大家一起来学习用Python操作xlsx文件&#xff0c;然后绘制了一个饼图。你知道吗&#xff0c;这个过程居然比我想象中的还要简单&#xff01;只需要几行代码&#xff0c;就能轻松搞定&#xff01; 首先&#xff0c;安装一个叫做openpyxl的库&#xff0c;它可…

住在我心里的猴子:焦虑那些事儿 - 三余书屋 3ysw.net

精读文稿 您好&#xff0c;本期我们解读的是《住在我心里的猴子》。这是一本由患有焦虑症的作家所著&#xff0c;关于焦虑症的书。不仅如此&#xff0c;作者的父母和哥哥也都有焦虑症&#xff0c;而作者的母亲后来还成为了治疗焦虑症的专家。这本书的中文版大约有11万字&#x…