Transformer从菜鸟到新手(一)

引言

这是从Transformer到LLM(大语言模型)系列的第一篇文章,几乎所有的大语言模型都是基于Transformer结构,因此本文回顾一下Transformer的原理与实现细节,包括分词算法BPE的实现。最终利用从零实现的Transformer模型进行英中翻译。

本文主要介绍Transformer中用到的子词分词算法BPE,在文章的最后你会自己实现一个BPE算法。


下篇文章介绍Transformer的核心模块——多头注意力以及位置编码的实现。

Transformer模型

202312060211

Transformer模型如上图所示,它也属于编码器-解码器架构,左边是编码器,右边是解码器。它们都由N个Transformer Block(块)组成。编码器的输出作为解码器多头注意力的Key和Value,Query来自解码器输入经过掩码多头注意力后的结果。

主要涉及以下几个模块:

  • 嵌入表示(Input/Output Embedding) 将每个标记(token)转换为对应的向量表示。

  • 位置编码(Positional Encoding) 由于没有时序信息,需要额外加入位置编码。

  • 多头注意力(Multi-Head Attention) 利用注意力机制对输入进行交互,得到具有上下文信息的表示。根据其所处的位置有不同的变种:邻接解码器嵌入位置是掩码多头注意力,特点是当前位置只能注意本身以及之前位置的信息;掩码多头注意力紧接的多头注意力特点是Key和Value来自编码器的输出,而Query来自底层的输出,目的是在计算输出时考虑输入信息。

  • 层归一化(Layer Normalization) 作用于Transformer块内部子层的输出表示上,对表示序列进行层归一化。

  • 残差连接(Residual connection) 作用于Transformer块内部子层输入和输出上,对其输入和输出进行求和。

  • 位置感知前馈网络(Position-wise Feedforward Network) 通过多层全连接网络对表示序列进行非线性变换,提升模型的表达能力。

下面依次介绍各模块的功能和实现。

嵌入表示

对于输入文本序列,首先要进行分词,而Transformer所采用的分词算法为更能解决OOV(Out of Vocabulary)问题的子词分词算法——BPE(Byte Pair Encoding)。分词之后得到子词标记,当成我们常用的标记使用,来构建词表。最后基于词表大小创建一个对应的嵌入层。

嵌入层的实现如下:

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int) -> None:
        """

        Args:
            vocab_size (int): size of vocabulary
            d_model (int): dimension of embeddings
        """
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)
        self.sqrt_d_model = math.sqrt(d_model)

    def forward(self, x: Tensor) -> Tensor:
        """

        Args:
            x (Tensor): (batch_size, seq_length)

        Returns:
            Tensor: (batch_size, seq_length, d_model)
        """
        # multiply embedding values by the square root of the embedding dimension
        return self.embed(x) * self.sqrt_d_model

vocab_size是词表大小;d_model是嵌入层的维度,如原始论文所说,会用嵌入结果与 d model \sqrt{\text{d}_{\text{model}}} dmodel 相乘。

我们重点来看BPE分词算法的实现。

BPE

BPE分词算法主要由以下四个步骤组成:

  1. 标准化
  2. 预分词
  3. 单词拆分成字符
  4. 根据学习到的规则应用到这些拆分

标准化对输入的单词进行一些预处理,包括大小写转换、(英文)复数转换为单数等。我们这里简单处理,只进行大小写转换。

预分词做的是将文本序列按照某一规则拆分最小单元——单词,我们这里英文按空格拆分,中文通过jieba进行分词。

将单词拆分成字符比较简单,英文就拆分成字母;中文拆分成字。

最后利用BPE学习到的合并规则来对拆分后的结果进行合并,得到我们想要的子词(subtoken)。

因此BPE算法主要通过训练算法学习的是这些合并规则。

BPE训练算法的步骤如下:

  1. 初始化语料库
  2. 将语料库中每个单词拆分成字符作为子词,并在单词结尾增加一个</w>字符
  3. 将拆分后的子词构成初始子词词表
  4. 在语料库中统计单词内相邻子词对的频次
  5. 合并频次最高的子词对,合并成新的子词,并将新的子词加入到子词词表
  6. 重复步骤4和5直到进行了设定的合并次数或达到了设定的子词词表大小

我们以参考2中的例子为例介绍BPE算法的实现,注意这只是为了让我们更好地理解原理,实际应用中应该直接使用🤗Tokenizer库。

以下是由几句话构成的语料库:

corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

然后我们执行上面所说的标准化和预分词过程:

for sentence in corpus:
    sentence = sentence.lower()
    print([c for c in jieba.cut(sentence) if c != " "])
