从零实现GPT【1】——BPE

文章目录

  • Embedding 的原理
  • 训练
  • 特殊 token 处理和保存
  • 编码
  • 解码
  • 完整代码

BPE,字节对编码

Embedding 的原理

image.png

  • 简单来说就是查表
# 解释embedding
from torch.nn import Embedding
import torch


# 标准的正态分布初始化 也可以用均匀分布初始化
emb = Embedding(10, 32)
res = emb(torch.tensor([[
    0, 1, 2]
]))
print(res.shape)  # torch.Size([1, 3, 32]) [batch, seq_len, dim]
  • 自己实现
# 解释embedding
from torch.nn import Embedding, Parameter, Module
import torch


class MyEmbing(Module):
    def __init__(self, vocab_size, dim):
        super().__init__()
        self.emb_matrix = Parameter(torch.randn(vocab_size, dim))
        # Parameter标记self.emb_matrix需要被训练

    def forward(self, ids):
        return self.emb_matrix[ids]  # 取索引这个操作可以反向传播


emb = MyEmbedding(10, 32)
res = emb(torch.tensor([[
    0, 1, 2]
]))
print(res.shape)  # torch.Size([1, 3, 32]) [batch, seq_len, dim]

训练

  1. 初始化词表,一般是 0-255 个 ASCII 编码
  2. 设置词表大小 Max_size
  3. 循环统计相邻两个字节的频率,取最高的合并后作为新的 token 加入到词表中
  4. 合并新的 token
  5. 重复 c、d,直到词表大小到Max_size 或者 没有更多的相邻 token
class BPETokenizer:
    def __init__(self):
        self.b2i = OrderedDict()  # bytes to id
        self.i2b = OrderedDict()  # id to bytes (b2i的反向映射)
        self.next_id = 0

        # special token
        self.sp_s2i = {}  # str to id
        self.sp_i2s = {}  # id to str

    # 相邻token统计
    def _pair_stats(self, tokens, stats):
        for i in range(len(tokens)-1):
            new_token = tokens[i]+tokens[i+1]
            if new_token not in stats:
                stats[new_token] = 0
            stats[new_token] += 1

    # 合并相邻token
    def _merge_pair(self, tokens, new_token):
        merged_tokens = []

        i = 0
        while i < len(tokens):
            if i+1 < len(tokens) and tokens[i]+tokens[i+1] == new_token:
                merged_tokens.append(tokens[i]+tokens[i+1])
                i += 2
            else:
                merged_tokens.append(tokens[i])
                i += 1
        return merged_tokens

    def train(self, text_list, vocab_size):
        # 单字节是最基础的token,初始化词表
        for i in range(256):
            self.b2i[bytes([i])] = i
        self.next_id = 256

        # 语料转byte
        tokens_list = []
        for text in text_list:
            tokens = [bytes([b]) for b in text.encode('utf-8')]
            tokens_list.append(tokens)

        # 进度条
        progress = tqdm(total=vocab_size, initial=256)

        while True:
            # 词表足够大了,退出训练
            if self.next_id >= vocab_size:
                break

            # 统计相邻token频率
            stats = {}
            for tokens in tokens_list:
                self._pair_stats(tokens, stats)

            # 没有更多相邻token, 无法生成更多token,退出训练
            if not stats:
                break

            # 合并最高频的相邻token,作为新的token加入词表
            new_token = max(stats, key=stats.get)

            new_tokens_list = []
            for tokens in tokens_list:
                # self._merge_pair(tokens, new_token) -> list
                new_tokens_list.append(self._merge_pair(tokens, new_token))
            tokens_list = new_tokens_list

            # new token加入词表
            self.b2i[new_token] = self.next_id
            self.next_id += 1

            # 刷新进度条
            progress.update(1)

        self.i2b = {v: k for k, v in self.b2i.items()}

特殊 token 处理和保存

  • 特殊 token 加到词表中
tokenizer = BPETokenizer()

