【Pytorch】RNN for Image Classification

在这里插入图片描述

文章目录

  • 1 RNN 的定义
  • 2 RNN 输入 input, h_0
  • 3 RNN 输出 output, h_n
  • 4 多层
  • 5 小试牛刀

学习参考来自

  • pytorch中nn.RNN()总结
  • RNN for Image Classification(RNN图片分类–MNIST数据集)
  • pytorch使用-nn.RNN

1 RNN 的定义

在这里插入图片描述

nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)

参数说明

  • input_size输入特征的维度, 一般 rnn 中输入的是词向量,那么 input_size 就等于一个词向量的维度
  • hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
  • num_layers网络的层数
  • nonlinearity激活函数
  • bias是否使用偏置
  • batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
  • dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
  • birdirectional是否使用双向的 rnn,默认是 False

2 RNN 输入 input, h_0

input 形状: 当设置 batch_first = False 时, ( L , N , H i n ) (L , N , H_{ i n}) (L,N,Hin) —— [时间步数, 批量大小, 特征维度]

当设置 batch_first = True时, ( N , L , H i n ) (N , L , H_{ i n}) (N,L,Hin)

当输入只有两个维度且 batch_size 为 1 时 :( L , H i n ) (L, H_{in})(L,H in ) 时,需要调用 torch.unsqueeze() 增加维度。

h_0 形状: ( D ∗ n u m _ l a y e r s , N , H o u t ) ( D ∗ n u m \_ l a y e r s , N , H _{o u t} ) (Dnum_layers,N,Hout), D 代表单向 RNN 还是双向 RNN。

在这里插入图片描述

3 RNN 输出 output, h_n

output 形状:当设置 batch_first = False 时, ( L , N , D ∗ H o u t ) (L, N, D * H_{out}) (L,N,DHout)—— [时间步数, 批量大小, 隐藏单元个数];
当设置 batch_first = True 时, ( N , L , D ∗ H o u t ) (N, L, D * H_{out}) (N,L,DHout)

h_n 形状 ( D ∗ num_layers , N , H o u t ) (D * \text{num\_layers}, N, H_{out}) (Dnum_layers,N,Hout)

4 多层

在这里插入图片描述

5 小试牛刀

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# -------------
# MNIST dataset
# -------------
batch_size = 128
train_dataset = torchvision.datasets.MNIST(root='./',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='./',
                                          train=False,
                                          transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)


# ---------------------
# Exploring the dataset
# ---------------------
# function to show an image
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()

if 1:
    # show image
    imshow(torchvision.utils.make_grid(images, nrow=15))
    plt.show()


# ----------
# parameters
# ----------
N_STEPS = 28
N_INPUTS = 28  # 输入数据的维度
N_NEURONS = 150  # RNN中间的特征的大小
N_OUTPUT = 10  # 输出数据的维度(分类的个数)
N_EPHOCS = 10  # epoch的大小
N_LAYERS = 3


# ------
# models
# ------
class ImageRNN(nn.Module):
    def __init__(self, batch_size, n_inputs, n_neurons, n_outputs, n_layers):
        super(ImageRNN, self).__init__()
        self.batch_size = batch_size  # 输入的时候batch_size, 128
        self.n_inputs = n_inputs  # 输入的维度, 28
        self.n_outputs = n_outputs  # 分类的大小 10
        self.n_neurons = n_neurons  # RNN中输出的维度 150
        self.n_layers = n_layers  # RNN中的层数 3
        self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons, num_layers=self.n_layers)
        self.FC = nn.Linear(self.n_neurons, self.n_outputs)

    def init_hidden(self):
        # (num_layers, batch_size, n_neurons)
        # initialize hidden weights with zero values
        # 这个是net的memory, 初始化memory为0
        return (torch.zeros(self.n_layers, self.batch_size, self.n_neurons).to(device))

    def forward(self, x):  # torch.Size([128, 28, 28])
        # transforms x to dimensions : n_step × batch_size × n_inputs
        x = x.permute(1, 0, 2)  # 需要把n_step放在第一个, torch.Size([28, 128, 28])
        self.batch_size = x.size(1)  # 每次需要重新计算batch_size, 因为可能会出现不能完整方下一个batch的情况 128
        self.hidden = self.init_hidden()  # 初始化hidden state  torch.Size([3, 128, 150])
        rnn_out, self.hidden = self.basic_rnn(x, self.hidden)  # 前向传播  torch.Size([28, 128, 150]), torch.Size([3, 128, 150])
        out = self.FC(rnn_out[-1])  # 求出每一类的概率 torch.Size([128, 150])->torch.Size([128, 10])
        return out.view(-1, self.n_outputs)  # 最终输出大小 : batch_size X n_output  torch.Size([128, 10])


