实现多层感知机

目录

多层感知机:

介绍:

代码实现:

运行结果:

问题答疑:

线性变换与非线性变换

参数含义

为什么清除梯度?

反向传播的作用

为什么更新权重?


多层感知机:

介绍:

缩写:MLP,这是一种人工神经网络,由一个输入层、一个或多个隐藏层以及一个输出层组成,每一层都由多个节点(神经元)构成。在MLP中,节点之间只有前向连接,没有循环连接,这使得它属于前馈神经网络的一种。每个节点都应用一个激活函数,如sigmoid、ReLU等,以引入非线性,从而使网络能够拟合复杂的函数和数据分布。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Step 1: Define the MLP model
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layer
        self.fc2 = nn.Linear(128, 64)   # Hidden layer to another hidden layer
        self.fc3 = nn.Linear(64, 10)    # Hidden layer to output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)             # Flatten the input from 28x28 to 784
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Step 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Step 3: Define loss function and optimizer
model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Step 4: Train the model
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# Step 5: Evaluate the model on the test set (optional)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

运行结果:

问题答疑:

线性变换与非线性变换

在神经网络中

线性变换通常指的是权重矩阵和输入数据的矩阵乘法,再加上偏置向量。数学上,对于一个输入向量𝑥x和权重矩阵𝑊W,加上偏置向量𝑏b,线性变换可以表示为: 𝑧=𝑊𝑥+𝑏z=Wx+b

非线性变换是指在神经网络的每一层之后应用的激活函数,如ReLU、sigmoid或tanh等。这些函数引入了非线性,使神经网络能够学习和表达复杂的函数关系。没有非线性变换,无论多少层的神经网络最终都将简化为一个线性模型。

参数含义

在上述模型中,参数如784, 128, 64, 10并不是字节,而是神经网络层的尺寸,具体来说是神经元的数量:

  • 784: 这是输入层的神经元数量,对应于MNIST数据集中每个图片的像素数量。MNIST的图片是28x28像素,因此总共有784个像素点。
  • 128 和 64: 这是两个隐藏层的神经元数量。它们代表了第一层和第二层的宽度,即这一层有多少个神经元。
  • 10: 这是输出层的神经元数量,对应于MNIST数据集中的10个数字类别(0到9)。

为什么清除梯度?

在每一次前向传播和反向传播过程中,梯度会被累积在张量的.grad属性中。如果不手动清零,这些梯度将会被累加,导致不正确的梯度值。因此,在每次迭代开始之前,都需要调用optimizer.zero_grad()来清空梯度。

反向传播的作用

反向传播(Backpropagation)是一种算法,用于计算损失函数相对于神经网络中所有权重的梯度。它的目的是为了让神经网络知道,当损失函数值较高时,哪些权重需要调整,以及调整的方向和幅度。这些梯度随后被用于权重更新,以最小化损失函数。

为什么更新权重?

权重更新是基于梯度下降算法进行的。在反向传播计算出梯度后,权重通过optimizer.step()函数更新,以朝着减小损失函数的方向移动。

这是训练神经网络的核心,即通过不断调整权重和偏置,使模型能够更好地拟合训练数据,从而提高预测准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/798678.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW红外热波图像缺陷检

开发使用LabVIEW开发的红外热波图像缺陷检测系统。该系统结合红外热像仪、工业相机和高效的数据采集硬件,实现对工件表面缺陷的自动检测和分析。通过LabVIEW的强大功能,系统能够实时采集、处理和显示红外热波图像,有效提高了检测的精度和效率…

【Playwright+Python】系列 Pytest 插件在Playwright中的使用

一、命令行使用详解 使用 Pytest 插件在Playwright 中来编写端到端的测试。 1、命令行执行测试 pytest --browser webkit --headed 2、使用 pytest.ini 文件配置 内容如下: [pytest] # Run firefox with UIaddopts --headed --browser firefox效果&#xff1…

机器人相关工科专业课程体系

机器人相关工科专业课程体系 前言传统工科专业机械工程自动化/控制工程计算机科学与技术 新兴工科专业智能制造人工智能机器人工程 总结Reference: 前言 机器人工程专业是一个多领域交叉的前沿学科,涉及自然科学、工程技术、社会科学、人文科学等相关学科的理论、方…

STM32MP135裸机编程:定时器内核时钟频率计算方法

0 工具准备 STM32MP13xx参考手册 1 定时器内核时钟频率计算方法 1.1 定时器分组 STM32MP135的定时器按照时钟源不同分成了三组,如下: APB1: APB2: APB6: 1.2 定时器内核时钟频率计算方法 APB1DIV是APB1的分频系数,APB2DIV、…

Flink Window 窗口【更新中】

Flink Window 窗口 在Flink流式计算中,最重要的转换就是窗口转换Window,在DataStream转换图中,可以发现处处都可以对DataStream进行窗口Window计算。 窗口(window)就是从 Streaming 到 Batch 的一个桥梁。窗口将无界流…

制作显卡版docker并配置TensorTR环境

感谢阅读 相关概念docker准备下载一个自己电脑cuda匹配的docker镜像拉取以及启动镜像安装cudaTensorRT部署教程 相关概念 TensorRT是可以在NVIDIA各种GPU硬件平台下运行的一个模型推理框架,支持C和Python推理。即我们利用Pytorch,Tensorflow或者其它框架…