# 特殊token
tokenizer.add_special_tokens(
    (['<|im_start|>', '<|im_end|>', '<|endoftext|>', '<|padding|>']))

# 特殊token
def add_special_tokens(self, special_tokens):
    for token in special_tokens:
        if token not in self.sp_s2i:
            self.sp_s2i[token] = self.next_id
            self.sp_i2s[self.next_id] = token
            self.next_id += 1
  • 保存和加载
tokenizer.save('tokenizer.bin')

def save(self, file):
    with open(file, 'wb') as fp:
        fp.write(pickle.dumps((self.b2i, self.sp_s2i, self.next_id)))

# 还原
tokenizer = BPETokenizer()
tokenizer.load('tokenizer.bin')
print('vocab size:', tokenizer.vocab_size())


def load(self, file):
    with open(file, 'rb') as fp:
        self.b2i, self.sp_s2i, self.next_id = pickle.loads(fp.read())
    self.i2b = {v: k for k, v in self.b2i.items()}
    self.sp_i2s = {v: k for k, v in self.sp_s2i.items()}

编码

  1. 分离特殊 token,用于直接映射特殊 token
  2. 进行编码,特殊 token 直接编码就好,普通 token 继续

while True:

  1. 对于普通 token, 统计相邻 token 频率
  2. 选择合并后的 id 最小的 pair token 合并(也就是优先合并短的)
  3. 重复 c d,直到没有合并的 pair token

不断循环 token,统计相邻 token 的频率,取 id 最小的 pair 进行合并,从而可以拼接成更大的 id

# 编码
ids, tokens = tokenizer.encode(
    '<|im_start|>system\nyou are a helper assistant\n<|im_end|>\n<|im_start|>user\n今天的天气\n<|im_end|><|im_start|>assistant\n')
print('encode:', ids, tokens)
'''
encode: 
[300, 115, 121, 115, 116, 101, 109, 10, 121, 111, 117, 32, 97, 114, 276, 97, 32, 104, 101, 108, 112, 101, 293, 97, 115, 115, 105, 115, 116, 97, 110, 116, 10, 301, 10, 300, 117, 115, 101, 114, 10, 265, 138, 266, 169, 261, 266, 169, 230, 176, 148, 10, 301, 300, 97, 115, 115, 105, 115, 116, 97, 110, 116, 10] 

[b'<|im_start|>', b's', b'y', b's', b't', b'e', b'm', b'\n', b'y', b'o', b'u', b' ', b'a', b'r', b'e ', b'a', b' ', b'h', b'e', b'l', b'p', b'e', b'r ', b'a', b's', b's', b'i', b's', b't', b'a', b'n', b't', b'\n', b'<|im_end|>', b'\n', b'<|im_start|>', b'u', b's', b'e', b'r', b'\n', b'\xe4\xbb', b'\x8a', b'\xe5\xa4', b'\xa9', b'\xe7\x9a\x84', b'\xe5\xa4', b'\xa9', b'\xe6', b'\xb0', b'\x94', b'\n', b'<|im_end|>', b'<|im_start|>', b'a', b's', b's', b'i', b's', b't', b'a', b'n', b't', b'\n']
'''

'''
在Python中,Unicode字符通常以"\x"开头,后面跟着两个十六进制数字,或者以"\u"开头,后面跟着四个十六进制数字。
'''