# --------------------
# Device configuration
# --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ------------------------------------
# Test the model(输入一张图片查看输出)
# ------------------------------------
# 定义模型
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
print(model)
"""
ImageRNN(
  (basic_rnn): RNN(28, 150, num_layers=3)
  (FC): Linear(in_features=150, out_features=10, bias=True)
)
"""

# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)


# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
print(logits[0:2])
"""
tensor([[-0.2846, -0.1503, -0.1593,  0.5478,  0.6827,  0.3489, -0.2989,  0.4575,
         -0.2426, -0.0464],
        [-0.6708, -0.3025, -0.0205,  0.2242,  0.8470,  0.2654, -0.0381,  0.6646,
         -0.4479,  0.2523]], device='cuda:0', grad_fn=<SliceBackward>)
"""

# 产生对角线是1的矩阵
torch.eye(n=5, m=5, out=None)
"""
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
"""


# --------
# Training
# --------
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)


def get_accuracy(logit, target, batch_size):
    """最后用来计算模型的准确率
    """
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()


# ---------
# 开始训练
# ---------
for epoch in range(N_EPHOCS):
    train_running_loss = 0.0
    train_acc = 0.0
    model.train()
    # trainging round
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        # reset hidden states
        model.hidden = model.init_hidden()
        # get inputs
        inputs, labels = data
        inputs = inputs.view(-1, 28, 28).to(device)
        labels = labels.to(device)
        # forward+backward+optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_running_loss = train_running_loss + loss.detach().item()
        train_acc = train_acc + get_accuracy(outputs, labels, batch_size)
    model.eval()
    print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/i, train_acc/i))


# ----------------------------------------
# Computer accuracy on the testing dataset
# ----------------------------------------
test_acc = 0.0
for i,data in enumerate(test_loader,0):
    inputs, labels = data
    labels = labels.to(device)
    inputs = inputs.view(-1,28,28).to(device)
    outputs = model(inputs)
    thisBatchAcc = get_accuracy(outputs, labels, batch_size)
    print("Batch:{:0>2d}, Accuracy : {:<6.4f}%".format(i,thisBatchAcc))
    test_acc = test_acc + thisBatchAcc