['this', 'is', 'the', 'hugging', 'face', 'course', '.']
['this', 'chapter', 'is', 'about', 'tokenization', '.']
['this', 'section', 'shows', 'several', 'tokenizer', 'algorithms', '.']
['hopefully', ',', 'you', 'will', 'be', 'able', 'to', 'understand', 'how', 'they', 'are', 'trained', 'and', 'generate', 'tokens', '.']

为了简单起见,我们可以直接用jieba对其进行分词,它能帮我们把单词和标点符号拆开,但是它会把空格也拆分出来,我们直接过滤掉空格。

有了这些结果,我们就可以统计每个单词出现的频次:

word_freqs = defaultdict(int)

for sentence in corpus:
    sentence = sentence.lower()
    words = [w for w in jieba.cut(sentence) if w != " "]
    for word in words:
        word_freqs[word] += 1

print(word_freqs)
defaultdict(<class 'int'>, {'this': 3, 'is': 2, 'the': 1, 'hugging': 1, 'face': 1, 'course': 1, '.': 4, 'chapter': 1, 'about': 1, 'tokenization': 1, 'section': 1, 'shows': 1, 'several': 1, 'tokenizer': 1, 'algorithms': 
1, 'hopefully': 1, ',': 1, 'you': 1, 'will': 1, 'be': 1, 'able': 1, 'to': 1, 'understand': 1, 'how': 1, 'they': 1, 'are': 1, 'trained': 1, 'and': 1, 'generate': 1, 'tokens': 1})

下面我们将语料库中每个单词拆分成字符作为子词,构建初始子词词表:

alphabet = []
for word in word_freqs.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)

alphabet.sort()
print(alphabet)
[',', '.', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z']

我们在词表中加入模型要使用的特殊标记,这里有<PAD><BOS><EOS><UNK>,分别表示填充、句子开头、句子结尾和未知词。

vocab = ["<PAD>", "<UNK>", "<BOS>", "<EOS>"] + alphabet.copy()

初始词表构建好了之后,我们将语料库中每个单词拆分成字符作为子词:

splits = {word: [c for c in word] for word in word_freqs.keys()}
pprint(splits)
{',': [','],
 '.': ['.'],
 'able': ['a', 'b', 'l', 'e'],
 'about': ['a', 'b', 'o', 'u', 't'],
 'algorithms': ['a', 'l', 'g', 'o', 'r', 'i', 't', 'h', 'm', 's'],
 'and': ['a', 'n', 'd'],
 'are': ['a', 'r', 'e'],
 'be': ['b', 'e'],
 'chapter': ['c', 'h', 'a', 'p', 't', 'e', 'r'],
 'course': ['c', 'o', 'u', 'r', 's', 'e'],
 'face': ['f', 'a', 'c', 'e'],
 'generate': ['g', 'e', 'n', 'e', 'r', 'a', 't', 'e'],
 'hopefully': ['h', 'o', 'p', 'e', 'f', 'u', 'l', 'l', 'y'],
 'how': ['h', 'o', 'w'],
 'hugging': ['h', 'u', 'g', 'g', 'i', 'n', 'g'],
 'is': ['i', 's'],
 'section': ['s', 'e', 'c', 't', 'i', 'o', 'n'],
 'several': ['s', 'e', 'v', 'e', 'r', 'a', 'l'],
 'shows': ['s', 'h', 'o', 'w', 's'],
 'the': ['t', 'h', 'e'],
 'they': ['t', 'h', 'e', 'y'],
 'this': ['t', 'h', 'i', 's'],
 'to': ['t', 'o'],
 'tokenization': ['t', 'o', 'k', 'e', 'n', 'i', 'z', 'a', 't', 'i', 'o', 'n'],
 'tokenizer': ['t', 'o', 'k', 'e', 'n', 'i', 'z', 'e', 'r'],
 'tokens': ['t', 'o', 'k', 'e', 'n', 's'],
 'trained': ['t', 'r', 'a', 'i', 'n', 'e', 'd'],
 'understand': ['u', 'n', 'd', 'e', 'r', 's', 't', 'a', 'n', 'd'],
 'will': ['w', 'i', 'l', 'l'],
 'you': ['y', 'o', 'u']}

然后编写一个函数计算每个连续对的频次:

def compute_pair_freqs(splits):
    pari_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        # word拆分后的列表
        split = splits[word]
        # 至少要有2个字符才能合并
        if len(split) == 1:
            continue

        for i in range(len(split) - 1):
            # word中连续的字符
            pair = (split[i], split[i + 1])
            # 累加其频次
            pari_freqs[pair] += freq

    return pari_freqs

pair_freqs = compute_pair_freqs(splits)

for i, key in enumerate(pair_freqs.keys()):
    print(f"{key}: {pair_freqs[key]}")
    if i >= 10:
        break
('t', 'h'): 6
('h', 'i'): 3
('i', 's'): 5
('h', 'e'): 2
('h', 'u'): 1
('u', 'g'): 1
('g', 'g'): 1
('g', 'i'): 1
('i', 'n'): 2
('n', 'g'): 1
('f', 'a'): 1

并且我们查看10个对(未排序)的频次。

现在我们只需要一次遍历就能找到最高频次的对:

    best_pair = None
    max_freq = 0
    for pair, freq in pair_freqs.items():
        if max_freq < freq:
            best_pair = pair
            max_freq = freq

    print(best_pair, max_freq)
('t', 'h') 6

所以第一个合并的是('t', 'h') -> 'th',然后将合并后的子词加入到子词词表,同时合并前的两个字符还在词表中,这样我们的词表扩展了一个标记:

# 学习到的第一条合并规则
merges = {("t", "h"): "th"}
# 加入到词表中
vocab.append("th")

然后别忘了,我们还要应用该合并规则到splits字典中,我们也可以编写一个函数来完成:

def merge_pair(a, b, splits):
    # 合并split中所有的(a,b) -> ab
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            # 如果刚好是a和b
            if split[i] == a and split[i + 1] == b:
                # 合并
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        # 重新赋值给word
        splits[word] = split

    return splits

splits = merge_pair("t", "h", splits)
print(splits["they"])
['th', 'e', 'y']

此时我们就有遍历整个语料所需的全部必要实现,我们可以编写代码调用这些实现,先试试将目标词表大小设置为50:

merges = {}

vocab_size = 50

while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)

    best_pair = None
    max_freq = 0

    for pair, freq in pair_freqs.items():
        if max_freq < freq:
            best_pair = pair
            max_freq = freq
    # 用*对best_pair元组进行解包
    # 得到新的splits
    splits = merge_pair(*best_pair, splits)
    # 学习到的合并规则
    merges[best_pair] = best_pair[0] + best_pair[1]
    # 词表扩充
    vocab.append(best_pair[0] + best_pair[1])

