深度学习之PyTorch实现卷积神经网络(CNN)

在深度学习领域,卷积神经网络(Convolutional Neural Networks,CNN)是一种非常强大的模型,专门用于处理图像数据。CNN通过卷积操作和池化操作来提取图像中的特征,具有较好的特征学习能力,特别适用于图像识别和计算机视觉任务。PyTorch作为一种流行的深度学习框架,提供了方便易用的工具来构建和训练CNN模型。本文将介绍如何使用PyTorch构建一个简单的CNN,并通过一个图像分类任务来演示其效果。

1. CNN的结构

典型的CNN结构包括卷积层、池化层和全连接层。卷积层通过卷积操作提取图像特征,池化层通过降采样操作减小特征图的尺寸,全连接层用于最终的分类。

在这里插入图片描述

2. 环境配置

在开始之前,确保已经安装了PyTorch和相关的Python库。可以通过以下命令安装:

pip install torch torchvision matplotlib

3. 数据集准备

在这个示例中,我们将使用PyTorch提供的CIFAR-10数据集,它包含了10个类别的60000张32x32彩色图像。我们将图像分为训练集和测试集,并加载到PyTorch的数据加载器中。

import torch
import torchvision
import torchvision.transforms as transforms

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

4. 构建CNN模型

我们将构建一个简单的CNN模型,包括卷积层、池化层、全连接层和激活函数。这个模型将接受3通道的32x32图像作为输入,并输出10个类别的概率分布。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
	super().__init__()
        # 定义网络层:卷积层+池化层
        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层
        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)


    def forward(self, x):
        # 卷积+relu+池化
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        # 卷积+relu+池化
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        # 将特征图做成以为向量的形式:相当于特征向量
        x = x.reshape(x.size(0), -1)
        # 全连接层
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        # 返回输出结果
        return self.out(x)

net = Net()

5. 模型训练

定义损失函数和优化器,并在训练集上训练CNN模型。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练网络
for epoch in range(2):  # 多次遍历数据集
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:  # 每2000个小批量数据打印一次损失值
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('训练结束!!!')

6. 模型测试

在测试集上评估训练好的模型的性能。

correct = 0
total = 0
# 禁用梯度追踪
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('10000 测试图片的准确率: %d %%' % (100 * correct / total))

7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

%%’ % (100 * correct / total))


## 7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

这篇博客介绍了如何使用PyTorch构建和训练一个简单的卷积神经网络。通过实际的代码示例,读者可以了解CNN模型的基本原理,并掌握如何在PyTorch中实现和训练这样的模型。希望本文能够对你理解和应用深度学习模型有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/546704.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

机器学习算法快速入门

文章目录 0.简介1.常用术语1) 模型2) 数据集3) 样本&特征4) 向量5) 矩阵6)假设函数&损失函数7)拟合&过拟合&欠拟合 2.线性回归3.梯度下降求极值4.Logistic回归算法(分类问题)5.KNN最邻近分类算法6.朴素贝叶斯分类算法7.决策树…

Python调用微信OCR识别文字和坐标

原理 在看雪看到一篇文章:逆向调用QQ截图NT与WeChatOCR-软件逆向。里面说了怎么调用微信和QQ本地的OCR模型,还有很详细的分析过程。 我稍微看了下文章,多的也看不懂。大概流程是使用mmmojo.dll这个dll来与WeChatOCR.exe做通信的&#xff0c…

1688拍立淘接口:图像识别技术引领电商搜索新革命,打造智能购物新体验!

1688拍立淘接口:技术解析与应用实践 一、引言 在电子商务蓬勃发展的今天,图像识别技术正逐渐成为各大电商平台提升用户体验、优化搜索效率的关键技术之一。作为阿里巴巴旗下的B2B采购批发平台,1688也紧跟技术潮流,推出了拍立淘接…

Dos命令的基础

雷迪斯and the乡亲们 欢迎你们来到 奇幻的编程世界 一、DOS命令基础 提示符 根目录:进入大到分区后,最外层的目录就是跟目录 工作目录:当前的所在位置/所在文件夹 二、cd命令 切换工作目录: cd 格式: cd 目标 …

巧避海森堡不确定性原理!量子比特读出技术重磅突破

内容来源:量子前哨(ID:Qforepost) 文丨浪味仙 排版丨沛贤 深度好文:1200字丨7分钟阅读 摘要:阿尔托大学研究人员用微测辐射热计替代传统参数放大器,以更少的附加噪声实现非侵入式量子比特测量…

排序链表 - LeetCode 热题 33

大家好!我是曾续缘😹 今天是《LeetCode 热题 100》系列 发车第 33 天 链表第 12 题 ❤️点赞 👍 收藏 ⭐再看,养成习惯 排序链表 给你链表的头结点 head ,请将其按 升序 排列并返回 排序后的链表 。 示例 1&#xff1a…

带你追踪 ICASSP 2024会议现场 韩国夜景令人陶醉

