pytorch之诗词生成--2

先上代码:

# -*- coding: utf-8 -*-
# @File    : dataset.py
# @Author  : AaronJny
# @Time    : 2019/12/30
# @Desc    : 构建数据集
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings


class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)


# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE

# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

# 统计词频
counter = Counter()
for line in poetry:
    counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)


class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

下面我们逐行分析该代码:我们首先定义一个分词器类:

class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)

看第一段:

def __init__(self, token_dict):
    # 词->编号的映射
    self.token_dict = token_dict
    # 编号->词的映射
    self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
    # 词汇表大小
    self.vocab_size = len(self.token_dict)

首先我们接受一个名为token_dict的参数,将其存储为类的属性,然后创建一个名为token_dict_rev的属性,这是token_dict的反向映射,最后,计算词汇表的大小并将其存储为vocab_size属性。

看下一段:

def id_to_token(self, token_id):
    """
    给定一个编号,查找词汇表中对应的词
    :param token_id: 带查找词的编号
    :return: 编号对应的词
    """
    return self.token_dict_rev[token_id]

这段代码定义一个方法id_to_token,接受一个名为token_id的参数,然后在词汇表中查找对应的词并返回,这个方法实际上是通过token_dict_rev属性实现的反向查找。明显,该字典中的键词的编号,值是词。

接着往下看:

def token_to_id(self, token):
    """
    给定一个词,查找它在词汇表中的编号
    未找到则返回低频词[UNK]的编号
    :param token: 带查找编号的词
    :return: 词的编号
    """
    return self.token_dict.get(token, self.token_dict['[UNK]'])

这段代码与上一段的由键到值差不多,是由值找到对应的键。接受名为token作为参数,然后在词汇表中查找对应词的编号并返回。如果词不在词汇表中,则返回低频词[UNK]的编号,注意我们的token_dict字典的键是词,值是编号,我们可以通过词来找到对应的编号,而token_dict_rev的键是编号,值是词,我们可以通过编号找到对应的值。

return self.token_dict.get(token, self.token_dict['[UNK]'])这段代码中,我们使用get方法,我们尝试在self.token_dict中获取键为token的值,也就是找到对应的编号,第二个参数表示如果没找到对应的键,则返回self.token_dict中键为[UNK]的值。(第二个参数表示字典找不到对应键时返回的默认值)。这样可以确保即使词不在词表中,也能返回一个默认值,避免了出现KeyError。

继续看代码:

def encode(self, tokens):
    """
    给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
    :param tokens: 待编码字符串
    :return: 编号序列
    """
    # 加上开始标记
    token_ids = [self.token_to_id('[CLS]'), ]
    # 加入字符串编号序列
    for token in tokens:
        token_ids.append(self.token_to_id(token))
    # 加上结束标记
    token_ids.append(self.token_to_id('[SEP]'))
    return token_ids

我们的开始标记调用了我们刚刚定义的token_to_id方法,显然,不可能出现[CLS]这个词,所以得到的是[UNK]对应的编号,显然是一个特殊的编号。
(我们看一下错误的输出,也不算错误,就是对应我们的处理词输出。)

而后遍历tokens中的每个词,将词转化为对应的编号加入到编号序列中,这样我们就可以将我们的汉字类型转化为数字,从而可以进行卷积层的处理。

随后加上结束标记的符号,显然也是对应[UNK]。最后我们返回完整的编号序列。(是一个由数字组成的列表)。

相对应的是解码:

def decode(self, token_ids):
    """
    给定一个编号序列,将它解码成字符串
    :param token_ids: 待解码的编号序列
    :return: 解码出的字符串
    """
    # 起止标记字符特殊处理
    spec_tokens = {'[CLS]', '[SEP]'}
    # 保存解码出的字符的list
    tokens = []
    for token_id in token_ids:
        token = self.id_to_token(token_id)
        if token in spec_tokens:
            continue
        tokens.append(token)
    # 拼接字符串
    return ''.join(tokens)

我们先将特殊字符,也就是开始与结束对应的字符组成一个集合。而后我们创建了一个名为tokens的空列表,用于保存由token_ids中token_id对应词。最后我们使用join方法,将tokens列表中的字符串元素链接起来,形成一个新的字符串,在这里,''表示以空字符串作为连接符,也就是将tokens中的词无缝衔接。

接下来我们定义一些参数,这些参数在setting中已经定义,这里我们直接拿来用:

isallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZEr

然后我们就可以开始加载数据集了:

with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []

