【模型学习之路】手写+分析bert

手写+分析bert

目录

前言

架构

embeddings

Bertmodel

预训练任务

MLM

NSP

Bert

后话

netron可视化

code2flow可视化

fine tuning


前言

Attention is all you need!

读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。

毕竟Bert是transformer的变种之一。

架构

embeddings

Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。

这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。

Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):

token_embeddings  和Transformer完全一样。

segment_embeddings  用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。

pos_embeddings  用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。

def get_token_and_segments(tokens_a, tokens_b=None):
    """
    bert的输入之一:token embeddings
    bert的输入之二:segment embeddings
    pos_embeddings在后面的模型里面
    """
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

Bertmodel

Bert的单个EncoderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。

组装好。

class BertModel(nn.Module):
    def __init__(self, vocab, n, d_ff, h, n_layers,
                 max_len=1000, k=768, v=768):
        super(BertModel, self).__init__()
        self.token_embeddings = nn.Embedding(vocab, n)  # [B, m]->[B, m, vocab]->[B, m, n]
        self.segment_embeddings = nn.Embedding(2, n)  # [B, m]->[B, m, 2]->[B, m, n]
        self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n))  # [1, max_len, n]
        self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)
                                     for _ in range(n_layers)])

    def forward(self, tokens, segments, m):  # m是句子长度
        X = self.token_embeddings(tokens) + \
            self.segment_embeddings(segments)
        X += self.pos_embeddings[:, :X.shape[1], :]
        for layer in self.layers:
            X, attn = layer(X)
        return X

简单测试一下。

# 弄一点数据测试一下
    tokens = torch.randint(0, 100, (2, 10))  # [B, m]
    segments = torch.randint(0, 2, (2, 10))  # [B, m]
    m = 10
    bert = BertModel(100, 768, 3072, 12, 12)
    out = bert(tokens, segments, m)
    print(out.shape)  # [2, 10, 768]

预训练任务

Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。

MLM

Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。

在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)

class MLM(nn.Module):
    def __init__(self, vocab, n, mlm_hid):
        super(MLM, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n, mlm_hid),
            nn.ReLU(),
            nn.LayerNorm(mlm_hid),
            nn.Linear(mlm_hid, vocab))

    def forward(self, X, P):
        # X: [B, m, n]
        # P: [B, p]
        # 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了
        p = P.shape[1]
        P = P.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(batch_size)
        batch_idx = torch.repeat_interleave(batch_idx, p)
        X = X[batch_idx, P].reshape(batch_size, p, -1)  # [B, p, n]
        out = self.mlp(X)
        return out

这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)。其实意会也行:就是训练时,在X[B, m, n]里每个batch有p个是<mask>,我们把它们挑出来,得到[B, p, n]。

NSP

Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合is_next关系,另外50%的第二句话是随机从预料中提取的,不符合is_next关系。分别记为1 | 0。

这个关系由每个句子的第一个token——<cls>捕捉。

class NSP(nn.Module):
    def __init__(self, n, nsp_hid):
        super(NSP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n, nsp_hid),
            nn.Tanh(),
            nn.Linear(nsp_hid, 2))

    def forward(self, X):
        # X: [B, m, n]
        X = X[:, 0, :]  # [B, n]
        out = self.mlp(X)  # [B, 2]
        return out

Bert

下面拼装Bert。

class Bert(nn.Module):
    def __init__(self, vocab, n, d_ff, h, n_layers,
                 max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):
        super(Bert, self).__init__()
        self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)
        self.mlm = MLM(vocab, n, mlm_feat)
        self.nsp = NSP(n, nsp_feat)

    def forward(self, tokens, segments, m, P=None):
        X = self.encoder(tokens, segments, m)
        mlm_out = self.mlm(X, P) if P is not None else None
        nsp_out = self.nsp(X)
        return X, mlm_out, nsp_out

后话

netron可视化

利用netron可视化。

test_tokens = torch.randint(0, 100, (2, 10))  # [B, m]
test_segments = torch.randint(0, 2, (2, 10))  # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)

modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)

截取部分看一下。

code2flow可视化

code2flow可以可视化代码函数和类的相互调用关系。

code2flow.code2flow([r'代码路径.py'], '输出路径.svg')

