ROCm上来自Transformers的双向编码器表示(BERT)

14.8. 来自Transformers的双向编码器表示(BERT) — 动手学深度学习 2.0.0 documentation (d2l.ai)

代码

import torch
from torch import nn
from d2l import torch as d2l

#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

#@save
class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

#@save
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状:(batchsize,num_hiddens)
        return self.output(X)

encoded_X = torch.flatten(encoded_X, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape

nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape

#@save
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

代码解析

这段代码是基于PyTorch框架实现的BERT(Bidirectional Encoder Representations from Transformers)模型。BERT是一种预训练语言表示模型,它可以用于各种自然语言处理(NLP)任务。下面是代码的中文解析:
1. get_tokens_and_segments(tokens_a, tokens_b=None) 函数用于获取输入句子的词元(tokens)及其对应的片段索引。如果有第二个句子 tokens_b,则会进行拼接,并用不同的索引来标识不同的句子。
2. BERTEncoder 类定义了BERT的编码器结构,它包含嵌入层(用于将词元转换为向量表示)、位置嵌入和多个Transformer编码块。
3. forward 方法定义了模型的前向传播逻辑。它将输入的词元和片段索引通过编码器进行编码,并返回编码后的向量表示。
4. 其中 tokens 是批量输入数据的词元索引,`segments` 是对应的片段索引,这里模拟了输入数据作为模型的示例。
5. 创建一个 BERTEncoder 实例,该实例就是BERT模型的编码器部分,类似于 Transformer 模型中的编码器层。
6. MaskLM 类未在代码中定义,通常用来实现BERT的掩码语言模型任务,它在一定比例的输入词元上应用掩码,并训练模型来预测这些被掩码的词元。
7. NextSentencePred 类定义了BERT的下一句预测(Next Sentence Prediction, NSP)任务,是一个简单的二分类器,用来预测给定的两个句子片段是否在原始文本中顺序相邻。
8. BERTModel 类将编码器、掩码语言模型(MaskLM),以及下一句预测(NSP)整合为完整的BERT模型。它通过前向传播来处理输入,同时能够根据需求进行掩码语言模型预测和下一句预测。
9. 模型实例化后,通过随机生成的 tokens 和 segments 调用其 forward 方法,得到编码后的向量 encoded_X,同时执行MLM和NSP任务,输出预测结果。
10. 最后计算MLM和NSP任务的损失,这些损失通常用于训练模型。`CrossEntropyLoss` 是在类别预测问题中经常使用的一个损失函数。
整体来看,这段代码展示了如何构建一个基于BERT结构的模型,其中涵盖了BERT的两个典型预训练任务:掩码语言模型和下一句预测。需要注意的是,这个代码片段作为一个解析,但实际中运行它需要额外的上下文(例如 MaskLM 类的实现)和适当的数据准备和预处理步骤。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/647381.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

html中被忽略的简单标签

1&#xff1a; alt的作用是在图片不能显示时的提示信息 <img src"https://img.xunfei.cn/mall/dev/ifly-mall-vip- service/business/vip/common/202404071019208761.jp" alt"提示信息" width"100px" height"100px" /> 2&#…

CTF之Web_python_block_chain

这种题对于我来说只能看大佬的wp&#xff08;但是这一题是wp都看不懂&#xff0c;只能表达一下我的理解了&#xff09; &#xff08;最后有简单方法&#xff0c;前面一种没看懂没关系&#xff09; 下面这一部分是首页的有用部分 访问/source_code,得到源码&#xff1a; # -*-…

mysql之递归sql

mysql之递归sql 递归sql在一些公司是不允许使用的&#xff0c;会涉及数据库压力&#xff0c;所以会在代码里递归查询&#xff0c;但有些公司开发流程没有规定&#xff0c;且数据库数据量不大&#xff0c;之前写过好几遍了&#xff0c;老是记不住&#xff0c;记录一下 通过父级…

LiveGBS流媒体平台GB/T28181用户手册-版本信息:查看机器码、切换查看流媒体服务

LiveGBS流媒体平台GB/T28181用户手册--版本信息:查看机器码、切换查看流媒体服务 1、版本信息1.1、查看机器码1.2、多个流媒体服务1.3、提交激活 2、搭建GB28181视频直播平台 1、版本信息 版本信息页面&#xff0c;可以查看到信令服务 流媒体服务相关信息&#xff0c;包含硬件…

MySQL--存储引擎

一、存储引擎介绍 1.介绍 存储引擎相当于Linux的文件系统&#xff0c;以插件的模式存在&#xff0c;是作用在表的一种属性 2.MySQL中的存储引擎类型 InnoDB、MyISAM、CSV、Memory 3.InnoDB核心特性的介绍 聚簇索引、事务、MVCC多版本并发控制、行级锁、外键、AHI、主从复制特…

网络安全等级保护:正确配置 Linux

正确配置 Linux 对Linux安全性的深入审查确实是一项漫长的任务。原因之一是Linux设置的多样性。用户可以使用Debian、Red Hat、Ubuntu或其他Linux发行版。有些可能通过shell工作&#xff0c;而另一些则通过某些图形用户界面&#xff08;例如 KDE 或 GNOME&#xff09;工作&…

零基础学Java第二十三天之网络编程Ⅱ

1. InetAddress类 用来表示主机的信息 练习&#xff1a; C:\Windows\system32\drivers\etc\ hosts 一个主机可以放多个个人网站 www.baidu.com/14.215.177.37 www.baidu.com/14.215.177.38 www.taobao.com/183.61.241.252 www.taobao.com/121.14.89.253 2. Socket 3.…

细粒度图像分类论文(AAM模型方法)阅读笔记

细粒度图像分类论文阅读笔记 摘要Abstract1. 用于细粒度图像分类的聚合注意力模块1.1 文献摘要1.2 研究背景1.3 本文创新点1.4 计算机视觉中的注意力机制1.5 模型方法1.5.1 聚合注意力模块1.5.2 通道注意力模块通道注意力代码实现 1.5.3 空间注意力模块空间注意力代码实现 1.5.…

Superset,基于浏览器的开源BI工具

BI工具是数据分析的得力武器&#xff0c;目前市场上有很多BI软件&#xff0c;众所周知的有Tableau、PowerBI、Qlikview、帆软等&#xff0c;其中大部分是收费软件或者部分功能收费。这些工具一通百通&#xff0c;用好一个就够了&#xff0c;重要的是分析思维。 我一直用的Tabl…

【数据结构/C语言】深入理解 双向链表

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;数据结构与算法 在阅读本篇文章之前&#xff0c;您可能需要用到这篇关于单链表详细介绍的文章 【数据结构/C语言】深入理解 单链表…

python内置函数map/filter/reduce详解

在Python中&#xff0c;map(), filter(), 和 reduce() 是内置的高级函数(实际是class)&#xff0c;用于处理可迭代对象&#xff08;如列表、元组等&#xff09;的元素。这些函数通常与lambda函数一起使用&#xff0c;以简洁地表达常见的操作。下面我将分别解释这三个函数。 1. …

echarts-地图

使用地图的三种的方式&#xff1a; 注册地图(用json或svg,注册为地图)&#xff0c;然后使用map地图使用geo坐标系&#xff0c;地图注册后不是直接使用&#xff0c;而是注册为坐标系。直接使用百度地图、高德地图&#xff0c;使用百度地图或高德地图作为坐标系。 用json或svg注…

Selenium 高频面试题及答案

1、什么是 Selenium&#xff1f;它用于做什么&#xff1f; Selenium 是一个用于自动化测试的开源框架。它提供了多种工具和库&#xff0c;用于模拟用户在不同浏览器和操作系统上的行为&#xff0c;并且可用于测试网页应用程序。 2、Selenium WebDriver 和 Selenium IDE 有何区…

【机器学习300问】100、怎么理解卷积神经网络CNN中的池化操作?

一、什么是池化&#xff1f; 卷积神经网络&#xff08;CNN&#xff09;中的池化&#xff08;Pooling&#xff09;操作是一种下采样技术&#xff0c;其目的是减少数据的空间维度&#xff08;宽度和高度&#xff09;&#xff0c;同时保持最重要的特征并降低计算复杂度。池化操作不…

JavaWeb_Web——Maven

介绍&#xff1a; Maven是Apache公司发行的&#xff0c;一个Java项目管理和构建工具 作用&#xff1a; 1.方便的依赖管理 2.统一的项目结构 3.标准的项目构建流程 模型&#xff1a; Maven通过项目对象模型(POM)和依赖管理模型(Dependency)管理依赖(jar包)&#xff0c;如果新添…

新闻稿海外媒体投稿,除了美联社发稿(AP)和彭博社宣发(Bloomberg),还有哪些优质的国外媒体平台可以选择

发布高质量的新闻稿到海外媒体&#xff0c;除了美联社发稿&#xff08;AP&#xff09;和彭博社发稿&#xff08;Bloomberg&#xff09;&#xff0c;还有许多其他优质的媒体平台可以选择。以下是一些受欢迎和高效的海外媒体发布平台&#xff1a; 路透社 (Reuters) 路透社是全球最…

HILL密码

一&#xff1a;简介 Hill密码又称希尔密码是运用基本矩阵论原理的替换密码&#xff0c;属于多表代换密码的一种&#xff0c;由L e s t e r S . H i l l Lester S. HillLesterS.Hill在1929年发明。 二&#xff1a;原理 1.对于每一个字母&#xff0c;我们将其转化为对应的数字&am…

[Android]联系人-删除修改

界面显示 添加按钮点击&#xff0c;holder.imgDelete.setlog();具体代码 public MyViewHolder onCreateViewHolder(NonNull ViewGroup parent, int viewType) {//映射布局文件&#xff0c;生成相应的组件View v LayoutInflater.from(parent.getContext()).inflate(R.layout.d…

[ C++ ] 类和对象( 中 ) 2

目录 前置和后置重载 运算符重载和函数重载 流插入流提取的重载 全局函数访问类私有变量 友员 const成员 取地址及const取地址操作符重载 前置和后置重载 运算符重载和函数重载 流插入流提取的重载 重载成成员函数会出现顺序不同的情况&#xff08;函数重载形参顺序必须相…

渗透工具CobaltStrike工具的下载和安装

一、CobalStrike简介 Cobalt Strike(简称为CS)是一款基于java的渗透测试工具&#xff0c;专业的团队作战的渗透测试工具。CS使用了C/S架构&#xff0c;它分为客户端(Client)和服务端(Server)&#xff0c;服务端只要一个&#xff0c;客户端可有多个&#xff0c;多人连接服务端后…