def encode(self, text):
    # 特殊token分离
    pattern = '('+'|'.join([re.escape(tok) for tok in self.sp_s2i])+')'
    splits = re.split(pattern, text)  # [ '<|im_start|>', 'user', '<||>' ]

    # 编码结果
    enc_ids = []
    enc_tokens = []
    for sub_text in splits:
        if sub_text in self.sp_s2i:  # 特殊token,直接对应id
            enc_ids.append(self.sp_s2i[sub_text])
            enc_tokens.append(sub_text.encode('utf-8'))
        else:
            tokens = [bytes([b]) for b in sub_text.encode('utf-8')]
            while True:
                # 统计相邻token频率
                stats = {}
                self._pair_stats(tokens, stats)

                # 选择合并后id最小的pair合并(也就是优先合并短的)
                new_token = None
                for merge_token in stats:
                    if merge_token in self.b2i and (new_token is None or self.b2i[merge_token] < self.b2i[new_token]):
                        new_token = merge_token

                # 没有可以合并的pair,退出
                if new_token is None:
                    break

                # 合并pair
                tokens = self._merge_pair(tokens, new_token)
            enc_ids.extend([self.b2i[tok] for tok in tokens])
            enc_tokens.extend(tokens)
    return enc_ids, enc_tokens

解码

# 解码
s = tokenizer.decode(ids)
print('decode:', s)
'''
decode: 
<|im_start|>system
you are a helper assistant
<|im_end|>
<|im_start|>user
今天的天气
<|im_end|><|im_start|>assistant
'''

def decode(self, ids):
    bytes_list = []
    for id in ids:
        if id in self.sp_i2s:
            bytes_list.append(self.sp_i2s[id].encode('utf-8'))
        else:
            bytes_list.append(self.i2b[id])  # self.i2b 这里已经是字节了 id to byte 
    return b''.join(bytes_list).decode('utf-8', errors='replace')

完整代码

from collections import OrderedDict
import pickle
import re
from tqdm import tqdm

# Byte-Pair Encoding tokenization


class BPETokenizer:
    def __init__(self):
        self.b2i = OrderedDict()  # bytes to id
        self.i2b = OrderedDict()  # id to bytes (b2i的反向映射)
        self.next_id = 0

        # special token
        self.sp_s2i = {}  # str to id
        self.sp_i2s = {}  # id to str

    # 相邻token统计
    def _pair_stats(self, tokens, stats):
        for i in range(len(tokens)-1):
            new_token = tokens[i]+tokens[i+1]
            if new_token not in stats:
                stats[new_token] = 0
            stats[new_token] += 1

    # 合并相邻token
    def _merge_pair(self, tokens, new_token):
        merged_tokens = []

        i = 0
        while i < len(tokens):
            if i+1 < len(tokens) and tokens[i]+tokens[i+1] == new_token:
                merged_tokens.append(tokens[i]+tokens[i+1])
                i += 2
            else:
                merged_tokens.append(tokens[i])
                i += 1
        return merged_tokens

    def train(self, text_list, vocab_size):
        # 单字节是最基础的token,初始化词表
        for i in range(256):
            self.b2i[bytes([i])] = i
        self.next_id = 256

        # 语料转byte
        tokens_list = []
        for text in text_list:
            tokens = [bytes([b]) for b in text.encode('utf-8')]
            tokens_list.append(tokens)

        # 进度条
        progress = tqdm(total=vocab_size, initial=256)

        while True:
            # 词表足够大了,退出训练
            if self.next_id >= vocab_size:
                break

            # 统计相邻token频率
            stats = {}
            for tokens in tokens_list:
                self._pair_stats(tokens, stats)

            # 没有更多相邻token, 无法生成更多token,退出训练
            if not stats:
                break

            # 合并最高频的相邻token,作为新的token加入词表
            new_token = max(stats, key=stats.get)

            new_tokens_list = []
            for tokens in tokens_list:
                # self._merge_pair(tokens, new_token) -> list
                new_tokens_list.append(self._merge_pair(tokens, new_token))
            tokens_list = new_tokens_list

            # new token加入词表
            self.b2i[new_token] = self.next_id
            self.next_id += 1

            # 刷新进度条
            progress.update(1)

        self.i2b = {v: k for k, v in self.b2i.items()}

    # 词表大小
    def vocab_size(self):
        return self.next_id

    # 词表
    def vocab(self):
        v = {}
        v.update(self.i2b)
        v.update({id: token.encode('utf-8')
                 for id, token in self.sp_i2s.items()})
        return v

    # 特殊token
    def add_special_tokens(self, special_tokens):
        for token in special_tokens:
            if token not in self.sp_s2i:
                self.sp_s2i[token] = self.next_id
                self.sp_i2s[self.next_id] = token
                self.next_id += 1

    def encode(self, text):
        # 特殊token分离
        pattern = '('+'|'.join([re.escape(tok) for tok in self.sp_s2i])+')'
        splits = re.split(pattern, text)  # [ '<|im_start|>', 'user', '<||>' ]

        # 编码结果
        enc_ids = []
        enc_tokens = []
        for sub_text in splits:
            if sub_text in self.sp_s2i:  # 特殊token,直接对应id
                enc_ids.append(self.sp_s2i[sub_text])
                enc_tokens.append(sub_text.encode('utf-8'))
            else:
                tokens = [bytes([b]) for b in sub_text.encode('utf-8')]
                while True:
                    # 统计相邻token频率
                    stats = {}
                    self._pair_stats(tokens, stats)

                    # 选择合并后id最小的pair合并(也就是优先合并短的)
                    new_token = None
                    for merge_token in stats:
                        if merge_token in self.b2i and (new_token is None or self.b2i[merge_token] < self.b2i[new_token]):
                            new_token = merge_token

                    # 没有可以合并的pair,退出
                    if new_token is None:
                        break

                    # 合并pair
                    tokens = self._merge_pair(tokens, new_token)
                enc_ids.extend([self.b2i[tok] for tok in tokens])
                enc_tokens.extend(tokens)
        return enc_ids, enc_tokens

    def decode(self, ids):
        bytes_list = []
        for id in ids:
            if id in self.sp_i2s:
                bytes_list.append(self.sp_i2s[id].encode('utf-8'))
            else:
                bytes_list.append(self.i2b[id])  # self.i2b 这里已经是字节了 id to byte
        return b''.join(bytes_list).decode('utf-8', errors='replace')

    def save(self, file):
        with open(file, 'wb') as fp:
            fp.write(pickle.dumps((self.b2i, self.sp_s2i, self.next_id)))

    def load(self, file):
        with open(file, 'rb') as fp:
            self.b2i, self.sp_s2i, self.next_id = pickle.loads(fp.read())
        self.i2b = {v: k for k, v in self.b2i.items()}
        self.sp_i2s = {v: k for k, v in self.sp_s2i.items()}