通过在setting中已经定义好的路径用只读的方式加载我们的数据,解码的类型是utf-8。f是一个对象,表示被打开的文件。文件对象f会在with代码块结束的时候自动关闭。

lines=f.readlines():这段代码从打开的文件对象f中读取所有行,并将它们存储在名为lines的列表中。(因为我们的数据集很大,所以这一步很耗时间)。
而后我们对我们的诗词进行处理,将所有行中的‘:’转化为‘:’,即格式统一,但是这里其实我们都转化为“:”也是不影响的。
然后我们创建一个数据集列表,也就是空列表。

接着我们开始对每一行(也就是一首诗)进行处理:

for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

这里我们首先要参考一下数据的格式:

可见我们的每首诗在:的前面部分是诗词名,后半部分是内容,如果该行不包含:则表示是数据出现错误,这时我们直接跳过该数据,使用continue。对于没有问题的数据,我们使用split方法将数据分为前半部分诗词名(当然,直接丢掉),和第二部分内容(是我们需要的精华)。

我们定义一个布尔类型的变量ignore_flag用来判断是否将这个数据忽视。我们将禁词一一取出,如果禁词在我们的数据中出现,我们将该布尔变量设置为true,也就是要去除该数据,嵌套遍历完成后,我们通过判断布尔变量值来确定是否进行下一步处理,当然没有问题的数据,我们将其保存并进行下一步处理。

我们在进行下一步处理的时候也要进行判断,显然,当我们的数据长度较长的时候,比如(长恨歌),我们也是不需要的,这属于异常数据,我们用它作为参考生成小篇幅诗词无异于读圣经来学习小学的看图写话。

剩下的部分也就是符合我们要求的数据了,对于这些数据,我们直接将他们放进我们的列表中。注意小细节,我们将换行符转化为空格。(官方解释是确保诗词文本在处理之后仍然保持连续的完整性,而不会因为换行符被分割为很多行,有利于后期对文本的处理和分析)(但是我认为这是多余的,因为对于一行数据来代表一首诗词来说,完全没必要考虑换行符的问题)。

嗯嗯...也不是完全没用。

可见,我们生成诗词的时候,如果考虑到换行符的话,我们可以拉开我们生成的诗词的距离。

继续:

counter = Counter()
for line in poetry:
    counter.update(line)

这段代码创建了一个Counter类(计数器对象),它是collections模块中的一个数据结构,用于统计可哈希对象的出现次数。然后循环迭代poetry中的每一行,其中poetry是一个包含多行诗歌的列表。在每次迭代中,counter.update(line)都会被调用,它会将line中的字符添加到计数器中,并更新它们的出现次数,update()方法接受一个可迭代对象作为参数,它会遍历该对象并更新计数器。

最终,counter对象将会包含整个数据集中每个字符出现的次数。我们将通过一个简单的案例来说明counter函数的用法:

from collections import Counter

poetry = [
    "Roses are red,",
    "Violets are blue,",
    "Sugar is sweet,",
    "And so are you."
]

counter = Counter()
for line in poetry:
    counter.update(line)

print(counter)

输出结果如下:

Counter({' ': 15, 'e': 10, 's': 7, 'a': 6, 'o': 5, 'r': 4, 'u': 4, 't': 3, 'd': 2, 'n': 2, 'y': 2, 'w': 2, 'A': 1, 'R': 1, 'V': 1, 'i': 1, 'l': 1, 'b': 1, 'g': 1, ',': 1, 'S': 1, 'I': 1, '.': 1})

输出的是一个Counter对象。

接下来我们接着对词进行处理:

_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)

我们首先来看第一行,创建了一个列表_tokens,用来包含计数器counter中词频大于等于min_word_frequency的词和它们的出现次数,counter.item返回的是一个键值对,键是词,值是对应的频数。

接下来,我们对_tokens列表进行排序,按照词频从高到低进行降序排序,key=lambda x:-x[1]表示使用每个元素的第二个值,即词频作为进行排序的依据。

_tokens = [token for token, count in _tokens]之后我们将排序后的列表中提取词汇,生成一个只包含词汇的列表,这里丢弃了词频信息,只包含了词汇。

而后我们将一些特殊字符,_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens 添加到_tokens列表中,即在词汇列表的最前面。

token_id_dict = dict(zip(_tokens, range(len(_tokens))))然后我们创建一个字典,字典是从词汇到ID的映射关系,当然,前几个索引对应的是特殊词汇,后面按照词汇出现的频率一次对应索引。当然,得到的结果是一个字典。(由词汇到索引)

我们将这个字典传到我们的分词器中,会自动生成由索引到词的映射,以及得到该字典的长度(即词的个数)。

