机器学习——RNN、LSTM

RNN

特点:输入层是层层相关联的,输入包括上一个隐藏层的输出h1和外界输入x2,然后融合一个张量,通过全连接得到h2,重复
优点:结构简单,参数总量少,在短序列任务上性能好
缺点:在长序列中效果不好,容易梯度提升或者爆炸

import torch
import torch.nn as nn
import torch.nn.functional as F


# 参数一: 输入张量的词嵌入维度5, 参数二: 隐藏层的维度(也就是神经元的个数), 参数三: 网络层数
rnn = nn.RNN(5, 6, 2)

# 参数一: sequence_length序列长度, 参数二: batch_size样本个数, 参数三: 词嵌入的维度, 和RNN第一个参数匹配
input1 = torch.randn(1, 3, 5)

# 参数一: 网络层数, 和RNN第三个参数匹配, 参数二: batch_size样本个数, 参数三: 隐藏层的维度, 和RNN第二个参数匹配
h0 = torch.randn(2, 3, 6)


output, hn = rnn(input1, h0)
# print(output.shape)
# torch.Size([1, 3, 6])
# print(hn.shape)
# torch.Size([2, 3, 6])



LSTM

解决了RNN的缺点,在长序列中效果好,现在仔细研究中间图的结构

在这里插入图片描述

最左边是的黄色矩形部分是遗忘门,就是结合前一层的的h1+输入x2拼接,然后经过全连接层后输出ft,就是把之前的一些信息遗忘一部分,
在这里插入图片描述

第二第三是一部分输入门,拼接完过后经过全连接结合σ激活函数it,以及拼接后用一个tanh激活函数ct,然后和上一层的结合起来
在这里插入图片描述
第三部分是输出门,图的右边黄色的矩形到结尾
在这里插入图片描述
此外Bi-LSTM,是双向的,相当于运用了两层LSTM但是方向不同,前面是单向的,信息从左到右的的传递相当于考虑前面的信息,Bi-LSTM是左右信息都考虑,然后拼接结果

# -------------------------------------
import torch
import torch.nn as nn

lstm=nn.LSTM(5,6,2)
input=torch.randn(1,3,5)
h0=torch.randn(2,3,6)
c0=torch.randn(2,3,6)
output,(hn,cn)=lstm(input,(hn,cn))

#-------------------------




