RNN预测下一句文本简单示例

根据句子前半句的内容推理出后半部分的内容,这样的任务可以使用循环的方式来实现。

RNN(Recurrent Neural Network,循环神经网络)是一种用于处理序列数据的强大神经网络模型。与传统的前馈神经网络不同,RNN能够通过其循环结构捕获序列内部的时间依赖性或顺序信息。

在RNN中,每个时间步(timestep)的隐藏状态不仅取决于当前输入,还与上一时间步的隐藏状态有关。这种递归特性使得网络能记忆过去的信息,并将其与当前输入相结合以做出决策或生成输出。

由于存在“梯度消失”和“梯度爆炸”的问题,在长序列建模时原始RNN可能效果不佳。因此,发展出了更复杂的变体,如LSTM(Long Short-Term Memory)和GRU(Gated Recurrent Units),它们通过门控机制更好地保留长期依赖信息。这些改进后的循环神经网络广泛应用于语音识别、自然语言处理(NLP)、机器翻译、视频分析等多种领域。

training_file = 'wordstest.txt' 在里面随便写入一些文章,当做数据,

具体代码如下,写了注释

import torch
import torch.nn.functional as F
import time
import random
import numpy as np
from collections import Counter

# 确保每次结果可复现
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def elapsed(sec):
    if sec<60:
        return str(sec) + " sec"
    elif sec<(60*60):
        return str(sec/60) + " min"
    else:
        return str(sec/(60*60)) + " hr"


#中文多文件
def readalltxt(txt_files):
    labels = []
    for txt_file in txt_files:

        target = get_ch_lable(txt_file)
        labels.append(target)
    return labels


def get_ch_lable(txt_file):
    """
    读取数据
    :param txt_file:
    :return:
    """
    labels = ""
    with open(txt_file, 'rb') as f:
        for label in f:
            labels += label.decode('utf-8')
            #labels += label.decode('gb2312')

    return labels


def get_ch_lable_v(txt_file, word_num_map, txt_label=None):
    """
    字符转向量
    :param txt_file:
    :param word_num_map:
    :param txt_label:
    :return:
    """
    words_size = len(word_num_map)
    to_num = lambda word: word_num_map.get(word, words_size)
    if txt_file != None:
        txt_label = get_ch_lable(txt_file)

    labels_vector = list(map(to_num, txt_label))
    return labels_vector


# 文本预处理,生成词向量
training_file = 'wordstest.txt'
training_data = get_ch_lable(training_file)
print("Loaded training data...")
print('样本长度:', len(training_data))
counter = Counter(training_data)
words = sorted(counter)
words_size= len(words)
word_num_map = dict(zip(words, range(words_size)))  # 给每个字构建索引,通过索引来处理计算预测每一个字
print('字表大小:', words_size)
wordlabel = get_ch_lable_v(training_file, word_num_map)


'''
GRU 构建 RNN 模型
1、将输入的文字索引转为词嵌入
2、将词嵌入结果输入用 GRU 所形成的网络层
3、对步骤 2 的输出结果做全连接处理,得到维度为【字表长度】的预测结果,这个结果代表的是每个文字的频率
'''
class GRURNN(torch.nn.Module):
    def __init__(self, word_size, embed_dim,
                 hidden_dim, output_size, num_layers):
        super(GRURNN, self).__init__()

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        self.embed = torch.nn.Embedding(word_size, embed_dim)
        self.gru = torch.nn.GRU(input_size=embed_dim,
                                hidden_size=hidden_dim,
                                num_layers=num_layers, bidirectional=True)
        # bidirectional=True 代表网络是双向的,从前往后,从后往前
        # hidden_dim*2 代表包含了两个维度的层数
        # 全连接层(线性层),它将接收前面双向GRU输出的隐藏状态作为输入
        self.fc = torch.nn.Linear(hidden_dim*2, output_size)

    def forward(self, features, hidden):
        embedded = self.embed(features.view(1, -1))
        output, hidden = self.gru(embedded.view(1, 1, -1), hidden)
        output = self.fc(output.view(1, -1))
        return output, hidden

    def init_zero_state(self):
        """
        一个初始化隐藏状态的方法,主要用于循环神经网络(RNN)类的实例。这个方法的作用是为RNN创建一组全零初始隐藏状态。
        self.num_layers * 2: 表示双向RNN时的层数(如果模型是双向的,即参数bidirectional=True),
            因为每个方向都会有一个隐藏层,所以总共有num_layers * 2个隐藏层。
        1: 表示批量大小(batch size),在这里初始化的是单个样本的隐藏状态,因此设置为1。若需要处理批量数据,则应根据实际批量大小调整。
        self.hidden_dim: 表示隐藏层的维度(hidden dimension),也就是每个隐藏单元的特征数量。
        :return:
        """
        init_hidden = torch.zeros(self.num_layers * 2, 1, self.hidden_dim).to(DEVICE)
        return init_hidden


