卷积神经网络(含案例代码)

概述

        卷积神经网络(Convolutional Neural Network,CNN)是一类专门用于处理具有网格结构数据的神经网络。它主要被设计用来识别和提取图像中的特征,但在许多其他领域也取得了成功,例如自然语言处理中的文本分类任务。

        CNN 的主要特点是它使用了卷积层(convolutional layer)来处理输入数据。卷积层通过卷积操作在输入数据上滑动一个或多个卷积核(也称为滤波器),从而学习局部特征。这种局部感知能力使得 CNN 能够有效地捕捉输入数据中的空间结构和模式。

基本组成部分

        卷积层(Convolutional Layer)

        由多个卷积核组成,每个卷积核用于检测输入数据中的特定特征。卷积操作通过在输入数据上滑动卷积核并计算局部区域的加权和来提取特征。

        池化层(Pooling Layer)

        用于减小数据的空间维度,降低计算复杂度,并且增强模型对平移变化的鲁棒性。最大池化是常用的池化操作,它选择输入区域中的最大值作为输出。

        激活函数(Activation Function)

通常在卷积层之后应用,引入非线性特性。常用的激活函数包括ReLU(Rectified Linear Unit)。

        全连接层(Fully Connected Layer)

        在卷积层和输出层之间,用于整合卷积层提取的特征并生成最终的输出。全连接层将前一层的所有节点与当前层的每个节点连接。

        CNN 在图像处理任务中表现出色,因为它能够学习到图像的局部和全局特征,具有平移不变性(通过共享权重)、参数共享和稀疏交互等特性。这些特性使得 CNN 在图像分类、目标检测、图像生成等任务中取得了显著的成功。

基本实现原理

        卷积神经网络(CNN)的实现原理涉及卷积层、池化层、激活函数、全连接层等关键组件。

输入层(Input Layer)

        接收原始输入数据,通常是图像或其他具有网格结构的数据。

卷积层(Convolutional Layer)

        使用卷积核(filter)对输入数据进行卷积操作,通过在输入数据上滑动卷积核,提取局部特征。卷积操作通过计算局部区域的加权和来生成输出特征图(feature map)。

激活函数(Activation Function)

        在卷积操作后,应用激活函数引入非线性,增加网络的表示能力。常用的激活函数包括ReLU(Rectified Linear Unit)。

池化层(Pooling Layer)

        对卷积层的输出进行下采样,减小空间维度,提高计算效率,并增强网络对平移变化的鲁棒性。最大池化是常用的池化操作,选择局部区域中的最大值作为输出。

全连接层(Fully Connected Layer)

        将池化层的输出扁平化,并通过全连接层连接到输出层。全连接层负责整合卷积层和池化层提取的特征,并生成最终的输出。

输出层(Output Layer)

        输出层根据任务的性质确定,可以是分类问题的softmax层,回归问题的线性层,或者其他适当的输出层结构。

        在训练过程中,通过反向传播算法更新网络参数,以最小化损失函数。这个过程包括前向传播(计算预测输出)、计算损失、反向传播(计算梯度),以及使用优化算法(如梯度下降)来更新权重。

        CNN 的关键之一是参数共享,即卷积核在整个输入上共享权重,这减少了参数数量,提高了模型的效率和泛化能力。此外,卷积操作和池化操作的重复使用使得网络能够逐渐构建出对输入数据的抽象表示。卷积神经网络通过多层次的特征提取和抽象,能够学习到输入数据的有用表示,从而在图像分类、目标检测等任务中表现出色。

案例代码

        下面是一个使用PyTorch的简单卷积神经网络(CNN)的代码案例。

        在这个例子中,使用PyTorch来构建一个简单的CNN模型,以进行图像分类。确保你已经安装了PyTorch:

pip install torch torchvision

        代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 设置随机种子,以保证实验的可重复性
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 定义CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# 数据预处理和加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print('Test Accuracy: {:.2%}'.format(accuracy))

        这个示例中,使用PyTorch构建了一个包含两个卷积层和两个全连接层的简单CNN模型,并在MNIST手写数字数据集上进行训练和测试。可以根据自己的需求修改模型架构、训练参数等。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# 数据预处理和加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 保存模型
torch.save(model.state_dict(), 'simple_cnn_model.pth')
print("Model has been saved.")