pprint(merges)
{('a', 'b'): 'ab',
 ('a', 't'): 'at',
 ('e', 'n'): 'en',
 ('e', 'r'): 'er',
 ('h', 'o'): 'ho',
 ('ho', 'w'): 'how',
 ('i', 'n'): 'in',
 ('i', 'o'): 'io',
 ('i', 's'): 'is',
 ('io', 'n'): 'ion',
 ('n', 'd'): 'nd',
 ('o', 'u'): 'ou',
 ('s', 'e'): 'se',
 ('t', 'h'): 'th',
 ('t', 'o'): 'to',
 ('th', 'e'): 'the',
 ('th', 'is'): 'this',
 ('to', 'k'): 'tok',
 ('tok', 'en'): 'token',
 ('token', 'i'): 'tokeni',
 ('tokeni', 'z'): 'tokeniz'}

我们学习了21个合并规则,初始词表大小为29(25个字母加上4个特殊标记)。

['<PAD>', '<UNK>', '<BOS>', '<EOS>', ',', '.', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'th', 'is', 'er', 'to', 'en', 'this', 'ou', 'se', 'tok', 'token', 'ho', 'nd', 'the', 'in', 'ab', 'tokeni', 'tokeniz', 'at', 'io', 'ion', 'how']

BPE分词最重要的合并规则已经学好了,那如何遇到新文本,要如何应用我们的BPE分词算法呢?

其实就是我们BPE分词算法的四个步骤,回顾一下:

  1. 标准化
  2. 预分词
  3. 单词拆分成字符
  4. 根据学习到的规则应用到这些拆分
def tokenize(text, merges):
    # 1. 标准化和预分词
    text = text.lower()
    words = [w for w in jieba.cut(text) if w != " "]
    # 2. 单词拆分成字符
    splits = [[c for c in word] for word in words]
    # 3. 根据学习到的规则应用到这些拆分
    # 3.1 遍历所有的合并规则
    for pair, merge in merges.items():
        # 3.2 遍历所有的splits
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            # 应用合并结果
            splits[idx] = split
    # splits是包含列表的列表
    # 返回一个包含所有子列表元素的新列表
    return sum(splits, [])