EMBEDDING_DIM = 10  # 向量的维度或者说长度
HIDDEN_DIM = 20  # 每一个隐藏层的神经元数量
NUM_LAYERS = 1  # 隐藏层数量

model = GRURNN(words_size, EMBEDDING_DIM, HIDDEN_DIM, words_size, NUM_LAYERS)
model = model.to(DEVICE)  # 将模型移动到指定设备上进行计算
# model.parameters():获取模型中所有需要优化的参数。
# Adam:是优化算法的一种,它基于梯度下降法,并结合了动量项(Momentum)和自适应学习率调整策略(RMSProp)。Adam通常在很多深度学习任务中表现良好,因为它能够自动调整学习率并减少对初始化学习率的敏感性。
# lr=0.005:表示设置学习率为0.005,这是Adam算法中的一个重要超参数,决定了每次更新参数时步伐的大小。
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


def evaluate(model, prime_str, predict_len, temperature=0.8):
    """
    评估函数
    :param model:
    :param prime_str: 一个表示起始序列的整数列表,每个整数代表词汇表中的索引
    :param predict_len: 指定要预测的字符或单词数量
    :param temperature: 控制生成文本时随机性的一个超参数,较小的值会让模型更倾向于生成概率最高的结果,较大的值则会增加多样性
    :return:
    """
    hidden = model.init_zero_state().to(DEVICE)
    predicted = ''

    # 处理输入语义
    # 将生成的字符添加到预测结果字符串 predicted 中
    for p in range(len(prime_str) - 1):
        _, hidden = model(prime_str[p], hidden)
        predicted += words[prime_str[p]]
    # 用最后一个输入字符开始进行预测
    inp = prime_str[-1]
    predicted += words[inp]

    for p in range(predict_len):
        output, hidden = model(inp, hidden)

        #从多项式分布中采样
        # 将模型输出转换为分布形式,通过除以温度 temperature 并求指数得到softmax分布
        output_dist = output.data.view(-1).div(temperature).exp()
        # 根据调整后的分布采样下一个字符的索引
        inp = torch.multinomial(output_dist, 1)[0]

        predicted += words[inp]

    return predicted


#定义参数训练模型
training_iters = 5000
display_step = 1000
n_input = 4
step = 0
offset = random.randint(0, n_input+1)  # 每次迭代结束时,将偏移值向后移动 n_input+1 个距离,保证输入样本的相对均匀
end_offset = n_input + 1

while step < training_iters:
    start_time = time.time()

    # 随机取一个位置偏移
    if offset > (len(training_data) - end_offset):
        offset = random.randint(0, n_input+1)

    # 取出偏移量为 4 的数据长度,因为文本时序列数据
    inwords = wordlabel[offset:offset + n_input]
    # [n_input, -1, 1] 表示重塑后的三维形状:
    # 第一维是序列长度(即每个序列有 n_input 个元素),
    # 第二维 -1 表示自动计算以适应原始数据大小,
    # 第三维为通道数(这里设为1,通常用于表示一维特征)
    inwords = np.reshape(np.array(inwords), [n_input, -1,  1])
    # 编码
    out_onehot = wordlabel[offset+1:offset+n_input+1]
    # 初始化隐藏层
    hidden = model.init_zero_state()
    '''
    模型完成一次前向传播计算并得到损失(loss)后,在反向传播(backpropagation)之前,需要调用这个函数来清零所有可训练参数的梯度。
    在开始新一轮的前向传播和反向传播之前,使用 optimizer.zero_grad() 来清零所有参数的梯度是至关重要的,确保每次优化步骤只基于当前批次数据计算出的梯度来进行参数更新
    '''
    optimizer.zero_grad()

    '''
    模型训练
    '''
    loss = 0.
    # 将输入数据 inwords 和目标数据 out_onehot 转换为PyTorch张量
    inputs, targets = torch.LongTensor(inwords).to(DEVICE), torch.LongTensor(out_onehot).to(DEVICE)
    for c in range(n_input):
        # 当前时间步的输入和前一时间步的隐藏状态运行模型,得到输出 (outputs) 和新的隐藏状态 (hidden)。
        outputs, hidden = model(inputs[c], hidden)
        # 计算当前时间步的交叉熵损失(Cross Entropy Loss),将模型预测的输出与实际的目标标签比较
        loss += F.cross_entropy(outputs, targets[c].view(1))
    # 所有时间步完成后,平均损失值
    loss /= n_input
    # 反向传播计算梯度:调用 .backward() 函数来计算关于损失函数关于模型参数的梯度
    loss.backward()
    # 使用优化器(在这里是 optimizer)根据计算出的梯度更新模型参数
    optimizer.step()

    #输出日志
    # with torch.set_grad_enabled(False): 这一上下文管理器用于在计算过程中暂时禁用梯度计算。这样,在打印损失、评估模型性能等操作时,
    # 不会占用额外的内存来存储中间计算的梯度,同时避免不必要的反向传播计算。
    with torch.set_grad_enabled(False):
        if (step+1) % display_step == 0:
            print(f'Time elapsed: {(time.time() - start_time)/60:.4f} min')
            print(f'step {step+1} | Loss {loss.item():.2f}\n\n')
            # torch.no_grad() 上下文管理器再次禁用梯度计算,以便于高效地进行模型评估,并且不影响之前或之后的梯度计算状态。
            with torch.no_grad():
                print(evaluate(model, inputs, 32), '\n')
            print(50*'=')
    step += 1
    offset += (n_input+1)#中间隔了一个,作为预测