然后我们将我们的诗词的列表进行混洗。

而后我们又定义了一个古诗数据集生成器:

class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

我们从头进行分析:

def __init__(self, data, random=False):
    # 数据集
    self.data = data
    # batch size
    self.batch_size = batch_size
    # 每个epoch迭代的步数
    self.steps = int(math.floor(len(self.data) / self.batch_size))
    # 每个epoch开始时是否随机混洗
    self.random = random

我们接受data和可选的random参数,在方法内部,我们将传入的data赋值给self.data,并确定了batch_size属性,我们之后通过数据集的长度和每个批次的长度来计算每一轮训练多少个批次(也就是步数)。
self.random=random表示每个epoch开始时是否随机混洗数据,它的值等于传入的random参数,默认为不随机混洗。

继续看代码:

def sequence_padding(self, data, length=None, padding=None):
    """
    将给定数据填充到相同长度
    :param data: 待填充数据
    :param length: 填充后的长度,不传递此参数则使用data中的最大长度
    :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
    :return: 填充后的数据
    """
    # 计算填充长度
    if length is None:
        length = max(map(len, data))
    # 计算填充数据
    if padding is None:
        padding = tokenizer.token_to_id('[PAD]')
    # 开始填充
    outputs = []
    for line in data:
        padding_length = length - len(line)
        # 不足就进行填充
        if padding_length > 0:
            outputs.append(np.concatenate([line, [padding] * padding_length]))
        # 超过就进行截断
        else:
            outputs.append(line[:length])
    return np.array(outputs)

我们使用sequence_padding方法,用于将给定的数据填充到相同的长度:

我们传入参数分别是数据,长度,填充的字符。
我们默认填充后的长度是我们数据中的最大长度,这也是我们为什么使用的是64作为最大长度,而将诗词较长的数据进行去除。(不太适合生成长恨歌那样的诗词)。
我们填充的数据编号是PAD对应的编号,即解码的时候对应的也是PAD。
之后我们进行填充,计算出每一行需要填充的长度(归一化长度后的长度减去当前的长度),如果需要进行填充,我们将原数据拼接填充内容作为填充之后的数据。将填充之后的数据放入我们的outputs列表中,否则的话(数据大于我们的最大数据,虽然理论上是不可能的,但是我们也是写一下吧,就只留下到最大长度为止的数据。)

这里值得注意的是,我们传入的是由索引组成的列表。我们得到的也是由数据列表组成的列表,我们通过np.array(outputs)将列表outputs转化为一个numpy数组,其中每个元素对应列表中子列表。便于进一步处理数据。

接下来我们通过__len__来返回步长:

def __len__(self):
    return self.steps

即每轮训练多少个批次,在这里,初始化的时候已经计算好了。

继续哈:

def __iter__(self):
    total = len(self.data)
    # 是否随机混洗
    if self.random:
        np.random.shuffle(self.data)
    # 迭代一个epoch,每次yield一个batch
    for start in range(0, total, self.batch_size):
        end = min(start + self.batch_size, total)
        batch_data = []
        # 逐一对古诗进行编码
        for single_data in self.data[start:end]:
            batch_data.append(tokenizer.encode(single_data))
        # 填充为相同长度
        batch_data = self.sequence_padding(batch_data)
        # yield x,y
        yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

首先我们获取总样本数,也就是我们的诗词个数,如果self.random=True表示每个epoch开始时需要随机混洗数据集,因此使用np.random.shuffle随机打乱self.data。

然后使用for循环进入每个批次进行训练,(以批次大小为步长遍历数据集,每次迭代都产生一个批次的数据)。用start和end分别表示训练数据开始和结束对应的索引,这里我们要考虑当用累加计算结束位置的时候,不要超过数据的长度。

然后我们逐一对古诗进行编码,将编码得到的结果送入空列表batch_data中,这里要注意我们得到的tokenizer.encode(single_data)是一个由数字组成的列表设为A,然后送入batch_data得到的是一个由A组成的列表,对这个列表进行padding处理,将这个列表中每首诗对应的列表进行扩充。(填充到相同的长度)。

yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

这一行代码使用yield语句生成一个批次的数据。它返回两个值:batch_data[:,:-1]输入数据,是经过填充的故事序列编码,去掉每个序列的最后一个词,它的形状是(batch_size,sequence_length-1)。(最后一个是句号哦)。
tf.one_hot(batch_data[;,1:],tokenizer.vocab_size)这段代码的目的是将目标数据进行编码,并在这个过程中去掉每个序列中的第一个词,进行独热编码。tokenizer.vocab_size是词汇表的大小,用于确定独热编码的维度。它的形状是(batch_size,sequence_length-1,tokenizer.vocab_size)。