print(tokenize("This is not a token.", merges))
['this', 'is', 'n', 'o', 't', 'a', 'token', '.']

这就是我们最终BPE算法对这行文本分词的结果。注意为了更容易理解,我们上面没有考虑在单词的最后加上</w>字符,我们这里封装的时候考虑进来,但</w>在我们的代码中会被当成多个字符:</w>,为了不对代码进行太多的修改,我们用特殊字符Ġ来代表</w>

最后用类封装所有操作:

class BPETokenizer:
    def __init__(self, special_tokens=[]) -> None:
        self.word_freqs = defaultdict(int)
        self.merges = {}
        self.token_to_id: dict[str, int] = {}
        self.id_to_token: dict[int, str] = {}

        if special_tokens is None:
            special_tokens = []
        special_tokens = [
            "<PAD>",
            "<UNK>",
            "<BOS>",
            "<EOS>",
            "Ġ",  # stands for </w>
        ] + special_tokens

        for token in special_tokens:
            self._add_token(token)

    def _add_token(self, token: str) -> None:
        if token not in self.token_to_id:
            idx = len(self.token_to_id)
            self.token_to_id[token] = idx
            self.id_to_token[idx] = token

    @property
    def vocab_size(self) -> int:
        return len(self.token_to_id)

    def _learn_vocab(self, corpus: list[str]) -> None:
        for sentence in corpus:
            sentence = sentence.lower()
            words = [w + "Ġ" for w in jieba.cut(sentence) if w != " "]
            for word in words:
                self.word_freqs[word] += 1

    def _compute_pair_freqs(self, splits) -> dict[Tuple, int]:
        pair_freqs = defaultdict(int)
        for word, freq in self.word_freqs.items():
            split = splits[word]
            if len(split) == 1:
                continue

            for i in range(len(split) - 1):
                pair = (split[i], split[i + 1])
                pair_freqs[pair] += freq

        return pair_freqs

    def _merge_pair(self, a: str, b: str, splits):
        for word in self.word_freqs:
            split = splits[word]
            if len(split) == 1:
                continue

            i = 0
            while i < len(split) - 1:
                if split[i] == a and split[i + 1] == b:
                    split = split[:i] + [a + b] + split[i + 2 :]
                else:
                    i += 1
            splits[word] = split

        return splits

    def _merge_vocab(self, vocab_size, splits):
        merges = {}

        while self.vocab_size < vocab_size:
            pair_freqs = self._compute_pair_freqs(splits)

            best_pair = None
            max_freq = 0

            for pair, freq in pair_freqs.items():
                if max_freq < freq:
                    best_pair = pair
                    max_freq = freq

            splits = self._merge_pair(*best_pair, splits)
            merges[best_pair] = best_pair[0] + best_pair[1]
            self._add_token(best_pair[0] + best_pair[1])

        return merges

    def train(self, corpus, vocab_size):
        self._learn_vocab(corpus)
        splits = {word: [c for c in word] for word in self.word_freqs.keys()}
        
        for split in splits.values():
            for c in split:
                self._add_token(c)
                
        self.merges = self._merge_vocab(vocab_size, splits)

    def tokenize(self, text):
        text = text.lower()
        words = [w + "Ġ" for w in jieba.cut(text) if w != " "]
        splits = [[c for c in word] for word in words]
        for pair, merge in self.merges.items():
            for idx, split in enumerate(splits):
                i = 0
                while i < len(split) - 1:
                    if split[i] == pair[0] and split[i + 1] == pair[1]:
                        split = split[:i] + [merge] + split[i + 2 :]
                    else:
                        i += 1
                splits[idx] = split
        return sum(splits, [])

主要的修改点就是在预分词后每个单词后面增加了表示结尾的特殊字符Ġ

_learn_vocab()方法中得到的word_freqs字典就变成了:

print(self.word_freqs)
defaultdict(<class 'int'>, {'thisĠ': 3, 'isĠ': 2, 'theĠ': 1, 'huggingĠ': 1, 'faceĠ': 1, 'courseĠ': 1, '.Ġ': 4, 'chapterĠ': 1, 'aboutĠ': 1, 'tokenizationĠ': 1, 'sectionĠ': 1, 'showsĠ': 1, 'severalĠ': 1, 'tokenizerĠ': 1, 'algorithmsĠ': 1, 'hopefullyĠ': 1, ',Ġ': 1, 'youĠ': 1, 'willĠ': 1, 'beĠ': 1, 'ableĠ': 1, 'toĠ': 1, 'understandĠ': 1, 'howĠ': 1, 'theyĠ': 1, 'areĠ': 1, 'trainedĠ': 1, 'andĠ': 1, 'generateĠ': 1, 'tokensĠ': 1})