print('============平均准确率===========')
print('Test Accuracy : {:<6.4f}%'.format(test_acc/i))
"""
Epoch : 00 | Loss : 0.6336 | Train Accuracy : 79.32 %
Epoch : 01 | Loss : 0.2363 | Train Accuracy : 93.00 %
Epoch : 02 | Loss : 0.1852 | Train Accuracy : 94.63 %
Epoch : 03 | Loss : 0.1516 | Train Accuracy : 95.69 %
Epoch : 04 | Loss : 0.1338 | Train Accuracy : 96.13 %
Epoch : 05 | Loss : 0.1198 | Train Accuracy : 96.67 %
Epoch : 06 | Loss : 0.1254 | Train Accuracy : 96.46 %
Epoch : 07 | Loss : 0.1128 | Train Accuracy : 96.88 %
Epoch : 08 | Loss : 0.1059 | Train Accuracy : 97.09 %
Epoch : 09 | Loss : 0.1048 | Train Accuracy : 97.10 %
Batch:00, Accuracy : 98.4375%
Batch:01, Accuracy : 98.4375%
Batch:02, Accuracy : 95.3125%
Batch:03, Accuracy : 98.4375%
Batch:04, Accuracy : 96.8750%
Batch:05, Accuracy : 93.7500%
Batch:06, Accuracy : 97.6562%
Batch:07, Accuracy : 95.3125%
Batch:08, Accuracy : 94.5312%
Batch:09, Accuracy : 92.9688%
Batch:10, Accuracy : 96.0938%
Batch:11, Accuracy : 96.0938%
Batch:12, Accuracy : 97.6562%
Batch:13, Accuracy : 96.8750%
Batch:14, Accuracy : 96.0938%
Batch:15, Accuracy : 95.3125%
Batch:16, Accuracy : 95.3125%
Batch:17, Accuracy : 96.0938%
Batch:18, Accuracy : 96.0938%
Batch:19, Accuracy : 97.6562%
Batch:20, Accuracy : 97.6562%
Batch:21, Accuracy : 98.4375%
Batch:22, Accuracy : 96.0938%
Batch:23, Accuracy : 96.8750%
Batch:24, Accuracy : 97.6562%
Batch:25, Accuracy : 99.2188%
Batch:26, Accuracy : 96.0938%
Batch:27, Accuracy : 94.5312%
Batch:28, Accuracy : 98.4375%
Batch:29, Accuracy : 94.5312%
Batch:30, Accuracy : 96.0938%
Batch:31, Accuracy : 93.7500%
Batch:32, Accuracy : 96.8750%
Batch:33, Accuracy : 96.0938%
Batch:34, Accuracy : 95.3125%
Batch:35, Accuracy : 96.8750%
Batch:36, Accuracy : 97.6562%
Batch:37, Accuracy : 93.7500%
Batch:38, Accuracy : 94.5312%
Batch:39, Accuracy : 100.0000%
Batch:40, Accuracy : 99.2188%
Batch:41, Accuracy : 100.0000%
Batch:42, Accuracy : 98.4375%
Batch:43, Accuracy : 98.4375%
Batch:44, Accuracy : 96.8750%
Batch:45, Accuracy : 99.2188%
Batch:46, Accuracy : 96.0938%
Batch:47, Accuracy : 98.4375%
Batch:48, Accuracy : 97.6562%
Batch:49, Accuracy : 100.0000%
Batch:50, Accuracy : 99.2188%
Batch:51, Accuracy : 91.4062%
Batch:52, Accuracy : 96.8750%
Batch:53, Accuracy : 99.2188%
Batch:54, Accuracy : 99.2188%
Batch:55, Accuracy : 100.0000%
Batch:56, Accuracy : 98.4375%
Batch:57, Accuracy : 98.4375%
Batch:58, Accuracy : 97.6562%
Batch:59, Accuracy : 100.0000%
Batch:60, Accuracy : 99.2188%
Batch:61, Accuracy : 96.0938%
Batch:62, Accuracy : 100.0000%
Batch:63, Accuracy : 97.6562%
Batch:64, Accuracy : 97.6562%
Batch:65, Accuracy : 96.8750%
Batch:66, Accuracy : 98.4375%
Batch:67, Accuracy : 100.0000%
Batch:68, Accuracy : 100.0000%
Batch:69, Accuracy : 100.0000%
Batch:70, Accuracy : 96.8750%
Batch:71, Accuracy : 98.4375%
Batch:72, Accuracy : 100.0000%
Batch:73, Accuracy : 99.2188%
Batch:74, Accuracy : 100.0000%
Batch:75, Accuracy : 96.0938%
Batch:76, Accuracy : 95.3125%
Batch:77, Accuracy : 96.8750%
Batch:78, Accuracy : 12.5000%
============平均准确率===========
Test Accuracy : 97.4559%
# """


# 定义hook
class SaveFeatures():
    """注册hook和移除hook
    """
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output
    def close(self):
        self.hook.remove()

# 绑定到model上
activations = SaveFeatures(model.basic_rnn)


# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()

# 前向传播
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
activations.close()  # 移除hook

# 这个是 28(step)*128(batch_size)*150(hidden_size)
print(activations.features[0].shape)
# torch.Size([28, 128, 150])
print(activations.features[0][-1].shape)
# torch.Size([128, 150])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/791083.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

游戏视频是后期配音好还是边录边配 游戏视频怎么剪辑制作才能火 视频剪辑免费软件