if __name__ == '__main__':
    # 加载语料
    cn = open('dataset/train-cn.txt', 'r').read()
    en = open('dataset/train-en.txt', 'r').read()

    # 训练
    tokenizer = BPETokenizer()
    tokenizer.train(text_list=[cn, en], vocab_size=300)

    # 特殊token
    tokenizer.add_special_tokens(
        (['<|im_start|>', '<|im_end|>', '<|endoftext|>', '<|padding|>']))

    # 保存
    tokenizer.save('tokenizer.bin')

    # 还原
    tokenizer = BPETokenizer()
    tokenizer.load('tokenizer.bin')
    print('vocab size:', tokenizer.vocab_size())

    # 编码
    ids, tokens = tokenizer.encode(
        '<|im_start|>system\nyou are a helper assistant\n<|im_end|>\n<|im_start|>user\n今天的天气\n<|im_end|><|im_start|>assistant\n')
    print('encode:', ids, tokens)

    # 解码
    s = tokenizer.decode(ids)
    print('decode:', s)

    # 打印词典
    print('vocab:', tokenizer.vocab())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/735013.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

探索Agent AI智能体的未来

随着人工智能&#xff08;AI&#xff09;技术的飞速发展&#xff0c;Agent AI智能体正成为一种改变世界的新力量。这些智能体不仅在当前的技术领域中发挥着重要作用&#xff0c;而且在未来将以更深远的影响改变我们的生活、工作和社会结构。本文将探讨Agent AI智能体的现状、潜…