print("Finished!")


# 使用模型
while True:
    prompt = "请输入几个字,最好是%s个: " % n_input
    sentence = input(prompt)
    inputword = sentence.strip()

    try:
        inputword = get_ch_lable_v(None, word_num_map, inputword)
        keys = np.reshape(np.array(inputword), [len(inputword), -1, 1])
        '''
        调用 model.eval() 方法将模型设置为评估模式。在评估模式下,模型中的批量归一化层(如果有)会使用经过训练时平均的移动统计量,
        并且不会更新模型参数(梯度计算被禁用)。
        接下来,通过 with torch.no_grad(): 语句创建了一个临时上下文,在此上下文中执行所有操作时都不会累积梯度。
        这对于生成任务非常关键,因为在这种情况下我们并不关心反向传播以更新模型权重,而是要利用当前模型状态来生成文本。
        '''
        model.eval()
        with torch.no_grad():
            sentence = evaluate(model, torch.LongTensor(keys).to(DEVICE), 32)

        print(sentence)
    except:
        print("该字我还没学会")

运行结果类似:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/352702.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

独享http代理安全性是更高的吗?

不同于共享代理&#xff0c;独享代理IP为单一用户提供专用的IP&#xff0c;带来了一系列需要考虑的问题。今天我们就一起来看看独享代理IP的优势&#xff0c;到底在哪里。 我们得先来看看什么是代理IP。简单来说&#xff0c;代理服务器充当客户机和互联网之间的中间人。当你使用…

C/C++ - 面向对象编程

面向对象 面向过程编程&#xff1a; 数据和函数分离&#xff1a;在C语言中&#xff0c;数据和函数是分开定义和操作的。数据是通过全局变量或传递给函数的参数来传递的&#xff0c;函数则独立于数据。函数为主导&#xff1a;C语言以函数为主导&#xff0c;程序的执行流程由函数…

复式记账的概念特点和记账规则

目录 一. 复式记账法二. 借贷记账法三. 借贷记账法的记账规则四. 复试记账法应用举例4.1 三栏式账户举例4.2 T型账户记录举例4.3 记账规则验证举例 \quad 一. 复式记账法 \quad 复式记账法是指对于任何一笔经济业务都要用相等的金额&#xff0c;在两个或两个以上的有关账户中进…

GIT使用,看它就够了

一、目的 Git的熟练使用是一个加分项&#xff0c;本文将对常用的Git命令做一个基本介绍&#xff0c;看了本篇文章&#xff0c;再也不会因为不会使用git而被嘲笑了。 二、设置与配置 在第一次调用Git到日常的微调和参考&#xff0c;用得最多的就是config和help命令。 2.1 gi…

4核16G幻兽帕鲁服务器性能测评,真牛

腾讯云幻兽帕鲁服务器4核16G14M配置&#xff0c;14M公网带宽&#xff0c;限制2500GB月流量&#xff0c;系统盘为220GB SSD盘&#xff0c;优惠价格66元1个月&#xff0c;277元3个月&#xff0c;支持4到8个玩家畅玩&#xff0c;地域可选择上海/北京/成都/南京/广州&#xff0c;腾…

第十六章 Spring cloud stream应用

文章目录 前言1、stream设计思想2、编码常用的注解3、编码步骤3.1、添加依赖3.2、修改配置文件3.3、生产3.4、消费3.5、延迟队列3.5.1、修改配置文件3.5.2、生产端3.5.2、消息确认机制 消费端 前言 https://github.com/spring-cloud/spring-cloud-stream-binder-rabbit 官方定…

GPT-SoVITS 本地搭建踩坑

GPT-SoVITS 本地搭建踩坑 前言搭建下载解压VSCode打开安装依赖包修改内容1.重新安装版本2.修改文件内容 运行总结 前言 传言GPT-SoVITS作为当前与BertVits2.3并列的TTS大模型&#xff0c;于是本地搭了一个&#xff0c;简单说一下坑。 搭建 下载 到GitHub点击此处下载 http…