还是像上面那样进行测试:

corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

tokenizer = BPETokenizer()
tokenizer.train(corpus, 50)
print(tokenizer.tokenize("This is not a token."))
['thisĠ', 'isĠ', 'n', 'o', 't', 'Ġ', 'a', 'Ġ', 'token', 'Ġ', '.Ġ']

这样我们得到了子词词表,可以继续添加encodedecode方法,此时代码如下:

class BPETokenizer:
    def __init__(self, special_tokens=[]) -> None:
        self.word_freqs = defaultdict(int)
        self.merges = {}
        self.token_to_id: dict[str, int] = {}
        self.id_to_token: dict[int, str] = {}

        if special_tokens is None:
            special_tokens = []
        special_tokens = [
            "<PAD>",
            "<UNK>",
            "<BOS>",
            "<EOS>",
            "Ġ",  # stands for </w>
        ] + special_tokens

        for token in special_tokens:
            self._add_token(token)

        self.unk_token = "<UNK>"
        self.unk_token_id = self.token_to_id.get(self.unk_token)

    def _add_token(self, token: str) -> None:
        if token not in self.token_to_id:
            idx = len(self.token_to_id)
            self.token_to_id[token] = idx
            self.id_to_token[idx] = token

    @property
    def vocab_size(self) -> int:
        return len(self.token_to_id)

    def _learn_vocab(self, corpus: list[str]) -> None:
        for sentence in corpus:
            sentence = sentence.lower()
            words = [w + "Ġ" for w in jieba.cut(sentence) if w != " "]
            for word in words:
                self.word_freqs[word] += 1

    def _compute_pair_freqs(self, splits) -> dict[Tuple, int]:
        pair_freqs = defaultdict(int)
        for word, freq in self.word_freqs.items():
            split = splits[word]
            if len(split) == 1:
                continue

            for i in range(len(split) - 1):
                pair = (split[i], split[i + 1])
                pair_freqs[pair] += freq

        return pair_freqs

    def _merge_pair(self, a: str, b: str, splits):
        for word in self.word_freqs:
            split = splits[word]
            if len(split) == 1:
                continue

            i = 0
            while i < len(split) - 1:
                if split[i] == a and split[i + 1] == b:
                    split = split[:i] + [a + b] + split[i + 2 :]
                else:
                    i += 1
            splits[word] = split

        return splits

    def _merge_vocab(self, vocab_size, splits):
        merges = {}

        while self.vocab_size < vocab_size:
            pair_freqs = self._compute_pair_freqs(splits)

            best_pair = None
            max_freq = 0

            for pair, freq in pair_freqs.items():
                if max_freq < freq:
                    best_pair = pair
                    max_freq = freq

            splits = self._merge_pair(*best_pair, splits)
            merges[best_pair] = best_pair[0] + best_pair[1]
            self._add_token(best_pair[0] + best_pair[1])

        return merges

    def train(self, corpus, vocab_size):
        self._learn_vocab(corpus)
        splits = {word: [c for c in word] for word in self.word_freqs.keys()}

        for split in splits.values():
            for c in split:
                self._add_token(c)

        self.merges = self._merge_vocab(vocab_size, splits)

    def tokenize(self, text):
        text = text.lower()
        words = [w + "Ġ" for w in jieba.cut(text) if w != " "]
        splits = [[c for c in word] for word in words]
        for pair, merge in self.merges.items():
            for idx, split in enumerate(splits):
                i = 0
                while i < len(split) - 1:
                    if split[i] == pair[0] and split[i + 1] == pair[1]:
                        split = split[:i] + [merge] + split[i + 2 :]
                    else:
                        i += 1
                splits[idx] = split
        return sum(splits, [])

    def _convert_token_to_id(self, token: str) -> int:
        return self.token_to_id.get(token, self.unk_token_id)

    def _convert_id_to_token(self, index: int) -> str:
        return self.id_to_token.get(index, self.unk_token)

    def _convert_ids_to_tokens(self, token_ids: list[int]) -> list[str]:
        return [self._convert_id_to_token(index) for index in token_ids]

    def _convert_tokens_to_ids(self, tokens: list[str]) -> list[int]:
        return [self._convert_token_to_id(token) for token in tokens]

    def encode(self, text: str) -> list[int]:
        tokens = self.tokenize(text)

        return self._convert_tokens_to_ids(tokens)

    def clean_up_tokenization(self, out_string: str) -> str:
        out_string = (
            out_string.replace("Ġ", " ")
            .replace(" .", ".")
            .replace(" ?", "?")
            .replace(" !", "!")
            .replace(" ,", ",")
            .replace(" ' ", "'")
            .replace(" n't", "n't")
            .replace(" 'm", "'m")
            .replace(" 's", "'s")
            .replace(" 've", "'ve")
            .replace(" 're", "'re")
        )
        return out_string

    def decode(self, token_ids: list[int]) -> str:
        tokens = self._convert_ids_to_tokens(token_ids)
        return self.clean_up_tokenization("".join(tokens))

