NLP Seq2Seq模型

🍨 本文为[🔗365天深度学习训练营学习记录博客
 
🍦 参考文章:365天深度学习训练营
 
🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

Seq2Seq模型是一种深度学习模型,用于处理序列到序列的任务,它由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器(Encoder): 编码器负责将输入序列(例如源语言句子)编码成一个固定长度的向量,通常称为上下文向量或编码器的隐藏状态。编码器可以是循环神经网络(RNN)、长短期记忆网络(LSTM)或者变种如门控循环单元(GRU)等。编码器的目标是捕捉输入序列中的语义信息,并将其编码成一个固定维度的向量表示。

  2. 解码器(Decoder): 解码器接收编码器生成的上下文向量,并根据它来生成输出序列(例如目标语言句子)。解码器也可以是RNN、LSTM、GRU等。在训练阶段,解码器一次生成一个词或一个标记,并且其隐藏状态从一个时间步传递到下一个时间步。解码器的目标是根据上下文向量生成与输入序列对应的输出序列。

在训练阶段,Seq2Seq模型的目标是最大化目标序列的条件概率给定输入序列。为了实现这一点,通常使用了一种称为教师强制(Teacher Forcing)的技术,即将目标序列中的真实标记作为解码器的输入。但是,在推理阶段(即模型用于生成新的序列时),解码器则根据先前生成的标记生成下一个标记,直到生成一个特殊的终止标记或达到最大长度为止。

下面演示了如何使用PyTorch实现一个简单的Seq2Seq模型,用于将一个序列翻译成另一个序列。这里我们将使用一个虚构的数据集来进行简单的法语到英语翻译。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

# 定义数据集
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 定义Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)
        
    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

# 定义Decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, input, hidden):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden

# 定义Seq2Seq模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        encoder_outputs, hidden = self.encoder(src)
        
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            outputs[t] = output
            teacher_force = np.random.rand() < teacher_forcing_ratio
            top1 = output.argmax(1) 
            input = trg[t] if teacher_force else top1
        
        return outputs

# 设置参数
INPUT_DIM = 10
OUTPUT_DIM = 10
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 64
N_LAYERS = 1
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# 实例化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(enc, dec, device).to(device)

# 打印模型结构
print(model)

# 定义训练函数
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src, trg = batch
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

# 定义测试函数
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, trg = batch
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0) # 关闭teacher forcing

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

# 示例数据
train_data = [(torch.tensor([1, 2, 3]), torch.tensor([3, 2, 1])),
              (torch.tensor([4, 5, 6]), torch.tensor([6, 5, 4])),
              (torch.tensor([7, 8, 9]), torch.tensor([9, 8, 7]))]

# 超参数
BATCH_SIZE = 3
N_EPOCHS = 10
LEARNING_RATE = 0.001
CLIP = 1