这里生成的png,其实svg清晰得多。

fine tuning

Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!

持续探索Bert......

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/909014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

word及Excel常见功能使用

最近一直在整理需规文档及表格&#xff0c;Word及Excel需要熟练使用。 Word文档 清除复制过来的样式 当复制文字时&#xff0c;一般会带着字体样式&#xff0c;此时可选中该文字 并使用 ctrlshiftN 快捷键进行清除。 批注 插入->批注&#xff0c;选中文本 点击“批注”…

【C++篇】数据之林:解读二叉搜索树的优雅结构与运算哲学

文章目录 二叉搜索树详解&#xff1a;基础与基本操作前言第一章&#xff1a;二叉搜索树的概念1.1 二叉搜索树的定义1.1.1 为什么使用二叉搜索树&#xff1f; 第二章&#xff1a;二叉搜索树的性能分析2.1 最佳与最差情况2.1.1 最佳情况2.1.2 最差情况 2.2 平衡树的优势 第三章&a…

【Mac】安装 VMware Fusion Pro

VMware Fusion Pro 软件已经正式免费提供给个人用户使用&#xff01; 1、下载 【官网】 下拉找到 VMware Fusion Pro Download 登陆账号 如果没有账号&#xff0c;点击右上角 LOGIN &#xff0c;选择 REGISTER 注册信息除了邮箱外可随意填写 登陆时&#xff0c;Username为…

文心一言 VS 讯飞星火 VS chatgpt (383)-- 算法导论24.5 3题

三、对引理 24.10 的证明进行改善&#xff0c;使其可以处理最短路径权重为 ∞ ∞ ∞ 和 − ∞ -∞ −∞ 的情况。引理 24.10(三角不等式)的内容是&#xff1a;设 G ( V , E ) G(V,E) G(V,E) 为一个带权重的有向图&#xff0c;其权重函数由 w : E → R w:E→R w:E→R 给出&…

Linux 服务器使用指南:从入门到登录

&#x1f31f;快来参与讨论&#x1f4ac;&#xff0c;点赞&#x1f44d;、收藏⭐、分享&#x1f4e4;&#xff0c;共创活力社区。 &#x1f31f; &#x1f6a9;博主致力于用通俗易懂且不失专业性的文字&#xff0c;讲解计算机领域那些看似枯燥的知识点&#x1f6a9; 目录 一…

【Maven】——基础入门,插件安装、配置和简单使用,Maven如何设置国内源

阿华代码&#xff0c;不是逆风&#xff0c;就是我疯 你们的点赞收藏是我前进最大的动力&#xff01;&#xff01; 希望本文内容能够帮助到你&#xff01;&#xff01; 目录 引入&#xff1a; 一&#xff1a;Maven插件的安装 1&#xff1a;环境准备 2&#xff1a;创建项目 二…

数据库基础(2) . 安装MySQL

0.增加右键菜单选项 添加 管理员cmd 到鼠标右键 运行 reg文件 在注册表中添加信息 这样在右键菜单中就有以管理员身份打开命令行的选项了 1.获取安装程序 网址: https://dev.mysql.com/downloads/mysql/ 到官网下载MySQL8 的zip包, 然后解压 下载后的包为: mysql-8.0.16-…

cocos开发QA

目录 TS相关foreach循环中使用return循环延迟动态获取类属性 Cocos相关属性检查器添加Enum属性使用Enum报错 枚举“XXX”用于其声明前实现不规则点击区域使用cc.RevoluteJoint的enable激活组件无效本地存储以及相关问题JSON.stringify(map)返回{}数据加密客户端复制文本使用客户…

flutter区别于vue的写法

View.dart 页面渲染&#xff1a; 类似于vue里面使用 <template> <div> <span> <textarea>等标签绘制页面, flutter 里面则是使用不同的控件来绘制页面 样式 与传统vue不同的是 flutter里面没有css/scss样式表&#xff0c; Flutter的理念是万物皆…

DICOM标准:DICOM标准中的公用模块、核心模块详解(一)——病人、研究、序列、参考帧和设备模块属性详解