class Attention(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        super(Attention, self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size
        
        self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)
        
        self.attn_combine = nn.Linear(self.query_size + self.value_size2, self.output_size)
        
    def forward(self, Q, K, V):
        attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
        
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
        
        output = torch.cat((Q[0], attn_applied[0]), 1)
        
        output1 = self.attn_combine(output).unsqueeze(0)
        
        return output1, attn_weights


query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attention(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1, 1, 32)
K = torch.randn(1, 1, 32)
V = torch.randn(1, 32, 64)
res = attn(Q, K, V)
# print(res[0])
# print(res[0].shape)
# print('*****')
# print(res[1])
# print(res[1].shape)

# --------------------------------------------

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/728546.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

浅谈医工交叉方向SCI写作

笔者因为工作性质原因,这几年写了不少医学人工智能方向的SCI论文,顺带每年相关的论文的阅读量也有小几百篇,特别是在医学影像AI方向,也算是小有心得,今天就简单聊一下医工交叉(影像AI)方向的SCI论文写作与投稿问题。 首…

计算机网络:应用层 - 域名系统 DNS

计算机网络:应用层 - 域名系统 DNS 域名结构域名服务器域名解析迭代查询递归查询 互联网中的每台设备都有一个唯一的IP地址,但这些地址通常是复杂的数字组合,例如 172.217.160.142,难以记忆和识别。域名系统将这些复杂的IP地址与易…

kafka基础概念

目录 1、kafka简介 2、kafka使用场景 3、kafka基础概念 3.1、消息 3.1.1、消息构成详解 3.1.2、消息存储设计 3.2、topic 3.3、partition 3.4、offset 3.5、replication 3.5.1、replication简介 3.5.2、副本角色 3.5.3、副本类型 3.5.3.1、副本类型简介 3.5.3.2、…

低代码平台实践:打造高效动态表单解决方案的探索与思考

🔥需求背景 我司业务同事在抓取到候选人的简历之后,经常会出现,很多意向候选人简历信息不完整,一个个打电话确认的情况,严重影响了HR的工作效率,于是提出我们可以通过发送邮件、短信、H5链接的方式来提醒候…

【Linux】使用ntpdate同步

ntpdate 是一个在 Linux 系统中用于同步系统时间的命令行工具,它通过与 NTP 服务器通信来调整本地系统时钟。然而,需要注意的是,ntpdate 已经被许多现代 Linux 发行版弃用。 安装 yum install -y ntpdate 查看时间 date同步时间 ntpdate ntp…

端口已被占用 1080

http://www.nirsoft.net/utils/cports.html#DownloadLinks 下载后解压,直接运行cports.exe. 这里写图片描述 找到被占用的端口,右键选择 “Close Selected TCP Connections”

python使用pyautogui自动化模拟鼠标、键盘操作、截屏、识别图片位置

🌈所属专栏:【python】✨作者主页: Mr.Zwq✔️个人简介:一个正在努力学技术的Python领域创作者,擅长爬虫,逆向,全栈方向,专注基础和实战分享,欢迎咨询! 您的…

Spring Boot集成tensorflow实现图片检测服务

1.什么是tensorflow? TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,…

Redis通用命令详解

文章目录 一、Redis概述1.1 KEYS:查看符合模板的所有 key1.2 DEL:删除一个指定的 key1.3 EXISTS:判断 key 是否存在1.4 EXPIRE:给一个 key 设置有效期,有效期到期时该 key 会被自动删除1.5 TTL:查看一个 ke…

《梦醒蝶飞:释放Excel函数与公式的力量》4.1if函数

第4章:逻辑与条件函数 第一节4.1 if函数 在Excel中,逻辑函数用于处理基于特定条件的真假判断,它们是构建复杂公式和进行高级数据分析的基础。本章将深入探讨逻辑函数的使用方法,特别是IF函数,这是Excel中最为常用的条…

Spring Boot程序打包docker镜像

1.将springboot程序使用maven package打包出jar。 2.创建dockerfile。 FROM openjdk:8 VOLUME /tmp EXPOSE 8601 #ADD 后面的参数是项目名字 / 后面的参数是自定义的别名 ADD webflux-hello-0.0.1-SNAPSHOT.jar /webflux-hello.jar #这里的最后一个变量需要和前面起的别名相同…

构建智慧高速公路:软件管理平台业务架构解析

随着交通网络的不断完善和技术的快速发展,智慧高速公路正成为交通领域的重要发展方向。在智慧高速公路系统中,软件管理平台扮演着关键的角色,它不仅是管理各种设备和系统的核心,还承担着数据监控、故障诊断、维护管理等重要任务。…

探索序列到序列模型:了解编码器和解码器架构的强大功能

目录 一、说明 二、什么是顺序数据? 三、编码器解码器架构的高级概述: 3.1 编码器和解码器架构的简要概述: 3.2 训练机制:编码器和解码器架构中的前向和后向传播: 四、编码器解码器架构的改进: 4.1.…

1.3自然语言的分布式表示-word2vec

文章目录 0基于计数的方法的问题1什么是基于推理的方法2神经网络中单词的表示2.1 MatMul 层的实现 3简单word2vec的实现3.1 CBOW模型的结构3.1.1神经元视角3.1.2层的视角3.1.3多层共享权重时存在的问题 3.2 CBOW模型的学习3.3单词的分布式表示 代码都位于:nlp&#…

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 目录 Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 一、简单介绍 二、实现原理 三、注意实现 四、简单效果预览 五、案例简单实现步骤 六、关键代码 一、简单介绍 Unit…

推荐5个AI辅助生成论文、降低查重率的网站【2024最新】

一、引言 对于忙碌的学生来说,毕业论文通常是一项艰巨的任务。幸运的是,随着人工智能技术的发展,现在有一些工具可以帮助学生轻松完成论文。本文将介绍五个免费的AI工具,它们能够一键帮助你生成毕业论文,让你的学术生…

Minecraft服务端配置教程

一、下载服务端核心文件 下载 | FastMirror 无极镜像 | 我的世界核心下载 Downloads for Minecraft Forge for MinecraftForge服务端下载 MCVersions.net - Minecraft Versions Download List原版 注意,这个网站可以下载Forge水桶等插件和模组端,如果…

华为---静态路由-浮动静态路由及负载均衡(二)

7.2 浮动静态路由及负载均衡 7.2.1 原理概述 浮动静态路由(Floating Static Route)是一种特殊的静态路由,通过配置去往相同的目的网段,但优先级不同的静态路由,以保证在网络中优先级较高的路由,即主路由失效的情况下&#xff0c…

AI界的“视频滤镜”(Stable Diffusion进阶篇-TemporalKit视频风格转化),手把手教你制作原创AI视频

大家好,我是向阳 在之前的文章中我也分享过如何进行AI视频的制作,说是AI视频其实也就是通过Stable Diffusion进行视频重绘,也就是将一个视频一帧一帧重绘为自己想要的画面,然后再连贯起来成为视频。 这个东西其实比较耗费时间和…

Linux中如何通过脚本文件实现对外流量的实时监测

while true #无限循环 doclear #清除屏幕flow$( cat /proc/net/dev | awk /ens160/{print $2} ) #查看receive流量trantsmit_flow$( cat /proc/net/dev | awk /ens160/{print $10} ) #查看transimit流量external_traffic$(( flow - trantsmit_flow )) #输入流量减去输…