【深度学习入门篇 ⑧】关于卷积神经网络

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


关于卷积神经网络,你还有哪些不知道的知识点呢,之前我们介绍了大部分,今天再来补充一下~

卷积神经网络基础

什么是卷积

Convolution,输入信息与核函数(滤波器)的乘积

  • 一维信号的时间卷积:输入x,核函数w,输出是一个连续时间段t的加权平均结果。
  • 二维图像的空间卷积:输入图像I,卷积核K,输出图像O。

单个二维图片卷积 :输入为单通道图像,输出为单通道图像。

图像的数据存储是多通道的二维矩阵:

灰度图(Gray)只有一个通道(一层),RGB彩色图就是三个通道(Red,Green,Blue),而RGBA彩色图就是四个通道(Red,Green,Blue,Alpha)。

如何表达每一个网络层中高维的图像数据?

特征图包含:通道,宽度,高度,其中输入特征图Ci,输出特征图C0,输出特征图的每一个通道,由输入图的所有通道和相同数量的卷积核先一一对应各自进行卷积计算,然后求和

 

卷积相关操作与参数

填充

padding :给卷积前的输入图像边界添加额外的行,列

  • 控 制 卷 积 后 图 像分 辨 率 , 方便计算特征图尺寸的变化
  • 弥 补 边 界 信 息 “丢 失 ” 

步长

步长(stride):卷积核在图像上移动的步子

卷积的核心思想 

为什么要进行局部连接?

  • 局部连接可以更好地利用图像中的结构信息,空间距离越相近的像素其相互影响越大

权重共享:保证不变性,图像从一个局部区域学习到的信息应用到其他区域 ,减少参数,降低学习难度。

ANN与CNN比较

传统神经网络为有监督的机器学习,输入为特征;卷积神经网络为无监督特征学习,输入为最原始的图像。


案例-图像分类

CIFAR10 数据集

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。

 PyTorch 中的 torchvision.datasets 计算机视觉模块封装了 CIFAR10 数据集:

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


def func1():

    # 加载数据集
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))

    print('训练集数量:', len(train.targets))
    print('测试集数量:', len(valid.targets))

    print("数据集形状:", train[0][0].shape)

    print("数据集类别:", train.class_to_idx)


# 数据加载器
def func2():

    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    dataloader = DataLoader(train, batch_size=8, shuffle=True)
    for x, y in dataloader:
        print(x.shape)
        print(y)
        break


if __name__ == '__main__':
    func1()
    func2()

我们要搭建的网络结构:

我们在每个卷积计算之后应用 relu 激活函数来给网络增加非线性因素。

网络代码实现:

class ImageClassification(nn.Module):


    def __init__(self):

        super(ImageClassification, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = x.reshape(x.size(0), -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))

        return self.out(x)

 编写训练函数

训练时,使用多分类交叉熵损失函数,Adam 优化器:

def train():

    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)


    model = ImageClassification()

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # 训练轮数
    epoch = 100

    for epoch_idx in range(epoch):

        # 构建数据加载器
        dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
        # 样本数量
        sam_num = 0
        # 损失总和
        total_loss = 0.0
        # 开始时间
        start = time.time()
        correct = 0

        for x, y in dataloader:
            # 送入模型
            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            correct += (torch.argmax(output, dim=-1) == y).sum()
            total_loss += (loss.item() * len(y))
            sam_num += len(y)

        print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %
              (epoch_idx + 1,
               total_loss / sam_num,
               correct / sam_num,
               time.time() - start))


    torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

我们加载训练好的模型,对测试集中的 1 万条样本进行预测,查看模型在测试集上的准确率

def test():


    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)
    # 构建数据加载器
    dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型
    model = ImageClassification()
    model.load_state_dict(torch.load('model/image_classification.bin'))
    model.eval()


    total_correct = 0
    total_samples = 0
    for x, y in  dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)

    print('Acc: %.2f' % (total_correct / total_samples))

输出:

'Acc: 0.61

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/843245.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

多租户分库分表同步数据库DDL脚本

我们在实现多租户系统的时候,为了数据安全和性能,往往会把数据库设计成一个租户一个数据库,如下图,主库记录了租户信息和对应的数据库地址,租户数据库则存储了租户相关的数据,租户数据库的表结构都是一致的,这种方式有…

npm 安装报错(已解决)+ 运行 “wue-cli-service”不是内部或外部命令,也不是可运行的程序(已解决)

首先先说一下我这个项目是3年前的一个项目了,中间也是经过了多个人的修改惨咋了布置多少个人的思想,这这道我手里直接npm都安装不上,在网上也查询了多种方法,终于是找到问题所在了 问题1: 先是npm i 报错在下面图片&…

全新AI工具——PaintsUndo:一键自动还原图像绘画过程!

ControlNet 作者 Lvmin Zhang 又开始整活了!这次发布的PaintsUndo 只需要上传一张图片, 就能够一键生成绘画过程!快来了解学习! 1、核心技术 PaintsUndo 是一项突破性的技术,旨在通过输入静态图像,自动生…

SpringCloudAlibaba-Seata2.0.0与Nacos2.2.1

一、下载 ## 下载seata wget https://github.com/apache/incubator-seata/releases/download/v2.0.0/seata-server-2.0.0.tar.gz## 解压 tar zxvf seata-server-2.0.0.tar.gz二、执行sql文件 ## 取出sql文件执行 cd /seata/script/server/db/mysql ## 找个mysql数据库执行三、…

