python Seq2Seq模型源码实战,超详细Encoder-Decoder模型解析实战;早期机器翻译模型源码demo

1.Seq2Seq(Encoder-Decoder)模型简介

Seq2Seq(Encoder-Decoder)模型是一种常用于序列到序列(sequence-to-sequence)任务的深度学习模型。它由两个主要的组件组成:编码器(Encoder)和解码器(Decoder)。

编码器负责将输入序列编码成一个固定长度的向量,该向量包含了输入序列的语义信息。常用的编码器模型有循环神经网络(RNN)和Transformer。RNN编码器通过逐个时间步处理输入序列,并将最后一个时间步的隐藏状态作为向量输出。而Transformer编码器则将输入序列同时输入到多个注意力头中,从而捕捉输入的全局依赖关系。

解码器将编码器输出的向量作为输入,逐步生成目标序列。在每一个时间步,解码器通过学习到的语境信息和当前的输入,预测下一个输出的标记。常用的解码器模型也有RNN和Transformer。RNN解码器通常使用循环单元,每个时间步依次生成输出。而Transformer解码器可以同时计算目标序列中每个位置的输出,从而并行生成序列。

Seq2Seq模型适用于多种任务,如机器翻译、文本摘要、对话生成等。在机器翻译任务中,输入序列是源语言句子,目标序列是目标语言句子,模型的目标是学习源语言到目标语言的翻译规则。在训练阶段,给定源语言句子,模型通过编码器将其编码为一个向量表示,然后通过解码器生成目标语言句子。在测试阶段,给定一个源语言句子,模型使用编码器生成其向量表示,然后利用解码器生成目标语言句子。

Seq2Seq模型的一些改进包括使用注意力机制(Attention)来解决长序列问题,以及引入更复杂的网络结构和训练技巧来提升模型性能。此外,Seq2Seq模型也可以通过引入强化学习技术进行训练优化,以更好地适应特定任务需求。

2.Seq2Seq(Encoder-Decoder)模型代码实战

机器翻译任务

2.1构建语料库 

# 构建语料库,每行包含中文、英文(解码器输入)和翻译成英文后的目标输出 3 个句子
sentences = [
    ['成哥 喜欢 小冰', '<sos> ChengGe likes XiaoBing', 'ChengGe likes XiaoBing <eos>'],
    ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>'],
    ['深度学习 改变 世界', '<sos> DL changed the world', 'DL changed the world <eos>'],
    ['自然 语言 处理 很 强大', '<sos> NLP is so powerful', 'NLP is so powerful <eos>'],
    ['神经网络 非常 复杂', '<sos> Neural-Nets are complex', 'Neural-Nets are complex <eos>']]
word_list_cn, word_list_en = [], []  # 初始化中英文词汇表
# 遍历每一个句子并将单词添加到词汇表中
for s in sentences:
    word_list_cn.extend(s[0].split())
    word_list_en.extend(s[1].split())
    word_list_en.extend(s[2].split())
# 去重,得到没有重复单词的词汇表
word_list_cn = list(set(word_list_cn))
word_list_en = list(set(word_list_en))
# 构建单词到索引的映射
word2idx_cn = {w: i for i, w in enumerate(word_list_cn)}
word2idx_en = {w: i for i, w in enumerate(word_list_en)}
# 构建索引到单词的映射
idx2word_cn = {i: w for i, w in enumerate(word_list_cn)}
idx2word_en = {i: w for i, w in enumerate(word_list_en)}
# 计算词汇表的大小
voc_size_cn = len(word_list_cn)
voc_size_en = len(word_list_en)
print(" 句子数量:", len(sentences)) # 打印句子数
print(" 中文词汇表大小:", voc_size_cn) # 打印中文词汇表大小
print(" 英文词汇表大小:", voc_size_en) # 打印英文词汇表大小
print(" 中文词汇到索引的字典:", word2idx_cn) # 打印中文词汇到索引的字典
print(" 英文词汇到索引的字典:", word2idx_en) # 打印英文词汇到索引的字典

2.2数据预处理

import numpy as np # 导入 numpy
import torch # 导入 torch
import random # 导入 random 库
# 定义一个函数,随机选择一个句子和词汇表生成输入、输出和目标数据
def make_data(sentences):
    # 随机选择一个句子进行训练
    random_sentence = random.choice(sentences)
    # 将输入句子中的单词转换为对应的索引
    encoder_input = np.array([[word2idx_cn[n] for n in random_sentence[0].split()]])
    # 将输出句子中的单词转换为对应的索引
    decoder_input = np.array([[word2idx_en[n] for n in random_sentence[1].split()]])
    # 将目标句子中的单词转换为对应的索引
    target = np.array([[word2idx_en[n] for n in random_sentence[2].split()]])
    # 将输入、输出和目标批次转换为 LongTensor
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input)
    target = torch.LongTensor(target)
    return encoder_input, decoder_input, target 
