5-pytorch-torch.nn.Sequential()快速搭建神经网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • torch.nn.Sequential()快速搭建网络法
    • 1 生成数据
    • 2 快速搭建网络
    • 3 训练、输出结果
  • 总结


前言

本文内容还是基于4-pytorch前馈网络简单(分类)问题搭建这篇的相同例子,只是为了介绍另一种更加快速搭建网络的方法,看个人喜好用哪一种。
【注】:建议先看完上面链接的博客4,在来看本篇。
这里的这种搭建方法是使用**torch.nn.Sequential()**快速搭建,不用我们在继承重写net类了。

torch.nn.Sequential()快速搭建网络法

1 生成数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

n_data = torch.ones(100,2)
x0 = torch.normal(2*n_data,1)
y0 = torch.zeros(100,1)
x1 = torch.normal(-2*n_data,1)
y1 = torch.ones(100,1)

x = torch.cat((x0,x1),0)
# 在分类问题中标签必须用一维tensor,回归中则没有这个要求
y = torch.cat((y0,y1),0).reshape(-1)
# 在分类问题中标签还需要用torch.LongTensor类型
# 将张量 y 的类型转换为 long,这是因为在 PyTorch 中,分类问题的标签通常是整数类型(long),以便与模型输出的类别概率进行比较,从而计算损失。
y = y.long()


fig = plt.figure()
plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=y.data.numpy())
# 给画出来的每一个点标上标签,有点难看,注了吧
# 循环遍历每个数据点,根据其对应的标签添加标签文本
for i in range(len(x)):
    plt.text(x[i][0], x[i][1], str(int(y[i].item())), fontsize=8)
plt.show()

输出:
在这里插入图片描述

2 快速搭建网络

## 搭建网络method1
# class Net(torch.nn.Module):
#     def __init__(self,n_features,n_hidden,n_output):
#         # 继承原来结构体的全部init属性及方法
#         super(Net,self).__init__()
#         # 线性层就是全连接层
#         self.hidden = torch.nn.Linear(n_features,n_hidden)
#         self.predict = torch.nn.Linear(n_hidden,n_output)
#         
#     def forward(self,x):
#         # 重写继承类的向前传播方法,就是在这个里面选择激活函数的
#         x = F.relu(self.hidden(x))
#         # 分类中输出层也可以不用激活函数,我们最后在对输出结果进行softmax处理
#         x = self.predict(x)
#         return x
#         
# net = Net(2,10,2)
# # 输出层定义2个输出,对输出在进行softmax处理,取出概率最大的元素的下标就是我们分类的类别;与回归有所不同
# # 有点类似机器学习里面的独热编码
# print(net)


## 快速搭建法,和前面注释掉的效果是一样的。
net = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(), # 这里激活函数大写了要
    torch.nn.Linear(10,2)
)
print(net)

输出:
在这里插入图片描述

3 训练、输出结果

optimizer = torch.optim.SGD(net.parameters(),lr=0.02)
# 分类用交叉熵损失函数
loss_func = torch.nn.CrossEntropyLoss()

# 开启matplotlib的交换模式
plt.ion()
for t in range(100):
    # 这一步其实是调用了类里面的 __call__魔术方法,又学到一个魔术方法
    out = net(x)
    loss = loss_func(out,y)
    # 梯度清零
    optimizer.zero_grad()
    # 误差反向传播,求梯度
    loss.backward()
    # 进行优化器优化
    optimizer.step()
    if t%5 == 0:
        plt.cla()
        prediction = torch.max(F.softmax(out,1),1)[1]
        pred_y = prediction.data.numpy().reshape(-1)
        target_y = y.data.numpy().reshape(-1)
        plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=pred_y)
        accuracy = sum(pred_y==target_y)/200
        plt.text(1.2,-4,'accuracy=%.2f' % accuracy, fontdict={'size':20,'color':'red'})
        plt.pause(0.1)
# 关闭matplotlib的交换模式
plt.ioff()
plt.show()

输出:
在这里插入图片描述

# 输出out经softmax处理过后才变成概率
out2probability = F.softmax(out,1)
#print(out2probability.round(decimals=2))
# 取出概率向量里面概率最大的下标就是最终的分类结果
prediction = torch.max(F.softmax(out,1),1)[1]
print(prediction)在这里插入代码片

输出:
在这里插入图片描述

总结

选择那种方法搭建,看个人喜好,效果完全一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/550455.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

滤波器笔记(杂乱)

线性相位是时间平移,相位不失真 零、基础知识 1、用相量表示正弦量 https://zhuanlan.zhihu.com/p/345546880 https://www.zhihu.com/question/347763932/answer/1103938667 A s i n ( ω t θ ) ⇔ A e j θ ⇔ A ∠ θ Asin(\omega t\theta) {\Leftrightarrow…

IBM SPSS Statistics for Mac中文激活版:强大的数据分析工具

IBM SPSS Statistics for Mac是一款功能强大的数据分析工具,为Mac用户提供了高效、精准的数据分析体验。 IBM SPSS Statistics for Mac中文激活版下载 该软件拥有丰富的统计分析功能,无论是描述性统计、推论性统计,还是高级的多元统计分析&am…

企业邮箱迁移是什么?如何通过IMAP/POP协议进行邮箱迁移?