接着之前的测试:

corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

tokenizer = BPETokenizer()
tokenizer.train(corpus, 50)
# ['thisĠ', 'isĠ', 'n', 'o', 't', 'Ġ', 'a', 'Ġ', 'token', 'Ġ', '.Ġ']
print(tokenizer.tokenize("This is not a token."))
# 50
print(tokenizer.vocab_size)
token_ids = tokenizer.encode("This is not a token.")
# [38, 33, 12, 16, 5, 4, 14, 4, 41, 4, 35]
print(token_ids)
# this is not a token.
print(tokenizer.decode(token_ids))
['thisĠ', 'isĠ', 'n', 'o', 't', 'Ġ', 'a', 'Ġ', 'token', 'Ġ', '.Ġ']
50
[38, 33, 12, 16, 5, 4, 14, 4, 41, 4, 35]
this is not a token.

通过clean_up_tokenization()方法处理空格替换以及最后一个标点符号后有额外空格的问题。

SentencePiece

上面自己实现的BPE仅适用于学习,但实际使用起来分词非常慢。

我们这里使用sentencepiece工具进行BPE分词,它底层基于C++实现,速度非常快。同时支持批数据处理。

import json

import sentencepiece as spm
from concurrent.futures import ProcessPoolExecutor
import os
from utils import convert_to_zh, make_dirs
from config import train_args, model_args


def get_wmt_pairs(data_dir: str, splits=["train", "dev", "test"]):
    chinese_sentences = []
    english_sentences = []
    """
    json content:
    [["english sentence", "中文语句"], ["english sentence", "中文语句"]]
    """
    for split in splits:
        with open(f"{data_dir}/{split}.json", "r", encoding="utf-8") as f:
            data = json.load(f)
            for pair in data:
                english_sentences.append(pair[0] + "\n")
                chinese_sentences.append(pair[1] + "\n")

    assert len(chinese_sentences) == len(english_sentences)

    print(f"the total number of sentences: {len(chinese_sentences)}")

    return chinese_sentences, english_sentences


def get_en_cn_pairs(data_dir: str, splits=["train", "dev", "test"]):
    chinese_sentences = []
    english_sentences = []
    """
    txt content:
    english sentence\t繁体中文语句
    english sentence\t繁体中文语句
    """
    for split in splits:
        with open(f"{data_dir}/{split}.txt", "r", encoding="utf-8") as f:
            lines = f.readlines()
            for line in lines:
                if line:
                    pair = line.strip().split("\t")
                    english_sentences.append(pair[0])
                    chinese_sentences.append(convert_to_zh(pair[1]))

    assert len(chinese_sentences) == len(english_sentences)

    print(f"the total number of sentences: {len(chinese_sentences)}")

    return chinese_sentences, english_sentences

def train_sentencepice_bpe(input_file: str, model_prefix: str, vocab_size: int, character_coverage: float = 0.9995, pad_id:int =0, unk_id:int=1, bos_id:int=2, eos_id:int=3):
    cmd = f"--input={input_file} --model_prefix={model_prefix} --vocab_size={vocab_size} --model_type=bpe --character_coverage={character_coverage} --pad_id={pad_id} --unk_id={unk_id} --bos_id={bos_id} --eos_id={eos_id}"
    spm.SentencePieceTrainer.train(cmd)