# 使用 make_data 函数生成输入、输出和目标张量
encoder_input, decoder_input, target = make_data(sentences)
for s in sentences: # 获取原始句子
    if all([word2idx_cn[w] in encoder_input[0] for w in s[0].split()]):
        original_sentence = s
        break
print(" 原始句子:", original_sentence) # 打印原始句子
print(" 编码器输入张量的形状:", encoder_input.shape)  # 打印输入张量形状
print(" 解码器输入张量的形状:", decoder_input.shape) # 打印输出张量形状
print(" 目标张量的形状:", target.shape) # 打印目标张量形状
print(" 编码器输入张量:", encoder_input) # 打印输入张量
print(" 解码器输入张量:", decoder_input) # 打印输出张量
print(" 目标张量:", target) # 打印目标张量

2.3定义编码层和解码层

import torch.nn as nn # 导入 torch.nn 库
# 定义编码器类,继承自 nn.Module
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()       
        self.hidden_size = hidden_size # 设置隐藏层大小       
        self.embedding = nn.Embedding(input_size, hidden_size) # 创建词嵌入层       
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层    
    def forward(self, inputs, hidden): # 前向传播函数
        embedded = self.embedding(inputs) # 将输入转换为嵌入向量       
        output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出
        return output, hidden
# 定义解码器类,继承自 nn.Module
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()       
        self.hidden_size = hidden_size # 设置隐藏层大小       
        self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)  # 创建 RNN 层       
        self.out = nn.Linear(hidden_size, output_size) # 创建线性输出层    
    def forward(self, inputs, hidden):  # 前向传播函数     
        embedded = self.embedding(inputs) # 将输入转换为嵌入向量       
        output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出       
        output = self.out(output) # 使用线性层生成最终输出
        return output, hidden
n_hidden = 128 # 设置隐藏层数量
# 创建编码器和解码器
encoder = Encoder(voc_size_cn, n_hidden)
decoder = Decoder(n_hidden, voc_size_en)
print(' 编码器结构:', encoder)  # 打印编码器的结构
print(' 解码器结构:', decoder)  # 打印解码器的结构

2.4创建 Seq2Seq 架构 

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        # 初始化编码器和解码器
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_input, hidden, dec_input):    # 定义前向传播函数
        # 使输入序列通过编码器并获取输出和隐藏状态
        encoder_output, encoder_hidden = self.encoder(enc_input, hidden)
        # 将编码器的隐藏状态传递给解码器作为初始隐藏状态
        decoder_hidden = encoder_hidden
        # 使解码器输入(目标序列)通过解码器并获取输出
        decoder_output, _ = self.decoder(dec_input, decoder_hidden)
        return decoder_output
# 创建 Seq2Seq 架构
model = Seq2Seq(encoder, decoder)
print('S2S 模型结构:', model)  # 打印模型的结构

2.5模型训练

# 定义训练函数
def train_seq2seq(model, criterion, optimizer, epochs):
    for epoch in range(epochs):
       encoder_input, decoder_input, target = make_data(sentences) # 训练数据的创建
       hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态      
       optimizer.zero_grad()# 梯度清零        
       output = model(encoder_input, hidden, decoder_input) # 获取模型输出        
       loss = criterion(output.view(-1, voc_size_en), target.view(-1)) # 计算损失        
       if (epoch + 1) % 40 == 0: # 打印损失
          print(f"Epoch: {epoch + 1:04d} cost = {loss:.6f}")         
       loss.backward()# 反向传播        
       optimizer.step()# 更新参数
# 训练模型
epochs = 400 # 训练轮次
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器
train_seq2seq(model, criterion, optimizer, epochs) # 调用函数训练模型

 

2.6进行测试