【网站项目】基于SSM的246品牌手机销售信息系统

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

蓝桥杯——每日一练(简单题)

题目 给定n个十六进制正整数&#xff0c;输出它们对应的八进制数。 解析 一、通过input&#xff08;&#xff09;函数获得需要转化的数字个数 二、for循环的到数字 三、for循环先将16进制转化为10进制&#xff0c;再输出8进制 代码 运行结果

直线拟合(支持任意维空间的直线拟合,附代码)

文章目录 一、问题描述二、推导步骤三、 M A T L A B MATLAB MATLAB代码 一、问题描述 给定一系列的三维空间点 ( x i , y i , z i ) , i 1 , 2 , . . . , n (x_i,y_i,z_i),i1,2,...,n (xi​,yi​,zi​),i1,2,...,n&#xff0c;拟合得到直线的方程。本文的直线拟合方法适用于任…

如何用一根网线和51单片机做简单门禁[带破解器]

仓库:https://github.com/MartinxMax/Simple_Door 支持原创是您给我的最大动力… 原理 -基础设备代码程序- -Arduino爆破器程序 or 51爆破器程序- 任意选一个都可以用… —Arduino带TFT屏幕——— —51带LCD1602——— 基础设备的最大密码长度是0x7F&#xff0c;因为有一位…

10.Golang中的map

目录 概述map实践map声明代码 map使用代码 结束 概述 map实践 map声明 代码 package mainimport ("fmt" )func main() {// 声明方式1var map1 map[string]stringif map1 nil {fmt.Println("map1为空")}// 没有分配空间&#xff0c;是不能使用的// map…

Vulnhub-dc6

信息收集 # nmap -sn 192.168.1.0/24 -oN live.port Starting Nmap 7.94 ( https://nmap.org ) at 2024-01-25 14:39 CST Nmap scan report for 192.168.1.1 Host is up (0.00075s latency). MAC Address: 00:50:56:C0:00:08 (VMware) Nmap scan report for 192.168.1.2…

IS-IS:10 ISIS路由渗透

ISIS的非骨干区域&#xff0c;无明细路由&#xff0c;容易导致次优路径问题。可以引入明细路由。 在IS-IS 网络中&#xff0c;所有的 level-2 和 level-1-2 路由器构成了一个连续的骨干区域。 level-1区域必须且只能与骨干区域相连&#xff0c;不同 level-1 区域之间不能直接…

.NET高级面试指南专题一【委托和事件】

在C#中&#xff0c;委托&#xff08;Delegate&#xff09;和事件&#xff08;Event&#xff09;是两个重要的概念&#xff0c;它们通常用于实现事件驱动编程和回调机制。 委托定义&#xff1a; 委托是一个类&#xff0c;它定义了方法的类型&#xff0c;使得可以将方法当作另一个…

SpringMVC第六天(拦截器)

概念 拦截器(Interceptor)是一种动态拦截方法调用的机制&#xff0c;在SpringMVC中动态拦截控制器方法的执行 作用&#xff1a; 在指定的方法调用前后执行预先设定的代码 阻止原始方法的执行 拦截器与过滤器的区别 归属不同&#xff1a;Filter属于Servlet技术&#xff0c;I…

递归方法猴子吃桃问题

public class A {public static void main(String[] args) {System.out.println("第一天有&#xff1a;"f(1)"个");System.out.println("第二天有&#xff1a;"f(2)"个");System.out.println(".....");System.out.println(&…

【揭秘】ForkJoinTask全面解析

内容摘要 ForkJoinTask的显著优点在于其高效的并行处理能力&#xff0c;它能够将复杂任务拆分成多个子任务&#xff0c;并利用多核处理器同时执行&#xff0c;从而显著提升计算性能&#xff0c;此外&#xff0c;ForkJoinTask还提供了简洁的API和强大的任务管理机制&#xff0c…

保姆级教学:Java项目从0到1部署到云服务器

目录 1、明确内容 2、apt 2.1、apt 语法 2.2、常用命令 2.3、更新apt 3、安装JDK17 4、安装MySQL 4.1、安装 4.2、检查版本及安装位置 4.3、初始化MySQL配置⭐ 4.4、检查状态 4.5、配置远程访问⭐ 4.6、登录MySQL 4.7、测试数据库 4.8、设置权限与密码⭐ 5、安…

cmake工具的安装

1、简介 CMake 是一个开源的、跨平台的自动化建构系统。它用配置文件控制编译过程的方式和Unix的make相似&#xff0c;只是CMake并不依赖特定的编译器。CMake并不直接建构出最终的软件&#xff0c;而是产生标准的建构文件&#xff08;如 Unix 的 Makefile 或 Windows Visual C …