回顾今年的618大战:除了卷低价,还有别的出路吗?

今年的618刚刚落下帷幕&#xff0c;大促期间&#xff0c;一些电商平台纷纷备足马力、迎接挑战&#xff0c;反倒是一向领跑的淘宝京东公开表示&#xff0c;今年取消了618预售制。 互联网电商20年来&#xff0c;每年618、双11轮流登场&#xff0c;“低价大战”愈演愈烈&#xff0…

【C++】类和对象2.0

俺来写笔记了&#xff0c;哈哈哈&#xff0c;浅浅介绍类和对象的知识点&#xff01; 1.类的6个默认成员函数 俺们定义一个空类&#xff1a; class N {}; 似乎这个类N里面什么都没有&#xff0c;其实不是这样子的。这个空类有6个默认的成员函数 。 默认成员函数&#xff1a…

Android 你应该知道的学习资源 进阶之路贵在坚持

coderzheaven 覆盖各种教程&#xff0c;关于Android基本时案例驱动的方式。 非常推荐 thenewcircle 貌似是个培训机构&#xff0c;多数是收费的&#xff0c;不过仍然有一些free resources值得你去挖掘。 coreservlets 虽然主打不是android&#xff0c;但是android的教程也​ 是…

【前端技术】标签页通讯localStorage、BroadcastChannel、SharedWorker的技术详解

&#x1f604; 19年之后由于某些原因断更了三年&#xff0c;23年重新扬帆起航&#xff0c;推出更多优质博文&#xff0c;希望大家多多支持&#xff5e; &#x1f337; 古之立大事者&#xff0c;不惟有超世之才&#xff0c;亦必有坚忍不拔之志 &#x1f390; 个人CSND主页——Mi…

MySQL之复制(十二)

复制 复制的问题和解决方案 未定义的服务器ID 如果没有在my.cnf里面定义服务器ID,可以通过CHANGE MASTER TO 来设置备库&#xff0c;但却无法启动复制。 mysql>START SLAVE; ERROR 1200(HY000):The server is not configured as slave;fix in config file or with CHANG…

实验13 简单拓扑BGP配置

实验13 简单拓扑BGP配置 一、 原理描述二、 实验目的三、 实验内容四、 实验配置五、 实验步骤 一、 原理描述 BGP&#xff08;Border Gateway Protocol&#xff0c;边界网关协议&#xff09;是一种用于自治系统间的动态路由协议&#xff0c;用于在自治系统&#xff08;AS&…

汇聚荣做拼多多运营怎么样?

汇聚荣做拼多多运营怎么样?在电商行业竞争日益激烈的今天&#xff0c;拼多多作为一家迅速崛起的电商平台&#xff0c;吸引了众多商家入驻。对于汇聚荣这样的企业而言&#xff0c;选择在拼多多上进行商品销售和品牌推广&#xff0c;无疑需要一套高效的运营策略。那么&#xff0…

技术师增强版,系统级别的工具!【不能用】

数据安全是每位计算机用户都关心的重要问题。在日常使用中&#xff0c;我们经常面临文件丢失、系统崩溃或病毒感染等风险。为了解决这些问题&#xff0c;我们需要可靠且高效的数据备份与恢复工具。本文将介绍一款优秀的备份软件&#xff1a;傲梅轻松备份技术师增强版&#xff0…

【MySQL数据库】:MySQL视图特性

目录 视图的概念 基本使用 准备测试表 创建视图 修改视图影响基表 修改基表影响视图 删除视图 视图规则和限制 视图的概念 视图是一个虚拟表&#xff0c;其内容由查询定义&#xff0c;同真实的表一样&#xff0c;视图包含一系列带有名称的列和行数据。视图中的数据…

地下管线管网三维建模系统MagicPipe3D

