【神经网络与深度学习】Long short-term memory网络(LSTM)

简单介绍

在这里插入图片描述
API介绍:

nn.LSTM(input_size=100, hidden_size=10, num_layers=1,batch_first=True, bidirectional=True)

inuput_size: embedding_dim
hidden_size: 每一层LSTM单元的数量
num_layers: RNN中LSTM的层数
batch_first: True对应[batch_size, seq_len, embedding_dim]
bidiectional: True对应使用双向LSTM

在这里插入图片描述
实例化LSTM对象后,不仅要传入数据,还有传入前一次的h_0和c_0
lstm(input, (h_0, c_0))
LSTM默认输出(output, (h_n, c_n))
output: [ seq_len, batch, hidden_size*num_directions ] (若batch_first=false)
h_n: [num_directions, batch, hidden_size]
c_n : [num_directions, batch, hidden_size]

import torch.nn as nn
import torch.nn.functional as F
import torch

batch_size = 10
seq_len =20 #句子长度
vocab_size = 100 # 词典数量
embedding_dim = 30 # 用embedding_dim长度的向量表示一个词语
hidden_size = 18

input = torch.randint(0, 100, [batch_size, seq_len])
print(input.size())
print("*"*100)
# 经过embedding
embed = nn.Embedding(vocab_size, embedding_dim)