最后del batch_data:

删除批次数据batch_data释放内存,在每次迭代后我们就不需要存储整个批次的数据,因此可以通过删除来释放内存。

为什么删除第一个和最后一个呢?因为我们的起始位置和结束都使用特殊字符进行编码。

最后我们使用:

def for_fit(self):
    """
    创建一个生成器,用于训练
    """
    # 死循环,当数据训练一个epoch之后,重新迭代数据
    while True:
        # 委托生成器
        yield from self.__iter__()

这里我们创建一个死循环,表示生成器会无限制的生成数据,这是为了在训练过程中能持续获取数据,这里使用yield from语法来委托另外一个生成器,即self.__iter__()方法生成的数据,委托生成器的作用是将self.__iter__()生成的数据直接传递给外部的迭代器,作为训练数据。

通过这种方式,当调用for_fit方法时,会得到一个生成器对象,每次迭代该生成器,会从self.__iter__()生成的数据中获取一个批次的训练数据,并将其作为生成器的输出,由于采用了死循环的设置,这个生成器会持续的生成数据,直到外部的训练过程停止或中断。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/461035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【敬伟ps教程】视频动画

文章目录 视频文档视频时间轴帧动画视频文档 工作区需由[基本功能]切换为[动感] 可以看到我们需从时间的维度来编辑动态视觉图像 时间轴:从时间的维度来编辑动态视觉图像 PS提供的时间轴有两种:1、视频时间轴;2、动画时间轴 新建视频文档,点击新建或Ctrl+N,预设选择“胶…

NCV1117ST50T3G线性稳压器芯片中文资料规格书PDF数据手册引脚图图片价格参数

产品概述: NCP1117系列为低压差(LDO)正向线性电压稳压器,能够提供超过1.0A的输出电流,800mA时温度范围内最大压差为1.2V。这一系列包括八个固定输出电压:1.5V、1.8V、2.0V、2.5V、2.85V、3.3V、5.0V 和 12…

LeetCode刷题小记 八、【回溯算法】

1.回溯算法 文章目录 1.回溯算法写在前面1.1回溯算法基本知识1.2组合问题1.3组合问题的剪枝操作1.4组合总和III1.5电话号码的字母组合1.6组合总和1.7组合总和II1.8分割回文串1.9复原IP地址1.10子集问题1.11子集II1.12非递减子序列1.13全排列1.14全排列II1.15N皇后1.16解数独 写…

代码贴--动态顺序表--数据结构

本博客将记录操作系统中的动态顺序表的相关代码 头文件&#xff08;SeList.h&#xff09; #pragma once #include<stdio.h> #include<string.h> #include<stdlib.h> #include<assert.h> typedef int SQDataType; //动态顺序表typedef struct SeqList…

可视化展示与交互编辑:探索3D Web轻量化平台HOOPS WEB Platform在BIM中的新可能性

随着数字技术的飞速发展&#xff0c;建筑行业也在不断迈向数字化转型的道路。在这个过程中&#xff0c;BIM&#xff08;Building Information Modeling&#xff0c;建筑信息模型&#xff09;技术已经成为建筑设计、施工和管理领域中的一项重要工具。 而在BIM的应用中&#xff…

Javascript抓取京东、淘宝商品数据(商品采集商品详情图片抓取)

之前用的方法&#xff1a; let temp []var lists $(#J_goodsList li.gl-item)$.each(lists,function(idx,item){ temp.push({ id:$(item).data(sku), goods_img:$(item).find(img).attr(src), goods_name:$(item).find(.p-name em).text(), market_price:$(item).fi…

C++:函数传参到函数执行结束发生了什么

首先要明确两个概念 函数实参的入栈从右向左栈区从高地址向低地址偏移 接下来看下面一段代码 void fun(int a,int b,int c){std::cout<<&a<<" "<<&b<<" "<<&c<<std::endl; } int main(){fun(1,2,3); }…

中国首个基于区块链的分布式算力网络上线

随着美国人工智能公司OpenAI近期发布的Sora视频模型&#xff0c;全球对高性能算力的需求突破了历史新高。Sora的创新在于它能够以超长生成时间、多角度镜头捕捉&#xff0c;理解物理世界的能力&#xff0c;这不仅是技术的一大突破&#xff0c;更是对算力需求的一大挑战。在这样…

Qt教程 — 3.3 深入了解Qt 控件:Input Widgets部件(2)