汽车的驱动力,是驱动汽车行驶的力吗?

一、地面对驱动轮的反作用力? 汽车发动机产生的转矩,经传动系传至驱动轮上。此时作用于驱动轮上的转矩Tt产生一个对地面的圆周力F0,地面对驱动轮的反作用力Ft(方向与F0相反)即是驱动汽车的外力,此外力称为汽车的驱动力。 即汽车…

Codeforces Round 957 (Div. 3)(A~D题)

A. Only Pluses 思路: 优先增加最小的数&#xff0c;它们的乘积会是最优,假如只有两个数a和b&#xff0c;b>a&#xff0c;那么a 1&#xff0c;就增加一份b。如果b 1&#xff0c;只能增加1份a。因为 b > a&#xff0c;所以增加小的数是最优的。 代码: #include<bi…

最新PHP自助商城源码,彩虹商城源码

演示效果图 后台效果图 运行环境&#xff1a; Nginx 1.22.1 Mysql5.7 PHP7.4 直接访问域名即可安装 彩虹自助下单系统二次开发 拥有供货商系统 多余模板删除 保留一套商城,两套发卡 源码无后门隐患 已知存在的BUG修复 彩虹商城源码&#xff1a;下载 密码:chsc 免责声明&…

搞定ES6同步与异步机制、async/await的使用以及Promise的使用!

文章目录 同步和异步async/awaitPromisePromise的概念 同步和异步 ​ 同步&#xff1a;代码按照编写顺序逐行执行&#xff0c;后续的代码必须等待当前正在执行的代码完成之后才能执行&#xff0c;当遇到耗时的操作&#xff08;如网络请求等&#xff09;时&#xff0c;主线程会…

解决fidder小黑怪倒出JMeter文件缺失域名、请求头

解决fidder小黑怪倒出JMeter文件缺失域名、请求头 1、目录结构&#xff1a; 2、代码 coding:utf-8 Software:PyCharm Time:2024/7/10 14:02 Author:Dr.zxyimport zipfile import os import xml.etree.ElementTree as ET import re#定义信息头 headers_to_extract [Host, Conn…

C语言 | Leetcode C语言题解之第236题二叉树的最近公共祖先

题目&#xff1a; 题解&#xff1a; /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/typedef struct road_t {struct TreeNode *road_node; // 途径路径struct road_t *p_next; }…

[IDEA插件] JarEditor 编辑jar包(直接新增、修改、删除jar包内的class文件)

文章目录 1. 安装插件 JarEditor2. 在IDEA中添加外部JAR包3. JarEditor 使用介绍 之前我们需要修改jar内文件的时候需要解压jar包&#xff0c;反编译class&#xff0c;新建java源文件&#xff0c;修改代码&#xff0c;再编译成class&#xff0c;替换jar包内的class文件。 现在…

MongoDB教程(三):mongoDB用户管理

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言一、MongoD…

P2p网络性能测度及监测系统模型

P2p网络性能测度及监测系统模型 网络IP性能参数 IP包传输时延时延变化误差率丢失率虚假率吞吐量可用性连接性测度单向延迟测度单向分组丢失测度往返延迟测度 OSI中的位置-> 网络层 用途 面相业务的网络分布式计算网络游戏IP软件电话流媒体分发多媒体通信 业务质量 通过…

JavaSE 面向对象程序设计进阶 IO 压缩流 解压缩流

目录 解压缩流 压缩流 解压缩流 压缩包 压缩包里面的每一个文件在java中都是一个ZipEntry对象 把每一个ZipEntry按照层级拷贝到另一个文件夹当中 import java.io.*; import java.util.Date; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream;public cl…

水表数字识别2:Pytorch DBNet实现水表数字检测(含训练代码和数据集)

水表数字识别2&#xff1a;Pytorch DBNet实现水表数字检测(含训练代码和数据集) 目录 水表数字识别2&#xff1a;Pytorch DBNet实现水表数字检测(含训练代码和数据集) 1.前言 2. 水表数字识别的方法 3. 水表数字识别数据集 4. 水表数字分割模型训练 &#xff08;1&#x…

OpenCV解决验证码(数字和字母)识别(Python)

文章目录 前言一、准备验证码图片 前言 OpenCV是一个基于Apache2.0许可&#xff08;开源&#xff09;发行的跨平台计算机视觉和机器学习软件库。它支持Windows、Linux、Mac OS、Android和iOS等多个操作系统&#xff0c;提供了丰富的图像处理和计算机视觉功能&#xff0c;包括但…

基于JAVA的网上招聘系统的设计与实现

点击下载源码 网上招聘系统的设计与实现 摘 要 随着时代的发展&#xff0c;中国的互联网技术愈加成熟&#xff0c;已经有越来越多的社会群体开始学会使用互联网技术&#xff0c;整个社会正在朝着智能化、信息化的方向前进。有了互联网&#xff0c;用户便可以足不出户地利用互…

【TOOLS】Chrome扩展开发

Chrome Extension Development 1. 入门教程 入门案例&#xff0c;可以访问【 谷歌插件官网官方文档 】查看官方入门教程&#xff0c;这里主要讲解大概步骤 Chrome Extenson 没有固定的脚手架&#xff0c;所以项目的搭建需要根据开发者自己根据需求搭建项目&#xff08;例如通过…