def train_tokenizer(
    source_corpus_path: str,
    target_corpus_path: str,
    source_vocab_size: int,
    target_vocab_size: int,
    source_character_coverage:float = 1.0,
    target_character_coverage:float = 0.9995

) -> None:
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                train_sentencepice_bpe,
                source_corpus_path,
                "source",
                source_vocab_size,
                source_character_coverage
            ),
            executor.submit(
                train_sentencepice_bpe,
                target_corpus_path,
                "target",
                target_vocab_size,
                target_character_coverage
            ),
        ]

        for future in futures:
            future.result()

    sp = spm.SentencePieceProcessor()

    source_text =  """
        Tesla is recalling nearly all 2 million of its cars on US roads to limit the use of its 
        Autopilot feature following a two-year probe by US safety regulators of roughly 1,000 crashes 
        in which the feature was engaged. The limitations on Autopilot serve as a blow to Tesla’s efforts 
        to market its vehicles to buyers willing to pay extra to have their cars do the driving for them.
        """

    sp.Load("./source.model")
    print(sp.encode_as_pieces(source_text))
    ids = sp.encode_as_ids(source_text)
    print(ids)
    print(sp.decode_ids(ids))


    
    target_text = """
        今晨(12月14日),中央气象台继续发布寒潮黄色预警、大风蓝色预警、暴雪黄色预警和冰冻黄色预警,
        中东部的大范围雨雪降温天气仍在持续进行中,公众需关注预报预警信息,做好保暖御寒和除雪工作,
        外出需警惕路面湿滑、道路结冰等带来的不利影响,注意出行安全。
    """

    sp.Load("./target.model")
    print(sp.encode_as_pieces(target_text))
    ids = sp.encode_as_ids(target_text)
    print(ids)
    print(sp.decode_ids(ids))
    

 

if __name__ == "__main__":
    make_dirs(train_args.save_dir)


    chinese_sentences, english_sentences = get_wmt_pairs(
        data_dir=train_args.dataset_path, splits=["train", "dev", "test"]
    )

    with open(f"{train_args.dataset_path}/corpus.ch", "w") as f:
        f.writelines(chinese_sentences)

    with open(f"{train_args.dataset_path}/corpus.en", "w") as f:
        f.writelines(english_sentences)
    


    train_tokenizer(
        f"{train_args.dataset_path}/corpus.en",
        f"{train_args.dataset_path}/corpus.ch",
        source_vocab_size=model_args.source_vocab_size,
        target_vocab_size=model_args.target_vocab_size
    )

这里针对WMT中英数据集训练分词器,这样就可以得到训练好的分词器。训练完毕后将分词器移到指定的目录。

想了解更多建议参考官方文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/288318.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IOS:Safari无法播放MP4(H.264编码)

一、问题描述 MP4使用H.264编码通常具有良好的兼容性&#xff0c;因为H.264是一种广泛支持的视频编码标准。它可以在许多设备和平台上播放&#xff0c;包括电脑、移动设备和流媒体设备。 使用caniuse查询H.264兼容性&#xff0c;看似确实具有良好的兼容性&#xff1a; 然而…

Windows系统镜像检测修复建议

当通过镜像检测功能检测出Windows操作系统磁盘上有残留驱动项、系统中存在残留Xen驱动或者存在禁止安装驱动属性设置等异常检测项时&#xff0c;您可以参考本文的操作指导进行修复。 清理注册表残留驱动 HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control注册表树包含了控…

c++基础(对c的扩展)