游戏视频后期配音是先配还是先剪&#xff1f;游戏视频后期配音没有统一的准则&#xff0c;可以先配&#xff0c;也可以后配&#xff0c;主要是根据内容而定。游戏视频剪辑在游戏玩家中十分流行&#xff0c;那么&#xff0c;游戏视频怎么剪辑制作&#xff1f;下面让我们以具体的…

apache python使用

修改httpd.conf文件。 AddHandler cgi-script .cgi .py 代码 #!自己的python.exe #-*- coding:UTF-8 -*- print ("Content-type:text/html") print () print (<html>) print (<head>) print (<meta charset"gb2312">) print (<tit…

现在国内的ddos攻击趋势怎么样?想了解现在ddos的情况该去哪看?

目前&#xff0c;国内的DDoS攻击趋势显示出以下几个特征&#xff1a; 攻击频次显著增加&#xff1a;根据《快快网络2024年DDoS攻击趋势白皮书》&#xff0c;2023年DDoS攻击活动有显著攀升&#xff0c;总攻击次数达到1246.61万次&#xff0c;比前一年增长了18.1%。 攻击强度和规…

Spring Expression表达式使用

Spring Expression 简介 Spring Expression Language&#xff08;简称 “SpEL”&#xff09;是一种功能强大的表达式语言&#xff0c;支持在运行时查询和操作对象图。其中最显著的是&#xff1a;方法调用和基本的字符串模板功能。 虽然SpEL是Spring产品组合中表达式评估的基础…

树莓派pico入坑笔记,ssd1306使用

目录 说明 后附进阶玩法&#xff1a;显示中文&#xff0c;外加简单库实现 官方模块使用 使用样例 方法说明 下面是绘图支持的方法 进阶玩法&#xff0c;显示中文 方法&#xff0c;对汉字取字模&#xff0c;然后按像素对字模进行显示 说明 circuitpython8.x使用iic接口驱…

链表 OJ(一)

移除链表元素 题目连接&#xff1a; https://leetcode.cn/problems/remove-linked-list-elements/description/ 使用双指针法&#xff0c;开始时&#xff0c;一个指针指向头节点&#xff0c;另一个指针指向头节点的下一个结点&#xff0c;然后开始遍历链表删除结点。 这里要注…

YOLOv10改进 | Conv篇 | CVPR2024最新DynamicConv替换下采样(解决低FLOPs陷阱)

一、本文介绍 本文给大家带来的改进机制是CVPR2024的最新改进机制DynamicConv其是CVPR2024的最新改进机制&#xff0c;这个论文中介绍了一个名为ParameterNet的新型设计原则&#xff0c;它旨在在大规模视觉预训练模型中增加参数数量&#xff0c;同时尽量不增加浮点运算&#x…

YOLOv10改进 | Conv篇 | 全新的SOATA轻量化下采样操作ADown(参数量下降百分之二十,附手撕结构图)

一、本文介绍 本文给大家带来的改进机制是利用2024/02/21号最新发布的YOLOv9其中提出的ADown模块来改进我们的Conv模块&#xff0c;其中YOLOv9针对于这个模块并没有介绍&#xff0c;只是在其项目文件中用到了&#xff0c;我将其整理出来用于我们的YOLOv10的项目&#xff0c;经…

Python 视频的色彩转换

这篇教学会介绍使用OpenCV 的cvtcolor() 方法&#xff0c;将视频的色彩模型从RGB 转换为灰阶、HLS、HSV...等。 因为程式中的OpenCV 会需要使用镜头或GPU&#xff0c;所以请使用本机环境( 参考&#xff1a;使用Python 虚拟环境) 或使用Anaconda Jupyter 进行实作( 参考&#x…

【TAROT学习日记】韦特体系塔罗牌学习(1)——愚者 THE FOOL 0

韦特体系塔罗牌学习&#xff08;1&#xff09;——愚者 THE FOOL 0 https://www.tarotchina.net/major-arcana0-vip/ 目录 韦特体系塔罗牌学习&#xff08;1&#xff09;——愚者 THE FOOL 0牌面分析1. 基础信息2. 图片元素 正位牌意1. 关键词/句2.爱情婚姻3. 学业事业4. 人际关…