地下管网是保障城市运行的基础设施和“生命线”。随着实景三维中国建设的推进&#xff0c;构建地下管网三维模型与地上融合的数字孪生场景&#xff0c;对于提升智慧城市管理至关重要&#xff01;针对现有三维管线建模数据差异大、建模交互弱、模型效果差、缺乏语义信息等缺陷&a…

多功能投票系统(ThinkPHP+FastAdmin+Uniapp)

让决策更高效&#xff0c;更民主&#x1f31f; ​基于ThinkPHPFastAdminUniapp开发的多功能系统&#xff0c;支持图文投票、自定义选手报名内容、自定义主题色、礼物功能(高级授权)、弹幕功能(高级授权)、会员发布、支持数据库私有化部署&#xff0c;Uniapp提供全部无加密源码…

Android MVP模式 入门

View&#xff1a;对应于布局文件 Model&#xff1a;业务逻辑和实体模型 Controllor&#xff1a;对应于Activity 看起来的确像那么回事&#xff0c;但是细细的想想这个View对应于布局文件&#xff0c;其实能做的事情特别少&#xff0c;实际上关于该布局文件中的数据绑定的操…

高通安卓12-安卓系统定制2

将开机动画打包到system.img里面 在目录device->qcom下面 有lito和qssi两个文件夹 现在通过QSSI的方式创建开机动画&#xff0c;LITO方式是一样的 首先加入自己的开机动画&#xff0c;制作过程看前面的部分 打开qssi.mk文件&#xff0c;在文件的最后加入内容 PRODUCT_CO…

【SSM】医疗健康平台-管理端-检查组管理

技能目标 掌握新增检查组功能的实现 掌握查询检查组功能的实现 掌握编辑检查组功能的实现 掌握删除检查组功能的实现 体检的检查项种类繁多&#xff0c;为了方便管理和快速筛选出类别相同的检查项&#xff0c;医疗健康将类别相同的检查项放到同一个检查组中进行管理&#…

ANR灵魂拷问:四大组件中的onCreate-onReceive方法中Thread-sleep(),会产生几个ANR-

findViewById(R.id.btn).setOnClickListener(new View.OnClickListener() { Override public void onClick(View v) { sleepTest(); } }); sleepTest方法详情 public void sleepTest(){ new Handler().postDelayed(new Runnable() { Override public void run() { Button but…

<Rust><iced>在iced中显示gif动态图片的一种方法

前言 本文是在rust的GUI库iced中在窗口显示动态图片GIF格式图片的一种方法。 环境配置 系统&#xff1a;window 平台&#xff1a;visual studio code 语言&#xff1a;rust 库&#xff1a;iced、image 概述 在iced中&#xff0c;提供了image部件&#xff0c;从理论上说&…

软考 系统架构设计师系列知识点之杂项集萃(44)

接前一篇文章&#xff1a;软考 系统架构设计师系列知识点之杂项集萃&#xff08;43&#xff09; 第71题 设有员工实体Employee&#xff08;员工号&#xff0c;姓名&#xff0c;性别&#xff0c;年龄&#xff0c;电话&#xff0c;家庭住址&#xff0c;家庭成员&#xff0c;关系…

自动驾驶⻋辆环境感知:多传感器融合

目录 一、多传感器融合技术概述 二、基于传统方法的多传感器融合 三、基于深度学习的视觉和LiDAR的目标级融合 四、基于深度学习的视觉和LiDAR数据的前融合方法 概念介绍 同步和配准 时间同步 标定 摄像机内参标定&#xff08;使用OpenCV&#xff09; 摄像机与LiDAR外…

【FreeRTOS】任务状态改进播放控制

这里写目录标题 1 任务状态1.1 阻塞状态(Blocked)1.2 暂停状态(Suspended)1.3 就绪状态(Ready)1.4 完整的状态转换图 2 举个例子3 编写代码 参考《FreeRTOS入门与工程实践(基于DshanMCU-103).pdf》 本节课实现音乐任务的创建&#xff0c;音乐播放的暂停与继续播放&#xff0c;删…