目录 1 Input Widgets简介 2 如何使用Input Widgets部件 2.1 QSpinBox组件-窗口背景不透明调节器 2.2 DoubleSpinBox 组件-来调节程序窗口的整体大小 2.3 QTimeEdit、QDateEdit、QDateTimeEdit组件-编辑日期和时间的小部件 Input Widgets部件部件较多&#xff0c;将分为三…

基于Python django的人脸识别门禁系统,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、Python技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&…

报Invalid value type for attribute ‘factoryBeanObjectType‘: java.lang.String错误

在springboot中使用Mybatis出现Invalid value type for attribute factoryBeanObjectType: java.lang.String 1、没有使用mybatis 检查pom文件里面的mybatis 可能是缺少这个依赖&#xff0c;或者版本过低 重新导入依赖 <dependency><groupId>org.mybatis.spri…

ubuntu(20.04)-安装JAVA环境-IDEA

1.下载IDEA 2.解压文件 sudo tar -zxvf idealC-2022.2.3.tar.gz -C /opt 3.添加环境变量&#xff1a; .vim ~/.bashrc export IDEA_HOME/opt/ideaIC-2022.2.3/ export PATH${IDEA_HOME}/bin:$PATH source ~/.bashrc 4.启动&#xff1a; cd /opt/ideaIC-2…

数据结构 第3章:栈与队列

文章目录 1. 栈1.1 栈的基本概念1.2 栈的基本操作1.3 栈的顺序存储实现1.4 栈的链式存储实现 2. 队列2.1 队列的基本概念2.2 队列的基本操作2.3. 队列的顺序存储实现2.4 队列的链式存储实现2.5 双端队列 3. 栈与队列的应用3.1 栈在括号匹配中的应用3.2 栈在表达式求值中的应用3…

聚观早报 | 微博2023年Q4营收;宝马集团2023财年营收

聚观早报每日整理最值得关注的行业重点事件&#xff0c;帮助大家及时了解最新行业动态&#xff0c;每日读报&#xff0c;就读聚观365资讯简报。 整理丨Cutie 3月16日消息 微博2023年Q4营收 宝马集团2023财年营收 海信发布中文大模型 岚图梦想家新增两款车型 vivo X Fold…

Transformer中注意力层和位置感知前馈层的分工与合作

1. 注意力层和位置感知前馈层的分工与合作 在Transformer模型中&#xff0c;注意力层&#xff08;自注意力机制&#xff09;和位置感知前馈层&#xff08;Position-wise Feed-Forward Networks, FFNs&#xff09;分别承担着不同的任务&#xff1a; 注意力层&#xff08;自注意…

【Godot4自学手册】第二十四节利用DirectionalLight2D节点实现夜幕降临

根据我们的游戏情节&#xff0c;今天我们将要实现夜晚的来临&#xff0c;主要是用DirectionalLight2D节点来实现&#xff0c;当与NPC对完话后&#xff0c;整改场景逐渐变得黑暗。 一、添加DirectionalLight2D节点 切换到主场景main&#xff0c;选择根目录&#xff0c;单击添加…

Linux下的多线程编程:原理、工具及应用(2)

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;Flower of Life—陽花 0:34━━━━━━️&#x1f49f;──────── 4:46 &#x1f504; ◀️ ⏸ ▶️ ☰ …

Acwing.1262 鱼塘钓鱼(多路归并)

题目 有 N个鱼塘排成一排&#xff0c;每个鱼塘中有一定数量的鱼&#xff0c;例如&#xff1a;N5时&#xff0c;如下表&#xff1a; 即&#xff1a;在第 1个鱼塘中钓鱼第 1分钟内可钓到 10条鱼&#xff0c;第 2分钟内只能钓到 8条鱼&#xff0c;……&#xff0c;第 5分钟以后再…

Vite为什么比Webpack快

一、引言 主流的前端构建工具包括以下几种&#xff1a; Webpack&#xff1a;当下最热门的前端资源模块化管理和打包工具。它能够将许多松散的模块按照依赖和规则打包成符合生产环境部署的前端资源。同时&#xff0c;Webpack还支持代码分割&#xff0c;可以按需加载模块&#…

webshell隐藏哥斯拉流量修改sqlmap改ua

webshell隐藏 windows 1.隐藏shell attrib "文件名" s h attrib "文件名" -s -h 2.利用系统代号隐藏shell 创建文件夹名为Computer.{20D04FE0-3AEA-1069-A2D8-08002B30309D}&#xff0c;此时文件夹将变成我的电脑&#xff0c;无法看到里面的东西&…