目录 概述 1 公用病人IE模块 1.1 病人模块 2 公用的研究IE模块 2.1 常规研究模块 2.2 病人研究模块 3 公用序列IE模块 3.1 常规序列模块 3.1.1 常规序列属性描述 4 公用参考帧信息实体模块 4.1 参考帧模块 4.1.1 参考帧属性描述 5 公用设备IE模块 5.1 常规设备模…

轻松搞定项目管理!用对在线项目管理工具助你生产力翻倍!

一、引言 在线项目管理是指借助互联网平台和相关软件工具&#xff0c;对项目从启动到结束的全过程进行规划、组织、协调、控制和监督的一种管理方式。它打破了传统项目管理在时间和空间上的限制&#xff0c;使得项目团队成员无论身处何地&#xff0c;都能实时同步项目信息、协…

ERP研究 | 颜值美和道德美,哪个更重要?

摘要 道德美和颜值美都会影响我们的评价。在这里&#xff0c;本研究采用事件相关电位(ERPs)技术探讨了道德美和颜值美如何交互影响社会判断和情感反应。参与者(均为女性)将积极、中性或消极的言语信息与高吸引力或低吸引力面孔进行关联&#xff0c;并对这些面孔进行评分&#…

【Linux】从零开始使用多路转接IO --- epoll

当你偶尔发现语言变得无力时&#xff0c; 不妨安静下来&#xff0c; 让沉默替你发声。 --- 里则林 --- 从零开始认识多路转接 1 epoll的作用和定位2 epoll 的接口3 epoll工作原理4 实现epollserverV1 1 epoll的作用和定位 之前提过的多路转接方案select和poll 都有致命缺点…

利用 Feather 格式加速数据科学工作流:Pandas 中的最佳实践

利用 Feather 格式加速数据科学工作流&#xff1a;Pandas 中的最佳实践 在数据科学中&#xff0c;高效的数据存储和传输对于保持分析流程的流畅性至关重要。传统的 CSV 格式虽然通用&#xff0c;但在处理大规模数据集时速度较慢&#xff0c;特别是在反复读取和写入时。幸运的是…

[极客大挑战 2019]BabySQL 1

[极客大挑战 2019]BabySQL 1 审题 还是SQL注入和之前的是一个系列的。 知识点 联合注入&#xff0c;双写绕过 解题 输入万能密码 发现回显中没有or&#xff0c;猜测是使用正则过滤了or。 尝试双写绕过 登录成功 使用联合查询&#xff0c;本题中过滤了from&#xff0c;w…

全面解析:大数据技术及其应用

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 全面解析&#xff1a;大数据技术及其应用 全面解析&#xff1a;大数据技术及其应用 全面解析&#xff1a;大数据技术及其应用 大…

七次课掌握 Photoshop:基础与入门

Photoshop 是 Adobe 公司开发的功能强大的图像处理软件&#xff0c;被广泛应用于平面设计、网页设计、摄影后期处理、UI 设计等多个领域。 ◆ ◆ ◆ Photoshop 中的核心概念 一、像素 像素&#xff08;Pixel&#xff09;是组成数字图像的基本单位&#xff0c;如同组成人体的细…

G2 基于生成对抗网络(GAN)人脸图像生成

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 基于生成对抗网络&#xff08;GAN&#xff09;人脸图像生成 这周将构建并训练一个生成对抗网络&#xff08;GAN&#xff09;来生成人脸图像。 GAN 原理概述 …

N-155基于springboot,vue宿舍管理系统

开发工具&#xff1a;IDEA 服务器&#xff1a;Tomcat9.0&#xff0c; jdk1.8 项目构建&#xff1a;maven 数据库&#xff1a;mysql5.7 项目采用前后端分离 前端技术&#xff1a;vue3element-plus 服务端技术&#xff1a;springbootmybatis-plus 本项目分为学生、宿舍管理…

友思特应用 | FantoVision边缘计算:多模态传感+AI算法=新型非接触式医疗设备

导读 基于多模态传感技术和先进人工智能技术可有效提升乳腺癌检测的精准性、性价比和效率。友思特 FantoVision 边缘计算机 则为其生物组织数据的高效传输和实时分析提供了坚实基础。 乳腺癌的新型医疗检测方式 乳腺癌是女性面临的最令人担忧的健康问题之一&#xff0c;早期发…