文章目录 命令空间引用基本本质引用作为参数引用的使用场景 内联函数引出基本概念 函数补充默认参数函数重载c中函数重载定义条件函数重载的原理 命令空间 定义 namespace是单独的作用域 两者不会相互干涉 namespace 名字 { //变量 函数 等等 }eg namespace nameA {int num;v…

TextView ClickableSpan 事件分发的坑

TextView 的 ClickableSpan 有两个坑&#xff1a; 默认情况下&#xff0c;点击 ClickableSpan 的文本时会同时触发绑定在 TextView 的监听事件&#xff1b;默认情况下&#xff0c;点击 ClickableSpan 的文本之外的文本时&#xff0c;TextView 会消费该事件&#xff0c;而不会传…

MySQL运维实战(2.2)忘记密码如何处理

作者&#xff1a;俊达 引言 当你突然忘记了一个普通用户的密码&#xff0c;而又想着通过管理员账号去改密码时&#xff0c;却猛的发现所有管理员账号的密码都离谱地被你忘了。嗨呀&#xff0c;这可真是个尴尬的大麻烦&#xff01;root账户通常是MySQL中的大boss&#xff0c;你…

Redis(一)

1、redis Redis是一个完全开源免费的高性能&#xff08;NOSQL&#xff09;的key-value数据库。它遵守BSD协议&#xff0c;使用ANSI C语言编写&#xff0c;并支持网络和持久化。Redis拥有极高的性能&#xff0c;每秒可以进行11万次的读取操作和8.1万次的写入操作。它支持丰富的数…

【LeetCode:69. x 的平方根 | 二分】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

程序员提问的艺术:28.4K Star指南,告别成为办公室讨厌鬼!

Github: https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way 原文&#xff1a;http://www.catb.org/~esr/faqs/smart-questions.html ✅为什么讨厌某些提问者 未自行尝试解决问题&#xff1a; ❌“怎么用Java写一个排序算法&#xff1f;” &#x1f44d;&#…

Plantuml之EBNF语法介绍(二十七)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

java spring核心技术AOP面向切面编程图文并茂包含例子demo

base Aspect-oriented programming面向切面的程序设计用于将那些与业务无关,但却对多个对象产生影响的公共行为和逻辑,抽取并封装为一个可重用的模块,这个模块被命名为“切面”(Aspect)场景: 权限认证,日志和事务处理.demo 基本背景 // 背景: 1. 模拟数据库操作增删改查 …

tomcat session cookie值设置逻辑

tomcat session cookie 值设置&#xff0c;tomcat jsessionid设置 ##调用request.getSession() Controller RequestMapping("/cookie") public class CookieController {RequestMapping("/tomcatRequest")ResponseBodypublic String tomcatRequest(HttpS…

CentOS 7 实战指南:文本处理命令详解

前言 在Linux系统中&#xff0c;文本处理是非常基础却又必不可少的一项技能。如果你正在使用CentOS系统&#xff0c;那么学会如何利用文本操作命令来高效地处理文本文件无疑将会是一个强有力的工具。 本篇文章将介绍一些最常用和最实用的文本操作命令&#xff0c;并通过详尽的…

医院配电能效监管方案

摘要:本文以医院能源监管系统为研究对象,采用智能化技术组建数据库、构建智能化的能耗信息管理系统,实现对医院的能源利用状况进行实时、准确的动态监管。具体而言,该系统建设的主要功能是对医院的能源消耗进行采集、上报、汇总与分析,并生成动态的数据和报表曲线,以及利用分析…

访问学者J1签证的申请流程

访问学者J1签证是许多人前往美国进行学术研究和文化交流的重要途径之一。申请J1签证需要经过一系列步骤和程序&#xff0c;让知识人网小编带大家来了解一下申请流程吧。 首先&#xff0c;申请者需要确认自己符合J1签证的资格要求。这包括被美国的赞助机构或组织接受&#xff0c…

uniapp中uview组件库的Input 输入框 的使用方法

目录 #平台差异说明 #基本使用 #输入框的类型 #可清空字符 #下划线 #前后图标 #前后插槽 API #Props #Events #Methods #Slots 去除fixed、showWordLimit、showConfirmBar、disableDefaultPadding、autosize字段 此组件为一个输入框&#xff0c;默认没有边框和样式…

mysql查询表里的重复数据方法:

1 2 3 4 INSERT INTO hk_test(username, passwd) VALUES (qmf1, qmf1),(qmf2, qmf11) delete from hk_test where usernameqmf1 and passwdqmf1 MySQL里查询表里的重复数据记录&#xff1a; 先查看重复的原始数据&#xff1a; 场景一&#xff1a;列出username字段有重读的数…

jdk动态代理中invoke的return返回的值有什么用?

目录 首先在接口中定义一个行为再定义一个目标角色实现接口&#xff0c;实现行为去代理角色类中解决一下报错&#xff0c;但是什么都不要写 invoke的return返回的值是调用方法中返回的值 下面我们来实例看一下 首先在接口中定义一个行为 public String toMarry02();再定义一个…

金和OA C6 UploadFileEditorSave.aspx 文件上传漏洞复现

0x01 产品简介 金和OA协同办公管理系统软件(简称金和OA),本着简单、适用、高效的原则,贴合企事业单位的实际需求,实行通用化、标准化、智能化、人性化的产品设计,充分体现企事业单位规范管理、提高办公效率的核心思想,为用户提供一整套标准的办公自动化解决方案,以帮助…

关于执行 roslaunch xxxxx yyyy.launch 后,没能进入 RViz 就卡死的问题

Problem 话不多说&#xff0c;先看图。 终端也会提示有报错&#xff08;可能是这种&#xff0c;但不确定&#xff09;&#xff1a; 这是发现问题所在之后&#xff0c;故意改错&#xff0c;然后尝试的。☝ Solution 总以为是显卡的问题&#xff0c;一直在研究怎么在 Ubuntu2…

适合前后端开发的可视化编辑器(拖拽控件)

分享一个面向研发人群使用的前后端分离的低代码软件——JNPF。 JNPF与市面上其他的低代码&#xff08;轻流、宜搭、微搭、简道云、轻流、活字格等等&#xff09;&#xff0c;后者更倾向于非编程人员使用&#xff0c;让业务线人员自行构建应用程序。而 JNPF 这款低代码产品是面向…