# 加载模型
new_model = SimpleCNN()
new_model.load_state_dict(torch.load('simple_cnn_model.pth'))
new_model.eval()

# 在测试集上评估加载的模型
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = new_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print('Test Accuracy of Loaded Model: {:.2%}'.format(accuracy))

        这个示例中,添加了模型的保存和加载过程。模型在训练完成后被保存到simple_cnn_model.pth文件,然后通过加载这个文件,可以重新创建模型并在测试集上进行评估。这对于在训练后的应用部署和模型共享上都是非常有用的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/245135.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Paper Reading: (CCVC) 基于冲突的半监督语义分割跨视图一致性

目录 简介目标/动机工作重点方法CVC: 跨视图一致性CPL: 基于冲突的伪标记 实验设置comparison with SOTAAblation 总结 简介 题目:《Conflict-Based Cross-View Consistency for Semi-Supervised Semantic Segmentation》, CVPR’23, 基于冲突的半监督语…

HPM5300系列--第二篇 Visual Studio Code开发环境以及多种调试器调试模式

一、目的 在博文《HPM5300系列--第一篇 命令行开发调试环境搭建》、《HPM6750系列--第四篇 搭建Visual Studio Code开发调试环境》中我们介绍了命令行方式开发环境,也介绍了HPM6750evkmini开发板如何使用Visual Studio Code进行开发调试(其中调试方式使用…

AcWing 338. 计数问题

文章目录 题目描述问题分析代码 题目描述 AcWing 338.计数问题 给定两个整数 a a a 和 b b b, 求 a a a 和 b b b中所有数字中0~9的出现次数 数据范围&#xff1a; 0 < a, b < 100000000 输入格式&#xff1a; 输入包含多组测试数据。 每组测试数据占一行&#xff0c;包…

AI会干掉美图秀秀们吗?

网上流传着这样一个传说&#xff0c;亚洲有三大“邪术”&#xff0c;韩国整容术、日本化妆术&#xff0c;还有震惊世界的中国PS 术。虽然是网友的戏称&#xff0c;但也反映了PS美图技术在国内盛行一时。 而说起美图技术就不得不提到美图公司&#xff0c;但美图公司近些年的日子…

影视动画行业发展现状与方向:AI技术推动动画工业化体系新变革

工业化体系 是国产动画强国的必经之路 中国动画的百年历程不仅是创作者们展现艺术才华的历程&#xff0c;也是一代代中国动画人不懈追求动画工业体系建设的历程。 为什么现在的中国动画需要建立工业化体系呢&#xff1f; 举个例子&#xff0c;在建立工业化体系之前&#xff…

推荐一个小而全的第三方登录开源组件

大家好&#xff0c;我是 Java陈序员。 我们在企业开发中&#xff0c;常常需要实现登录功能&#xff0c;而有时候为了方便&#xff0c;就需要集成第三方平台的授权登录。如常见的微信登录、微博登录等&#xff0c;免去了用户注册步骤&#xff0c;提高了用户体验。 为了业务考虑…

网络基础(十):DHCP原理与配置

目录 1、DHCP的概念 2、使用DHCP的优势 3、DHCP的分配方式 4、可分配的地址信息 5、DHCP的工作原理&#xff08;租约过程&#xff09; 6、DHCP动态配置主机地址&#xff08;使用eNSP软件配置&#xff09; 1、DHCP的概念 DHCP(Dynamic HostConfiguration Protocol&#x…

SLAM学习——相机模型(针孔+鱼眼)

针孔相机模型 针孔相机模型是很常用&#xff0c;而且有效的模型&#xff0c;它描述了一束光线通过针孔之后&#xff0c;在针孔背面投影成像的关系&#xff0c;基于针孔的投影过程可以通过针孔和畸变两个模型来描述。 模型中有四个坐标系&#xff0c;分别为world&#xff0c;c…

vite(一)——基本了解和依赖预构建

文章目录 一、什么是构建工具&#xff1f;1.为什么使用构建工具&#xff1f;2.构建工具的作用&#xff1f;3.构建工具怎么用&#xff1f; 二、经典面试题&#xff1a;webpack和vite的区别1.编译方式不同2.基础概念不同3.开发效率不同4.扩展性不同5.应用场景不同6.总结&#xff…

孩子都能学会的FPGA:第三十一课——用FPGA实现SPI主机发送数据

&#xff08;原创声明&#xff1a;该文是作者的原创&#xff0c;面向对象是FPGA入门者&#xff0c;后续会有进阶的高级教程。宗旨是让每个想做FPGA的人轻松入门&#xff0c;作者不光让大家知其然&#xff0c;还要让大家知其所以然&#xff01;每个工程作者都搭建了全自动化的仿…

【Python】conda镜像配置,.condarc文件详解,channel镜像

1. conda 环境 安装miniconda即可&#xff0c;Miniconda 安装包可以到 http://mirrors.aliyun.com/anaconda/miniconda/ 下载。 .condarc是conda 应用程序的配置文件&#xff0c;在用户家目录&#xff08;windows&#xff1a;C:\users\username\&#xff09;&#xff0c;用于…

SAP PP 配置学习(一)

物料主数据 一、定义物料类型属性 未检查的外部号分配&#xff1a;这是指如果用户自己输入的物料编码超出了定义的编码范围&#xff0c;是否会提示错误。 勾上代表不检查。 用户部门&#xff1a;这里勾选需要进行维护的视图&#xff0c;如果不选择&#xff0c;那么在新建和维…

苍穹外卖项目笔记(12)— 数据统计、Excel报表

前言 代码链接&#xff1a; Echo0701/take-out⁤ (github.com) 1 工作台 需求分析和设计 产品原型 工作台是系统运营的数据看板&#xff0c;并提供快捷操作入口&#xff0c;可以有效提高商家的工作效率 接口设计 ① 今日数据接口&#xff1a; ② 订单管理接口&#xff1…

智慧灯杆技术应用分析

智慧灯杆是指在传统灯杆的基础上&#xff0c;通过集成多种先进技术实现城市智能化管理的灯杆。智慧灯杆技术应用的分析如下&#xff1a; 照明功能&#xff1a;智慧灯杆可以实现智能调光、时段控制等功能&#xff0c;根据不同的需求自动调节照明亮度&#xff0c;提高照明效果&am…

论文阅读《Parameterized Cost Volume for Stereo Matching》

论文地址&#xff1a;https://openaccess.thecvf.com/content/ICCV2023/papers/Zeng_Parameterized_Cost_Volume_for_Stereo_Matching_ICCV_2023_paper.pdf 源码地址&#xff1a;https://github.com/jiaxiZeng/Parameterized-Cost-Volume-for-Stereo-Matching 概述 现有的立体匹…

SpringBoot整合Lucene实现全文检索【详细步骤】【附源码】

笑小枫的专属目录 1. 项目背景2. 什么是Lucene3. 引入依赖&#xff0c;配置索引3.1 引入Lucene依赖和分词器依赖3.2 表结构和数据准备3.3 创建索引3.4 修改索引3.5删除索引 4. 数据检索4.1 基础搜索4.2 一个关键词&#xff0c;在多个字段里面搜索4.3 搜索结果高亮显示4.4 分页检…

数组相关的题目

数组相关的题目 128. 最长连续序列 128. 最长连续序列 题目&#xff1a;给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。 很容易就能想到要先排序&#xff0c;再进行后续的处理。有一个坑&a…

Java集合--Map

1、Map集合概述 在Java的集合框架中&#xff0c;Map为双列集合&#xff0c;在Map中的元素是成对以<K,V>键值对的形式存在的&#xff0c;通过键可以找对所对应的值。Map接口有许多的实现类&#xff0c;各自都具有不同的性能和用途。常用的Map接口实现类有HashMap、Hashtab…

交易历史记录20231206 记录

昨日回顾&#xff1a; select top 10000 * from dbo.CODEINFO A left join dbo.全部&#xff21;股20231206010101 B ON A.CODE B.代码 left join dbo.全部&#xff21;股20231206CONF D on A.CODED.代码left join dbo.全部&#xff21;股20231206 G on A.CODEG.代码 left…

铭飞CMS list 接口 SQL注入漏洞复现

0x01 产品简介 铭飞CMS是一款基于java开发的一套轻量级开源内容管理系统,铭飞CMS简洁、安全、开源、免费,可运行在Linux、Windows、MacOSX、Solaris等各种平台上,专注为公司企业、个人站长快速建站提供解决方案 0x02 漏洞概述 铭飞CMS在5.2.10版本以前list 接口处存在sql注入…