会议之眼 快讯 昨天,2024年的ICASSP(International Conference on Acoustics, Speech, and Signal Processing)即国际声学、语音和信号处理会议已经在韩国首尔拉开帷幕!吸引了众多热情的与会者!本届ICASSP会议举办日期…

实验笔记之——RGBD GS-ICP SLAM配置与测试

《RGBD GS-ICP SLAM》是最新开源的一个3DGS-SLAM工作,通过利用GICP来实现当前帧gaussian与已mapping的gaussian进行匹配进行位姿的估算,并通过关键帧的选择策略来进一步提升performance~ Use G-ICP to align the current frame with the 3D GS map whic…

Redis消息队列-基于PubSub的消息队列

7.3 Redis消息队列-基于PubSub的消息队列 PubSub(发布订阅)是Redis2.0版本引入的消息传递模型。顾名思义,消费者可以订阅一个或多个channel,生产者向对应channel发送消息后,所有订阅者都能收到相关消息。 SUBSCRIBE …

OpenHarmony实战开发-图片选择和下载保存案例。

介绍 本示例介绍图片相关场景的使用:包含访问手机相册图片、选择预览图片并显示选择的图片到当前页面,下载并保存网络图片到手机相册或到指定用户目录两个场景。 效果图预览 使用说明 从主页通用场景集里选择图片选择和下载保存进入首页。分两个场景点…

Linux的重要命令(二)+了解Linux目录结构

目录 一.Linux的目录结构 二.查看文件内容命令 1.cat 命令 2.more 命令 3.less 命令 4.head 命令 5.tail 命令 6.拓展 head 和 tail 的其他用法 ​编辑 三.统计文件内容的命令-wc ​编辑 四.检索和过滤文件内容的命令-grep ​编辑 ​编辑 五.压缩命令 gzip 和 bz…

Canvas 画布基本用法详解

Canvas 画布 HTML中的 <canvas> 标签用于动态绘制图形&#xff0c;所有在<canvas>中的画图必须用JavaScript完成。 <canvas>标签是透明的&#xff0c;它是图形的容器&#xff0c;必须使用脚本才能实际绘制图形。 绘制一个简单的矩形 <!-- canvas标签&a…

Python基于卷积神经网络的车牌识别系统

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

【数据结构与算法】递推

来源&#xff1a;《信息学奥赛一本通》 所谓递推&#xff0c;是指从已知的初始条件出发&#xff0c;依据某种递推关系&#xff0c;逐次推出所要求的各中间结果及最后结果。其中初始条件或是问题本身已经给定&#xff0c;或是通过对问题的分析与化简后确定。 从已知条件出发逐…

jenkins(docker)安装及应用

jenkins Jenkins是一个开源的、提供友好操作界面的持续集成(CI)工具&#xff0c;起源于Hudson&#xff08;Hudson是商用的&#xff09;&#xff0c;主要用于持续、自动的构建/测试软件项目、监控外部任务的运行&#xff08;这个比较抽象&#xff0c;暂且写上&#xff0c;不做解…

论坛直击|发展新质生产力,高校怎么做?

新质生产力浪潮涌动 三大议题聚焦高校人才培养 今年全国两会的政府工作报告将“大力推进现代化产业体系建设&#xff0c;加快发展新质生产力”列在2024年政府工作任务首位&#xff0c;发展新质生产力的先导是培养拔尖创新人才&#xff0c;高等教育改革必须以立德树人为根本任…

幽灵漏洞进阶版来了

近日&#xff0c;网络安全研究人员披露了针对英特尔系统上 Linux 内核的首个原生 Spectre v2 漏洞&#xff0c;该漏洞是2018 年曝出的严重处理器「幽灵」&#xff08;Spectre&#xff09;漏洞 v2 衍生版本&#xff0c;利用该漏洞可以从内存中读取敏感数据&#xff0c;主要影响英…

Java怎么获取今天最早的时间

今天在实现项目里的一个功能的时候&#xff0c;需要获取今天最早的时间&#xff0c;比如今天是2024-4-15&#xff0c;则今天的开始时间为2024-4-14日24点之后&#xff08;2024-4-15零点&#xff09;的那个时间点。 这篇文章就分享一下博主获取这个时间的方法&#xff1a; 很简…

Python数据可视化库—Bokeh与Altair指南【第161篇—数据可视化】

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 在数据科学和数据分析领域&#xff0c;数据可视化是一种强大的工具&#xff0c;可以帮助我们…

代码随想录刷题day53|最长公共子序列不相交的线最大子序和

文章目录 day53学习内容一、最长公共子序列1.1、动态规划五部曲1.1.1、 确定dp数组&#xff08;dp table&#xff09;以及下标的含义1.1.2、确定递推公式1.1.3、 dp数组如何初始化1.1.4、确定遍历顺序1.1.5、输出结果 1.2、代码 二、不相交的线2.1、动态规划五部曲2.1.1、 确定…