# 定义测试函数
def test_seq2seq(model, source_sentence):
    # 将输入的句子转换为索引
    encoder_input = np.array([[word2idx_cn[n] for n in source_sentence.split()]])
    # 构建输出的句子的索引,以 '<sos>' 开始,后面跟 '<eos>',长度与输入句子相同
    decoder_input = np.array([word2idx_en['<sos>']] + [word2idx_en['<eos>']]*(len(encoder_input[0])-1))
    # 转换为 LongTensor 类型
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input).unsqueeze(0) # 增加一维    
    hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态    
    predict = model(encoder_input, hidden, decoder_input) # 获取模型输出    
    predict = predict.data.max(2, keepdim=True)[1] # 获取概率最大的索引
    # 打印输入的句子和预测的句子
    print(source_sentence, '->', [idx2word_en[n.item()] for n in predict.squeeze()])
# 测试模型
test_seq2seq(model, '成哥 喜欢 小冰')  
test_seq2seq(model, '自然 语言 处理 很 强大')

3.总结

        综上所述,Seq2Seq模型是一种用于序列到序列任务的深度学习模型,由编码器和解码器组成。编码器将输入序列编码为一个向量,解码器逐步生成目标序列。该模型在机器翻译、文本摘要、对话生成等任务中得到广泛应用,并有很多改进和扩展。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/343557.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

顶顶通呼叫中心中间件机器人压力测试配置(mod_cti基于FreeSWITCH)

介绍 顶顶通呼叫中心中间件机器人压力测试(mod_cit基于FreeSWITCH) 一、配置acl.conf 打开ccadmin-》点击配置文件-》点击acl.conf-》我这里是已经配置好了的&#xff0c;这里的192.168.31.145是我自己的内网IP&#xff0c;你们还需要自行修改 二、配置线路 打开ccadmin-&g…

社区分享|百果园选择DataEase搭档蜜蜂微搭实现企业数据应用一体化

百果园&#xff0c;全称为深圳百果园实业&#xff08;集团&#xff09;股份有限公司&#xff0c;2001年12月成立于深圳&#xff0c;2002年开出中国第一家水果专卖店。截至2022年11月&#xff0c;百果园全国门店数量超过5600家&#xff0c;遍布全国140多个城市&#xff0c;消费会…

实现单链表的增删改查

实现单链表的增删改查的功能&#xff1a;头部插入删除/尾部插入删除&#xff0c;查找&#xff0c;在指定位置之前插入数据&#xff0c;删除pos节点&#xff0c;在指定位置之后插入数据&#xff0c;删除pos之后的节点&#xff0c;销毁链表。 SListNode.h #pragma once #includ…

工程化代码管理高频面试题

1. git常用命令以及工作中都怎么工作 git init 初始化仓库 ​ git status 查看当前各个区域的代码状态。 ​ git log查看commit记录 ​ git reflog查看完整记录 ​ git add 添加工作区代码到暂存区 ​ Git commit 暂存区代码的提交 ​ git reset 代码的版本回退 ​ git stash …

Python中的open与JSON的使用

目录 1 使用 open 函数进行文件操作 2 使用 json 模块进行 JSON 数据处理&#xff1a; 2.1 写入JSON 文件 2.2 读取JSON 文件 在 Python 中&#xff0c;open 函数和 json 模块常用于文件的读写和 JSON 数据的处理。 1 使用 open 函数进行文件操作 open 函数用于打开文件…

docker安装Rabbitmq教程(详细图文)

目录 1.下载Rabbitmq的镜像 2.创建并运行rabbitmq容器 3.启动web客户端 4.访问rabbitmq的微博客户端 5.遇到的问题 问题描述&#xff1a;在rabbitmq的web客户端发现界面会弹出如下提示框Stats in management UI are disabled on this node 解决方法 &#xff08;1&#…

【JAVA语言-第14话】集合框架(一)——Collection集合,迭代器,增强for,泛型

目录 集合框架 1.1 概述 1.2 集合和数组的区别 1.3 Collection集合 1.3.1 概述 1.3.2 常用方法 1.4 迭代器 1.4.1 概述 1.4.2 常用方法 1.4.3 使用步骤 1.5 增强for循环 1.5.1 概述 1.5.2 使用 1.6 泛型 1.6.1 概述 1.6.2 使用泛型的利弊 1.6.2.1 好处 1…

06章【Eclipse与异常处理】

Eclipse开发环境使用入门 Eclipse开发环境使用入门 下载安装配置环境Eclipse入门 异常处理 异常 异常是阻止当前方法或作用域继续执行的问题&#xff0c;在程序中导致程序中断运行的一些指令 try与catch关键字 在程序中出现异常&#xff0c;就必须进行处理&#xff0c;处理格…

vue的模板语法-指令-事件绑定-条件渲染