input_embed = embed(input)  # [bs, seq_len, embedding_dim]
print(input_embed.size())
print("*"*100)
lstm = nn.LSTM(embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
output,(h_n, c_n) = lstm(input_embed)
print(output.size())
print("*"*100)
print(h_n.size())
print("*"*100)
print(c_n.size())

通常由最后一个输出代替整个句子

使用双向LSTM实现

"""
定义模型
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib import ws,max_len
from dataset import get_data
import lib
import os
import numpy as np
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.embedding = nn.Embedding(len(ws), 100)
        self.lstm = nn.LSTM(input_size=100, hidden_size=lib.hidden_size, num_layers=lib.num_layers,batch_first=True, bidirectional=lib.bidirectional, dropout=lib.dropout)
        self.fc = nn.Linear(lib.hidden_size*2, 2)
        



    def forward(self, input):
        """

        :param input: [batch_size, max_len]
        :return:
        """
        x = self.embedding(input) # [batch_size, max_len, 100]
        x,(h_n,c_n)= self.lstm(x)
        output = torch.cat([h_n[-2,:,:],h_n[-1,:,:]],dim=-1)
        output = self.fc(output)
        return F.log_softmax(output,dim=-1)

model = MyModel().to(lib.device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
if os.path.exists("./model0/model.pkl"):
    model.load_state_dict(torch.load("./model0/model.pkl"))
    optimizer.load_state_dict(torch.load("./model0/optimizer.pkl"))

def train(epoch):
    for idx,(input,target) in enumerate(get_data(train=True)):
        input = input.to(lib.device)
        target = target.to(lib.device)
        # 梯度清零
        optimizer.zero_grad()
        output= model(input)
        loss = F.nll_loss(output,target)
        loss.backward()
        optimizer.step()
        print(epoch, idx, loss.item())

        if idx%100==0:
            torch.save(model.state_dict(),"./model0/model.pkl")
            torch.save(optimizer.state_dict(),"./model0/optimizer.pkl")

def eval():
    loss_list = []
    acc_list = []
    for idx,(input,target) in enumerate(get_data(train=False, batch_size=lib.test_batch_size)):
        input = input.to(lib.device)
        target = target.to(lib.device)
        with torch.no_grad():
            output= model(input)
            loss = F.nll_loss(output,target)
            loss_list.append(loss.cpu().item())
            pre = output.max(dim=-1)[-1]
            acc = pre.eq(target).float().mean()
            acc_list.append(acc.cpu().item())

    print("total loss, acc:", np.mean(loss_list), np.mean(acc_list))




if __name__ == '__main__':
    for i in range(10):
        train(epoch=i)

    eval()






本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/546342.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

fpga基础|如何在XDC文件中使用get_pins/ports/cells/nets/clocks查找指定的对象

大家好,我是数字小熊饼干,一个练习时长两年半的ic打工人。我在两年前通过自学跨行社招加入了IC行业。现在我打算将这两年的工作经验和当初面试时最常问的一些问题进行总结,并通过汇总成文章的形式进行输出,相信无论你是在职的还是…

P1706 全排列问题

原题链接:全排列问题 - 洛谷 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 dfs典题 3. 代码实现 #define _CRT_SECURE_NO_WARNINGS 1 #include<bits/stdc.h> using namespace std; #define ll long long #define endl \n const int N 2…

SaaS知识库工具推荐来了,有这些就够用了!

在快节奏、以信息为核心的工作环境中&#xff0c;如何有效管理和共享知识成为了很多企业面临的一大挑战。幸运的是&#xff0c;现在市面上有许多出色的SaaS&#xff08;Software as a Service&#xff09;知识库工具可以帮助我们。这些工具便捷、高效&#xff0c;能够满足各种规…

pycharm爬虫模块(scrapy)基础使用

今天学了个爬虫。在此记录 目录 一.通过scrapy在命令行创建爬虫项目 二.判断数据为静态还是动态 三.pycharm中的设置 三:爬虫主体 四.pipelines配置&#xff08;保存数据的&#xff09; 五.最终结果 一.通过scrapy在命令行创建爬虫项目 1.首先需要在cmd中进入到python文…

2024.4.15

代码&#xff1a;uart4.c #include "uart4.h"void uart4_config() {//使能GPIO\GPIOG\UART4的外设时钟RCC->MP_AHB4ENSETR | (0x1<<1);RCC->MP_AHB4ENSETR | (0x1<<6);RCC->MP_APB1ENSETR | (0x1<<16);//设置PB2和PG11为管脚复用功能//P…

游戏行业科普 (二)游戏是怎么做出来,怎么卖出去的?

游戏行业科普系列文章&#xff0c;大家可以关注起来&#xff0c;等我慢慢分享~~ 《蛋仔派对》 一、研运流程--游戏是怎么做出来的 一款游戏的开发和运营大体上可以分为预研立项、设计开发、测试调优、发行上线和成熟运营几个阶段。 1&#xff09;预研立项&#xff1a; 初始研…

YOLOv9有效改进专栏汇总|未来更新卷积、主干、检测头注意力机制、特征融合方式等创新![2024/4/14]

​ 专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;助力高效涨点&#xff01;&#xff01;&#xff01; 专栏介绍 YOLOv9作为最新的YOLO系列模型&#xff0c;对于做目标检测的同学是必不可少的。本专栏将针对2024年最新推出的YOLOv9检测模型&#xff0…

Matlab|基于广义Benders分解法的综合能源系统优化规划

目录 1 主要内容 广义benders分解法流程图&#xff1a; 优化目标&#xff1a; 约束条件&#xff1a; 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序复现文章《综合能源系统协同运行策略与规划研究》第四章内容基于广义Benders分解法的综合能源系统优化规划&…

树莓派点亮双色LED

双色LED灯准确来说叫双基色LED灯,是指模块只能显示2种颜色,一般是红色和绿色,可以有三种状态 :灭,颜色1亮,颜色2亮,根据颜色组合的不同,分为红蓝双色,黄蓝双色,红绿双色等等。 接线:将引脚S(绿色)和中间引脚(红色)连接到Raspberry Pi的GPIO接口上,对Raspberry…

针对MaxCompute优化案例分享

声明 原文来源&#xff1a;微信公众号&#xff1a;阿里云开发者 前言 MaxCompute 是阿里巴巴集团推出的一种大数据计算平台&#xff0c;用于处理海量数据和进行数据分析。它提供了高可靠性、高扩展性和高性能的数据处理能力&#xff0c;支持 SQL 查询、MapReduce 计算和机器…

1113. 红与黑--Flood Fill 算法

目录 1113. 红与黑--Flood Fill 算法---宽搜&#xff08;BFS&#xff09;或DFS&#xff09; 输入格式 输出格式 数据范围 输入样例&#xff1a; 输出样例&#xff1a; 思路&#xff1a; 1.BFS 思路&#xff1a; 2.DFS 思路 方法一&#xff1a;&#xff08;BFS&#x…

hadoop最新详细版安装教程 2024 最新版

文章目录 hadoop安装教程 2024最新版提前准备工作用户配置安装 SSH Server免密登录设置编辑 SSH server 配置文件配置Java环境查看java 版本验证 环境变量设置安装Hadoop下载hadoop解压hadoop查看hadoop 版本hadoop 配置编辑编辑配置文件core-site.xml编辑配置文件hdfs-site.xm…

使用深度学习集成模型进行乳腺癌组织病理学图像分类

基于预训练的VGG16和VGG19架构训练了四种不同的模型&#xff08;即完全训练的 VGG16、微调的 VGG16、完全训练的 VGG19 和微调的 VGG19 模型&#xff09;。最初&#xff0c;我们对所有单独的模型进行了5倍交叉验证操作。然后&#xff0c;我们采用集成策略&#xff0c;取预测概率…

深度学习框架

深度学习框架 1 引言 在当今技术加速发展的时代&#xff0c;深度学习已经成为了人工智能领域内最为引人注目的子领域之一。其在图像识别、自然语言处理、自动驾驶等多个行业中的成功应用&#xff0c;已经证明了深度学习在解决复杂问题方面的巨大潜力。然而&#xff0c;深度学习…

package.java文件的作用

你查看springboot的源码&#xff0c;有很多类都有这个文件&#xff0c;在idea不能创建&#xff0c;因为不支持这种命名&#xff0c;只能用记事本创建后复制都项目中。 主要应用是给类添加正常&#xff0c;或者把公用的注解都放到这里&#xff0c;常量不合适&#xff0c;作用范…

动态消息系统设计

动态消息流是一个在你个人主页不同更定的故事列表&#xff0c;推特、mega和Instagram 的post消息都是典型的动态消息列表&#xff0c;和普通消息流系统的最大区别是消息流动态变化、实时更新&#xff0c;设计一个动态消息系统核心功能消息流的构建和消息的发布&#xff0c;需要…

蓝桥杯 — — 纯质数

纯质数 题目&#xff1a; 思路&#xff1a; 一个最简单的思路就是枚举出所有的质数&#xff0c;然后再判断这个质数是否是一个纯质数。 枚举出所有的质数&#xff1a; 可以使用常规的暴力求解法&#xff0c;其时间复杂度为&#xff08; O ( N N ) O(N\sqrt{N}) O(NN ​)&…

破译验证码reCAPTCHA 之 打码平台

由于登录需要验证码&#xff0c;除了日常的字符串&#xff0b;数字&#xff0c;此时就需要用第三方插件进行破译。 reCaptcha是Google公司的验证码服务&#xff0c;方便快捷&#xff0c;改变了传统验证码需要输入n位失真字符的特点。 1. reCAPTCHA 初识 reCaptcha是Google公司…

Oracle+11g+笔记(3)-SQL/Plus

Oracle11g笔记(3)-SQL/Plus 3、SQL/Plus 3.1 启动退出SQL/Plus > sqlplus 账号/密码数据库 # 示例 > sqlplus scott/tigerorcl> sqlplus /nolog -- 无日志登录&#xff1a;避免别人从日志中查询到登录信息 > conn soctt/socttorcl # 示例 > sqlplus /nolog &…

Numpy数组和列表list的区别

参考&#xff1a;Numpy Array vs List 在Python编程中&#xff0c;列表&#xff08;list&#xff09;和Numpy数组&#xff08;numpy array&#xff09;是两种常见的数据结构&#xff0c;它们都可以用来存储多个元素。但是它们在实际使用中有很大的区别&#xff0c;本文将详细比…