android13 rom frameworks 蓝牙自动接收文件

总纲 android13 rom 开发总纲说明 目录 1.前言 2.源码查找 3.我们先实现第一种改法 4.实现第二种改法 5.第三种改法代码参考 6.编译测试 1.前言 我们从导航栏这里,点开这个蓝牙的接收框,弹出来的对话框,使用android studio 的layout inspector可以发现这个是 Bluetoo…

有必要找第三方软件测评公司吗?如何选择靠谱软件测评机构?

软件测试是确保软件质量的重要环节&#xff0c;而在进行软件测试时&#xff0c;是否有必要找第三方软件测评公司呢?第三方软件测评公司是指独立于软件开发公司和用户之间的中立机构&#xff0c;专门从事软件测试和测评工作。与自身开发团队或内部测试团队相比&#xff0c;选择…

大白话讲解AI大模型

大白话讲解大模型 大模型的发展重要大模型发展时间线 大模型的简单原理-训练⼤模型是如何训练并应⽤到场景中的&#xff1f;如果训练私有化模型 模型&#xff1a;model 语料库&#xff1a;用于训练模型的数据 大模型的发展 详细信息来源&#xff1a;DataLearner 2022年11月底…

JVM相关知识点汇总

JDK,JRE以及JVM的关系 我们的编译器到底干了什么事? 仅仅是将我们的 .java 文件转换成了 .class 文件,实际上就是文件格式的转换,对等信息转换。 类加载机制是什么? > **所谓类加载机制就是** > ``` > 虚拟机把Class文件加载到内存 > 并对数据进行校验,转换…

web安全及内网安全知识

本文来源无问社区&#xff08;wwlib.cn&#xff09;更多详细内容可前往观看http://www.wwlib.cn/index.php/artread/artid/7506.html Web安全 1、sql注入 Web程序中对于用户提交的参数未做过滤直接拼接到SQL语句中执行&#xff0c;导致参数中的特殊字符破坏了SQL语句原有逻…

qt 用数据画一个图,并表示出来

1.概要 想用数据绘制一个画面&#xff0c;看有相机到播放的本质是啥。 要点 // 创建一个QImage对象&#xff0c;指定图像的宽度、高度和格式 QImage image(width, height, QImage::Format_Grayscale8); // 将像素数据复制到QImage对象中 memcpy(image.bits(), pixelD…

【Linux网络】IP协议{初识/报头/分片/网段划分/子网掩码/私网公网IP/认识网络世界/路由表}

文章目录 1.入门了解2.认识报头3.认识网段4.路由跳转相关指令路由 该文诸多理解参考文章&#xff1a;好文&#xff01; 1.入门了解 用户需求&#xff1a;将我的数据可靠的跨网络从A主机送到B主机 传输层TCP&#xff1a;由各种方法&#xff08;流量控制/超时重传/滑动窗口/拥塞…

PTC可复位保险丝 vs 传统型保险丝:全面对比分析

PTC可复位保险丝&#xff0c;又称为自恢复保险丝、自恢复熔断器或PPTC保险丝&#xff0c;是一种电子保护器件。它利用材料的正温度系数效应&#xff0c;即电阻值随温度升高而显著增加的特性&#xff0c;来实现电路保护。 当电路正常工作时&#xff0c;PTC保险丝呈现低阻态&…

最新浪子授权系统网站源码 全开源免授权版本

最新浪子授权系统网站源码 全开源免授权版本 此版本没有任何授权我已经去除授权&#xff0c;随意二开无任何加密。 更新日志 1.修复不能下载 2.修复不能更新 3.修复不能删除用户 4.修复不能删除授权 5.增加代理后台管理 6.重写授权读取文件 7.修复已经知道漏洞 源码下…

2-30 基于matlab的神经网路下身份证号码识别算法

基于matlab的神经网路下身份证号码识别算法&#xff0c;二值化、膨胀处理、边界区域划分、身份证字符分割&#xff0c;字符识别算法&#xff0c;输出识别结果。并保存识别结果。程序已调通&#xff0c;可直接运行。 2-30 神经网络 身份证识别 图像处理 - 小红书 (xiaohongshu.c…