分布式服务框架zookeeper+消息队列kafka

一、zookeeper概述 zookeeper是一个分布式服务框架,它主要是用来解决分布式应用中经常遇到的一些数据管理问题,如:命名服务,状态同步,配置中心,集群管理等。 在分布式环境下,经常需要对应用/服…

钡铼分布式I/O系统边缘计算Modbus,MQTT,OPC UA耦合器BL206

BL206系列耦合器是一个数据采集和控制系统,基于强大的32 位微处理器设计,采用Linux操作系统,支持Modbus,MQTT,OPC UA协议,可以快速接入现场PLC、DCS、PAS、MES、Ignition和SCADA以及ERP系统,同时…

numpy(史上最全)

目录 numpy简介 性能对比 ndarray属性 numpy中的数组: 几个创建的函数: 1) np.ones(shape, dtypeNone, orderC) shape: 形状,使用元组表示 2) np.zeros(shape, dtypefloat, orderC) 3) np.full(shape, fill_value, dtypeNone, orderC)…

核函数支持向量机(Kernel SVM)

核函数支持向量机(Kernel SVM)是一种非常强大的分类器,能够在非线性数据集上实现良好的分类效果。以下是关于核函数支持向量机的详细数学模型理论知识推导、实施步骤与参数解读,以及两个多维数据实例(一个未优化模型&a…

IVI(In-Vehicle Infotainment,智能座舱的信息娱乐系统)

IVI能够实现包括三维导航、实时路况、辅助驾驶等在线娱乐功能。 IVI人机交互形式(三板斧):声音、图像、文字 IVI人机交互媒介I(四件套):中控屏幕(显示、触控)、仪表显示、语言、方…

吴恩达大模型系列课程《Prompt Compression and Query Optimization》中文学习打开方式

Prompt Compression and Query Optimization GPT-4o详细中文注释的Colab观看视频1 浏览器下载插件2 打开官方视频 GPT-4o详细中文注释的Colab 中文注释链接:https://github.com/Czi24/Awesome-MLLM-LLM-Colab/tree/master/Courses/1_Prompt-Compression-and-Query-…

基于 Three.js 的 3D 模型加载优化

作者:来自 vivo 互联网前端团队- Su Ning 作为一个3D的项目,从用户打开页面到最终模型的渲染需要经过多个流程,加载的时间也会比普通的H5项目要更长一些,从而造成大量的用户流失。为了提升首屏加载的转化率,需要尽可能…

ELK企业级日志分析

目 录 一、ELK简介 1.1 elasticsearch简介 1.2 logstash简介 1.3 kibana简介 1.4 ELK的好处 1.5 ELK的工作原理 二、部署ELK 2.1 部署elasticsearch(集群) 2.1.1 修改配置文件 2.1.2 修改系统参数 2.1.2.1 修改systemmd服务管理器 2.1.2.2 性能调优参数 2.1.2.3 …

9.11和9.9哪个大?

没问题 文心一言 通义千问

信创数据库沙龙(南京站 | 开启报名)

信创数据库沙龙: 是一个致力于推动数据库技术创新和发展的高端交流平台,旨在增强国内数据库产业的自主可控性和高质量发展。这个平台汇集了学术界和产业界的顶尖专家、学者以及技术爱好者,通过专题演讲、案例分享和技术研讨等丰富多样的活动形式&#x…

(ISPRS,2021)具有遥感知识图谱的鲁棒深度对齐网络用于零样本和广义零样本遥感图像场景分类

文章目录 Robust deep alignment network with remote sensing knowledge graph for zero-shot and generalized zero-shot remote sensing image scene classification相关资料摘要引言遥感知识图谱的表示学习遥感知识图谱的构建实体和关系的语义表示学习创建遥感场景类别的语…

【C#】计算两条直线的交点坐标

问题描述 计算两条直线的交点坐标,可以理解为给定坐标P1、P2、P3、P4,形成两条线,返回这两条直线的交点坐标? 注意区分:这两条线是否垂直、是否平行。 代码实现 斜率解释 斜率是数学中的一个概念,特别是…

TiDB实践—索引加速+分布式执行框架创建索引提升70+倍

作者: 数据源的TiDB学习之路 原文来源: https://tidb.net/blog/92d348c2 背景介绍 TiDB 采用在线异步变更的方式执行 DDL 语句,从而实现 DDL 语句的执行不会阻塞其他会话中的 DML 语句。按照是否需要操作 DDL 目标对象所包括的数据来划分…

QT样式美化 之 qss入门

样例一 *{font-size:13px;color:white;font-family:"宋体"; }CallWidget QLineEdit#telEdt {font-size:24px;}QMainWindow,QDialog{background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,stop: 0 #1B2534, stop: 0.4 #010101,stop: 0.5 #000101, stop: 1.0 #1F2B…

springboot项目中,yml文件乱码

项目场景: 在springboot项目的resource目录,新建yml文件,并且输入了中文,但是关闭idea,再打开,里面的中文乱码了 问题描述 原因分析: 编码设置相关 解决方案: 方案1&#xff0…

使用 XPath 定位 HTML 中的 img 标签

引言 随着互联网内容的日益丰富,网页数据的自动化处理变得愈发重要。图片作为网页中的重要组成部分,其获取和处理在许多应用场景中都显得至关重要。例如,在社交媒体分析、内容聚合平台、数据抓取工具等领域,图片的自动下载和处理…