# 数据集与迭代器
train_dataset = SimpleDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 定义损失函数与优化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# 训练模型
for epoch in range(N_EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/421138.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

wireshark抓取localhost(127.0.0.1)数据包

打开wireshark中&#xff0c;在"capture"菜单中&#xff0c;选择"interfaces"子菜单&#xff0c;在列出的接口中选中"Adapter for loopback traffic capture"即可。 必须安装了Npcap才有此选项&#xff0c;否则需要重新安装wireshark。 抓包截图…

Vue+SpringBoot打造城市桥梁道路管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询城市桥梁4.2 新增城市桥梁4.3 编辑城市桥梁4.4 删除城市桥梁4.5 查询单个城市桥梁 五、免责说明 一、摘要 1.1 项目介绍 基于VueSpringBootMySQL的城市桥梁道路管理系统&#xff0c;支持…

gpt批量工具,gpt批量生成文章工具

GPT批量工具在今天的数字化时代扮演着越来越重要的角色&#xff0c;它们通过人工智能技术&#xff0c;可以自动批量生成各种类型的文章&#xff0c;为用户提供了便利和效率。本文将介绍5款不同的GPT批量工具&#xff0c;并介绍一款知名的147GPT生成工具&#xff0c;以及另外一款…

beets,一个有趣的 Python 音乐信息管理工具!

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站AI学习网站。 目录 前言 什么是Beet库&#xff1f; 安装Beet库 使用Beet库 Beet库的功能特性 1. 多种音乐格式支持 2. 自动标签识…

ECMAScript-262 @2023版本中的关键字和保留字

1、什么是标识符&#xff1f; 所谓标识符&#xff0c;就是javascript里的变量、函数、属性或函数参数的名称&#xff0c;可由一个或多个字符组成&#xff0c;当然标识符有命名规范 标识符第一个字符必须是 一个字母、下划线&#xff08;_&#xff09;或美元符号&#xff08;$…

在 Rust 中实现 TCP : 1. 联通内核与用户空间的桥梁

内核-用户空间鸿沟 构建自己的 TCP栈是一项极具挑战的任务。通常&#xff0c;当用户空间应用程序需要互联网连接时&#xff0c;它们会调用操作系统内核提供的高级 API。这些 API 帮助应用程序 连接网络创建、发送和接收数据&#xff0c;从而消除了直接处理原始数据包的复杂性。…

烧脑问题解决办法:如何选择一款合适自己的手机流量卡

现在社会人们越来越离不开手机了&#xff0c;手机给我们生活带来了翻天覆地的变化&#xff0c;手机需要最多的就是流量了&#xff0c;所以选择一款合适自己的手机流量卡就显得尤为重要了&#xff0c;今天小编就给大家来分享一下我的经验&#xff0c;希望对大家能有帮助&#xf…

STM32合并烧录IAP+APP

STM32合并烧录IAPAPP 通过查找相关资料 有以下几种合并方法 第一种直接将二进制文件用记事本合并 而要合并的就是就将IAP最后的一行删除&#xff0c;然后将APP程序追加在后面。 &#xff08;修改前&#xff09; 把APP的.hex 全部内容拷贝复制到 刚才删掉结束语句的 IAP的.…

基于springboot+vue的公交线路查询系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

如何用ChatGPT+GEE+ENVI+Python进行高光谱,多光谱成像遥感数据处理?

原文链接&#xff1a;如何用ChatGPTGEEENVIPython进行高光谱&#xff0c;多光谱成像遥感数据处理&#xff1f; 第一&#xff1a;遥感科学 从摄影侦察到卫星图像 遥感的基本原理 遥感的典型应用 第二&#xff1a;ChatGPT ChatGPT可以做什么&#xff1f; ChatGPT演示使用 …

CSS的弹性布局

CSS 的弹性布局 前言 前端中为了实现页面的布局效果&#xff0c;采用的一个技术手段&#xff0c;它在前端开发的技术场景是非常广泛的 实现上述区域的页面相关的布局效果&#xff0c;就可以使用弹性布局来完成 弹性布局(flex布局) flex 是 flexible box 的缩写&#xff0c;…

解锁AI大模型秘籍:未来科技的前沿探索

在当今这个技术高速发展的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经成为了我们生活中不可或缺的一部分。从简单的个人助手到复杂的数据分析和决策制定&#xff0c;AI的应用范围日益扩大&#xff0c;其目的是为了让我们的生活变得更加智能化。本文旨在探讨AI如何…

让边缘智能助力配电房监测,P1600网关引领智慧新潮

科技与生活的交融 在现代社会的脉搏中&#xff0c;科技与生活紧密交融。我们的生活方式&#xff0c;正在由传统的模式&#xff0c;逐步向智能化、便捷化的方向迈进。配电房作为城市的重要基础设施&#xff0c;其稳定运行关系到千家万户的生活和工作。如何有效监控配电房的状态…

Linux技巧|centos7|重新认识和学习egrep和grep命令

前言&#xff1a; 相信提高文本检索工具&#xff0c;大家脑海里肯定有很多工具会自动跳出来&#xff0c;比如&#xff0c;grep&#xff0c;egrep&#xff0c;sed&#xff0c;cat&#xff0c;more&#xff0c;less&#xff0c;cut&#xff0c;awk&#xff0c;vim&#xff0c;vi…

剑指offer刷题记录Day 1 03.数组中重复的数字 ---> 06.从尾到头打印链表

名人说&#xff1a;莫道桑榆晚&#xff0c;为霞尚满天。——刘禹锡&#xff08;刘梦得&#xff0c;诗豪&#xff09; 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 0、关于核心代码模式该怎么刷题&#xff1f;1、…

1_SQL

文章目录 前端复习SQL数据库的分类关系型数据库非关系型数据库&#xff08;NoSQL&#xff09; 数据库的构成软件架构MySQL内部数据组织方式 SQL语言登录数据库数据库操作查看库创建库删除库修改库 数据库中表的操作选择数据库创建表删除表查看表修改表 数据库中数据的操作添加数…

MATLAB练习题:排队论问题的模拟

​讲解视频&#xff1a;可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇&#xff08;数学建模清风主讲&#xff0c;适合零基础同学观看&#xff09;_哔哩哔哩_bilibili 下面我们来看一道排队论的题目。假设某银行工作时间内只有一个…

Java多线程导入Excel示例

在导入Excel的时候&#xff0c;如果文件比较大&#xff0c;行数很多&#xff0c;一行行读往往速度比较慢&#xff0c;为了加快导入速度&#xff0c;我们可以采用多线程的方式 话不多说直接上代码 首先是Controller import com.sakura.base.service.ExcelService; import com.s…

ADBMS1818芯片资料介绍(1)

ADBMS1818数据手册和产品信息 | Analog Devices 一、芯片简介  可测量多达18串电池电压  3 mV最大总测量误差  内置isoSPI接口  使用单根双绞线&#xff0c;长达100米  290 μs内可完成系统中所有单体电池电压测量 二、芯片内核和isoSPI状态 ADBMS1818内核状态说明…

Mac清理电脑垃圾工具CleanMyMac X4.15中文免费版下载

嘿&#xff0c;亲爱的Mac用户们&#xff0c;你们是否曾经想象过你的电脑是一座美丽的城市&#xff0c;而垃圾文件则是那些不速之客&#xff0c;悄悄堆积&#xff0c;影响着城市的整体美观。今天&#xff0c;我们就来聊聊Mac为什么会产生垃圾文件&#xff0c;这些垃圾文件会对你…