使用公司邮箱工作的过程中,公司可能遇到公司规模的扩大或技术架构升级,可能要换公司邮箱。假如马上使用新的公司邮箱,业务处理要被终断。企业邮箱转移是公司更换邮箱不可或缺的一步,不仅是技术操作,更是企业信息安全、…

Unity MySql安装部署与Unity连接 下篇

一、前言 上篇讲到了如何安装与部署本地MySql;本篇主要讲Unity与MySql连接、创建表、删除表,然后就是对表中数据的增、删、改、查等操作。再讲这些之前会说一些安装MySql碰到的一些问题和Unity连接的问题。 当把本地MySql部署好之后,我们可能…

Pytorch搭建GoogleNet神经网络

一、创建卷积模板文件 因为每次使用卷积层都需要调用Con2d和relu激活函数,每次都调用非常麻烦,就将他们打包在一起写成一个类。 in_channels:输入矩阵深度作为参数输入 out_channels: 输出矩阵深度作为参数输入 经过卷积层和relu激活函数…

AI:156-利用Python进行自然语言处理(NLP):情感分析与文本分类

本文收录于专栏:精通AI实战千例专栏合集 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 每一个案例都附带关键代码,详细讲解供大家学习,希望可以帮到大家。正…

JDK5.0新特性

目录 1、JDK5特性 1.1、静态导入 1.2 增强for循环 1.3 可变参数 1.4 自动装箱/拆箱 1.4.1 基本数据类型包装类 1.5 枚举类 1.6 泛型 1.6.1 泛型方法 1.6.2 泛型类 1.6.3 泛型接口 1.6.4 泛型通配符 1、JDK5特性 JDK5中新增了很多新的java特性,利用这些新…

你的RPCvs佬的RPC

一、课程目标 了解常见系统库的hook了解frida_rpc 二、工具 教程Demo(更新)jadx-guiVS CodejebIDLE 三、课程内容 1.Hook_Libart libart.so: 在 Android 5.0(Lollipop)及更高版本中,libart.so 是 Android 运行时(ART&#x…

计算机网络----第十二天

交换机端口安全技术和链路聚合技术 1、端口隔离技术: 用于在同vlan内部隔离用户; 同一隔离组端口不能通讯,不同隔离组端口可以通讯; 2、链路聚合技术: 含义:把连接到同一台交换机的多个物理端口捆绑为一个逻辑端口…

【前后端的那些事】SpringBoot 基于内存的ip访问频率限制切面(RateLimiter)

文章目录 1. 什么是限流2. 常见的限流策略2.1 漏斗算法2.2 令牌桶算法2.3 次数统计 3. 令牌桶代码编写4. 接口测试5. 测试结果 1. 什么是限流 限流就是在用户访问次数庞大时,对系统资源的一种保护手段。高峰期,用户可能对某个接口的访问频率急剧升高&am…

十大排序——6.插入排序

这篇文章我们来介绍一下插入排序 目录 1.介绍 2.代码实现 3.总结与思考 1.介绍 插入排序的要点如下所示: 首先将数组分为两部分[ 0 ... low-1 ],[ low ... arr.length-1 ],然后,我们假设左边[ 0 ... low-1 ]是已排好序的部分…

Spring Boot 多环境配置:YML 文件的三种高效方法

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

力扣:141. 环形链表

力扣:141. 环形链表 给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾…

uni-app学习

目录 一、安装HBuilderX 二、创第一个uni-app 三、项目目录和文件作用 四、全局配置文件(pages.json) 4.1 globalStyle(全局样式) 导航栏:背景颜色、标题颜色、标题文本 导航栏:开启下拉刷新、下拉背…

LeetCode 409—— 最长回文串

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 要想组成回文串,那么只有最中间的字符可以是奇数个,其余字符都必须是偶数个。 所以,我们先遍历一遍字符串,统计出每个字符出现的次数。 然后如果某个字符出现了偶…

【数据分享】历次人口普查数据(一普到七普)

国之情,民之意,查人口,定大计。 第七次人口普查已经结束,那么,为了方便大家把七普数据与之前的数据做对比,地理遥感生态网整理了从一普到七普人口数据,并且把第七次人口普查的数据也一并分享给…

当全连接队列满了,tcp客户端收到服务端RST信令的模拟

当tcp服务端全连接队列满了后,并且服务端也不accept取出连接,客户端再次连接时,服务端能够看到SYN_RECV状态。但是客户端看到的是ESTABLISHED状态,所以客户端自认为成功建立了连接,故其写往服务端写数据,发…

JVM之本地方法栈和程序计数器和堆

本地方法栈 本地方法栈是为虚拟机执行本地方法时提供服务的 JNI:Java Native Interface,通过使用 Java 本地接口程序,可以确保代码在不同的平台上方便移植 不需要进行 GC,与虚拟机栈类似,也是线程私有的,…

C语言--函数递归

目录 1、什么是递归? 1.1 递归的思想 1.2 递归的限制条件 2. 递归举例 2.1 举例1:求n的阶乘 2.2 举例2:顺序打印⼀个整数的每⼀位 3. 递归与迭代 扩展学习: 早上好,下午好,晚上好 1、什么是递归&…

【鸿蒙开发】生命周期

1. UIAbility组件生命周期 UIAbility的生命周期包括Create、Foreground、Background、Destroy四个状态。 UIAbility生命周期状态 1.1 Create状态 Create状态为在应用加载过程中,UIAbility实例创建完成时触发,系统会调用onCreate()回调。可以在该回调中…