VSCode代码片段生成 我们在前面练习Vue的过程中&#xff0c;有些代码片段是需要经常写的&#xff0c;我们再VSCode中我们可以生成一个代码片段&#xff0c;方便我们快速生成。 VSCode中的代码片段有固定的格式&#xff0c;所以我们一般会借助于一个在线工具来完成。 具体的步…

Shell脚本 2

一、变量 变量来源于数学&#xff0c;是计算机语言中能储存计算结果或能表示值的抽象概念。 保存将来会变化的数据&#xff0c;即使数据变化&#xff0c;直接调用变量即可&#xff0c;各种 Shell 环境中都使用到了“变量”的概念。Shell 变量用来存放系统和用户需要使用的特定…

基于JavaSwing+百度OCR开发的题库管理系统源码+数据库,能够将图片中的文字提取出来,保存题库中

题库管理系统 介绍 具备上传本地图片及截屏功能&#xff0c;并利用百度OCR技术&#xff0c;能够将图片中的文字提取出来&#xff0c;保存题库中&#xff0c;供以后查找。 技术方面&#xff0c;为制作exe可执行文件&#xff0c;该软件将JavaSwing,MybatisPlus,Spring三者进行集…

使用Python合并PPT文件

在日常工作和学习中&#xff0c;我们经常需要处理和管理大量的PPT文件。如果需要将多个PPT文件合并成一个文件&#xff0c;手动操作可能会非常繁琐和耗时。今天&#xff0c;我们将介绍如何使用Python编程语言和wxPython模块创建一个简单的GUI应用程序&#xff0c;来自动合并指定…

力扣精选算法100道——x的平方根(二分查找专题)

x的平方根 首先看到这个题目的时候&#xff0c;我们需要对上一个二分查找专题的题目进行深度理解&#xff0c;然后了解模板&#xff0c;这题是完全利用的上一题的模板知识进行&#xff0c;如果直接看这个题目可能是有点懵的&#xff0c;因为我这里直接利用模板进行解题。力扣…

架构篇14:高性能数据库集群-读写分离

文章目录 读写分离原理复制延迟分配机制小结高性能数据库集群的第一种方式是“读写分离”,其本质是将访问压力分散到集群中的多个节点,但是没有分散存储压力;第二种方式是“分库分表”,既可以分散访问压力,又可以分散存储压力。先来看看“读写分离”,下一篇我们再介绍“分…

函数递归的总结回顾

函数递归的本质就是其名字——递与归。先递出去&#xff0c; 再收回来。 而递归的思想就是为了让一个复杂的问题变成一个简单的问题 按照我目前的理解&#xff0c;函数递归有两点很重要。一个是它的限定条件&#xff0c;另一个就是函数体内“自调”&#xff08;就是自我调用语句…

ctfshow-命令执行(web53-web72)

目录 web53 web54 web55 web56 web57 web58 web59 web60 web61 web62 web63 web64 web65 web66 web67 web68 web69 web70 web71 web72 web53 …

微信小程序实现长按 识别图片二维码

第一种方案&#xff08;只需要在image里面加一个属性就可以了&#xff09; show-menu-by-longpress“{{true}}” <image show-menu-by-longpress"{{true}}" src"{{sysset.dyqewm}}" />第二种方案 放大预览图片&#xff0c;长按识别二维码 wxml <…

【GitHub项目推荐--一个简单的绘图应用程序(Rust + GTK4)】【转载】

一个用 Rust 和 GTK4 编写的简单的绘图应用程序来创建手写笔记。 Rnote 旨在成为一个简单但实用的笔记应用程序&#xff0c;用于手绘或注释图片或文档。它最终能够导入/导出各种媒体文件格式。而且输出的作品是基于矢量的&#xff0c;这使其在编辑和更改内容时非常灵活。 地址…

oracle古法unwrap手艺(oracle存储过程解码)

先说骚话 首先oracle官方是不支持解包的&#xff0c;见Doc ID 376303.1 但是需求来了。我就寄希望于民间大神的工具。很顺利&#xff0c;找到了几个&#xff0c;甚至还有网页版&#xff0c;以为是个easy money。 但是&#xff0c;我点背&#xff0c;总是能遇到精彩的情况。数…

C# .Net6搭建灵活的RestApi服务器

1、准备 C# .Net6后支持顶级语句&#xff0c;更简单的RestApi服务支持&#xff0c;可以快速搭建一个极为简洁的Web系统。推荐使用Visual Studio 2022&#xff0c;安装"ASP.NET 和Web开发"组件。 2、创建工程 关键步骤如下&#xff1a; 包添加了“Newtonsoft.Json”&…