计算机视觉-LeNet

目录

LeNet

LeNet在手写数字识别上的应用

LeNet在眼疾识别数据集iChallenge-PM上的应用


LeNet

LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征,其架构如 图1 所示,这里展示的是用于MNIST手写体数字识别任务中的LeNet-5模型:
 


图1:LeNet模型网络结构示意图


 

  • 第一模块:包含5×5的6通道卷积和2×2的池化。卷积提取图像中包含的特征模式(激活函数使用Sigmoid),图像尺寸从28减小到24。经过池化层可以降低输出特征图对空间位置的敏感性,图像尺寸减到12。

  • 第二模块:和第一模块尺寸相同,通道数由6增加为16。卷积操作使图像尺寸减小到8,经过池化后变成4。

  • 第三模块:包含4×4的120通道卷积。卷积之后的图像尺寸减小到1,但是通道数增加为120。将经过第3次卷积提取到的特征图输入到全连接层。第一个全连接层的输出神经元的个数是64,第二个全连接层的输出神经元个数是分类标签的类别数,对于手写数字识别的类别数是10。然后使用Softmax激活函数即可计算出每个类别的预测概率。


【提示】:

卷积层的输出特征图如何当作全连接层的输入使用呢?

卷积层的输出数据格式是[N,C,H,W][N, C, H, W][N,C,H,W],在输入全连接层的时候,会自动将数据拉平,

也就是对每个样本,自动将其转化为长度为KKK的向量,


LeNet在手写数字识别上的应用

LeNet网络的实现代码如下:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear

## 组网
import paddle.nn.functional as F

# 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):
    def __init__(self, num_classes=1):
        super(LeNet, self).__init__()
        # 创建卷积和池化层
        # 创建第1个卷积层
        self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)
        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
        # 尺寸的逻辑:池化层未改变通道数;当前通道数为6
        # 创建第2个卷积层
        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)
        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
        # 创建第3个卷积层
        self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)
        # 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]
        # 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120
        self.fc1 = Linear(in_features=120, out_features=64)
        # 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数
        self.fc2 = Linear(in_features=64, out_features=num_classes)
    # 网络的前向计算过程
    def forward(self, x):
        x = self.conv1(x)
        # 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化
        x = F.sigmoid(x)
        x = self.max_pool1(x)
        x = F.sigmoid(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        # 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]
        x = paddle.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        return x

飞桨会根据实际图像数据的尺寸和卷积核参数自动推断中间层数据的W和H等,只需要用户表达通道数即可。下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状。

# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')

# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
x = paddle.to_tensor(x)
for item in model.sublayers():
    # item是LeNet类中的一个子层
    # 查看经过子层之后的输出数据形状
    try:
        x = item(x)
    except:
        x = paddle.reshape(x, [x.shape[0], -1])
        x = item(x)
    if len(item.parameters())==2:
        # 查看卷积和全连接层的数据和参数的形状,
        # 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数b
        print(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)
    else:
        # 池化层没有参数
        print(item.full_name(), x.shape)

卷积Conv2D的padding参数默认为0,stride参数默认为1,当输入形状为[Bx1x28x28]时,B是batch_size,经过第一层卷积(kernel_size=5, out_channels=6)和maxpool之后,得到形状为[Bx6x12x12]的特征图;经过第二层卷积(kernel_size=5, out_channels=16)和maxpool之后,得到形状为[Bx16x4x4]的特征图;经过第三层卷积(out_channels=120, kernel_size=4)之后,得到形状为[Bx120x1x1]的特征图,在FC层计算之前,将输入特征从卷积得到的四维特征reshape到格式为[B, 120x1x1]的特征,这也是LeNet中第一层全连接层输入shape为120的原因。

# -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST

# 定义训练过程
def train(model, opt, train_loader, valid_loader):
    # 开启0号GPU训练
    use_gpu = True
    paddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')
    print('start training ... ')
    model.train()
    for epoch in range(EPOCH_NUM):
        for batch_id, data in enumerate(train_loader()):
            img = data[0]
            label = data[1] 
            # 计算模型输出
            logits = model(img)
            # 计算损失函数
            loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
            loss = loss_func(logits, label)
            avg_loss = paddle.mean(loss)

            if batch_id % 2000 == 0:
                print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))
            avg_loss.backward()
            opt.step()
            opt.clear_grad()

        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(valid_loader()):
            img = data[0]
            label = data[1] 
            # 计算模型输出
            logits = model(img)
            pred = F.softmax(logits)
            # 计算损失函数
            loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
            loss = loss_func(logits, label)
            acc = paddle.metric.accuracy(pred, label)
            accuracies.append(acc.numpy())
            losses.append(loss.numpy())
        print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
        model.train()

    # 保存模型参数
    paddle.save(model.state_dict(), 'mnist.pdparams')


# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。那么对于其它数据集效果如何呢?我们通过眼疾识别数据集iChallenge-PM验证一下。

LeNet在眼疾识别数据集iChallenge-PM上的应用

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。下面我们详细介绍LeNet在iChallenge-PM上的训练过程。


说明:

如今近视已经成为困扰人们健康的一项全球性负担,在近视人群中,有超过35%的人患有重度近视。近视会拉长眼睛的光轴,也可能引起视网膜或者络网膜的病变。随着近视度数的不断加深,高度近视有可能引发病理性病变,这将会导致以下几种症状:视网膜或者络网膜发生退化、视盘区域萎缩、漆裂样纹损害、Fuchs斑等。因此,及早发现近视患者眼睛的病变并采取治疗,显得非常重要。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/98665.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

电脑组装教程分享!

案例:如何自己组装电脑? 【看到身边的小伙伴组装一台自己的电脑,我也想试试。但是我对电脑并不是很熟悉,不太了解具体的电脑组装步骤,求一份详细的教程!】 电脑已经成为我们日常生活中不可或缺的一部分&a…

vue3组合式api bus总线式通信

vue2中可以创建一个 vue 实例, 做为 总结来完成组件间的通信 但是在vue3中, 这种方法是不能使用的。 因为vue3中main.js中, 使用的createApp() 没有机会再写 new Vue了 但是我们可以使用 mitt 的插件来解决这个问题 vue3 bus组件的用法 安装…

Python序列类型

序列(Sequence)是有顺序的数据列,Python 有三种基本序列类型:list, tuple 和 range 对象,序列(Sequence)是有顺序的数据列,二进制数据(bytes) 和 文本字符串&…

MATLAB中isequal函数转化为C语言

背景 有项目算法使用matlab中isequal函数进行运算,这里需要将转化为C语言,从而模拟算法运行,将算法移植到qt。 MATLAB中isequal简单介绍 语法 tf isequal(A,B) tf isequal(A1,A2,...,An) 说明 如果 A 和 B 等效,则 tf is…

飞书接入ChatGPT,实现智能化问答助手功能,提供高效的解答服务

文章目录 前言环境列表1.飞书设置2.克隆feishu-chatgpt项目3.配置config.yaml文件4.运行feishu-chatgpt项目5.安装cpolar内网穿透6.固定公网地址7.机器人权限配置8.创建版本9.创建测试企业10. 机器人测试 前言 在飞书中创建chatGPT机器人并且对话,在下面操作步骤中…

DataLoader的使用

示例代码: import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter# 准备的测试数据集 test_data torchvision.datasets.CIFAR10("./dataset", trainFalse, transformtorchvision.transforms.…

postgresql-窗口函数

postgresql-窗口函数 简介窗口函数的定义分区选项(PARTITION BY)排序选项(ORDER BY)窗口选项(frame_clause) 聚合窗口函数排名窗口函数演示了 CUME_DIST 和 NTILE 函数 取值窗口函数 简介 常见的聚合函数&…

java八股文面试[数据库]——MySql聚簇索引和非聚簇索引区别

聚集索引和非聚集索引 聚集索引和非聚集索引的根本区别是表记录的排列顺序和与索引的排列顺序是否一致。 1、聚集索引 聚集索引表记录的排列顺序和索引的排列顺序一致(以InnoDB聚集索引的主键索引来说,叶子节点中存储的就是行数据,行数据在…

音频母带制作::AAMS V4.0 Crack

自动音频母带制作简介。 使用 AAMS V4 让您的音乐听起来很美妙! 作为从事音乐工作的音乐家,您在向公众发布材料时需要尽可能最好的声音,而为所有音频扬声器系统提供良好的商业声音是一项困难且耗时的任务。AI掌握的力量! 掌控您…

Web安全——信息收集上篇

Web安全 一、信息收集简介二、信息收集的分类三、常见的方法四、在线whois查询在线网站备案查询 五、查询绿盟的whois信息六、收集子域名1、子域名作用2、常用方式3、域名的类型3.1 A (Address) 记录:3.2 别名(CNAME)记录:3.3 如何检测CNAME记录&#xf…

flutter自定义按钮-文本按钮

目录 前言 需求 实现 前言 最近闲着无聊学习了flutter的一下知识,发现flutter和安卓之间,页面开发的方式还是有较大的差异的,众所周知,android的页面开发都是写在xml文件中的,而flutter直接写在代码里(da…

Java 16进制字符串转换成GBK字符串

问题: 现在已知有一个16进制字符串 435550D3C3D3DAD4DABDBBD2D7CFECD3A6CFFBCFA2D6D0B4E6B7C5D5DBBFDBD0C5CFA2A3ACD5DBBFDBBDF0B6EE3130302E3036 而且知道这串的字符串对应的内容是: CUP用于在交易响应消息中存放折扣信息,折扣金额100.06 但…

如何为你的公司选择正确的AIGC解决方案?

如何为你的公司选择正确的AIGC解决方案? 摘要引言词汇解释(详细版本)详细介绍1. 确定需求2. 考虑技术能力3. 评估可行性4. 比较不同供应商 代码快及其注释注意事项知识总结 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&…

java实现多文件压缩zip

1,需求 需求要求实现多个文件压缩为zip文件 2,代码 package com.example.demo;import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import…

三极管,MOS管开关应用总结

目录 一、符号二、共同点1.三级管,mos管符号中都有一个PN结,可以根据PN结方向区分型号 二、区别1. 三级管导通条件:PN结正偏2. mos管导通条件: PN结反偏这样就不需要记哪个极的电压, 一、符号 1.三极管 2.mos管 MOS…

微信小程序scroll-view隐藏滚动条参数不生效问题

如题&#xff0c;先来看看问题是怎么出现的。 先看文档如何隐藏滚动条&#xff1a; 再根据文档实现wxml文件&#xff1a; <scroll-view show-scrollbar"{{false}}" enhanced><view wx:for"{{1000}}">11111</view> </scroll-view>…

(笔记一)利用open_cv在图像上进行点标记,文字注记,画圆、多边形、椭圆

&#xff08;1&#xff09;CV2中的绘图函数&#xff1a; cv2.line() 绘制线条cv2.circle() 绘制圆cv2.rectangle() 绘制矩形cv2.ellipse() 绘制椭圆cv2.putText() 添加注记 &#xff08;2&#xff09;注释 img表示需要绘制的图像color表示线条的颜色&#xff0c;采用颜色矩阵…

【跟小嘉学 Rust 编程】二十、进阶扩展

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 【跟小嘉学 Rust 编程】二、Rust 包管理工具使用 【跟小嘉学 Rust 编程】三、Rust 的基本程序概念 【跟小嘉学 Rust 编程】四、理解 Rust 的所有权概念 【跟小嘉学 Rust 编程】五、使用结构体关联结构化数据 【跟小嘉学…

静态树提升对Vue生态系统的影响和发展

文章目录 1. 了解Vue 3的静态树提升介绍Vue 3的基本概念和优势解释静态树提升的作用和目标 2. 什么是静态树&#xff1f;解释静态树的概念和特点比较静态树和动态树的区别 3. Vue 3中的静态树提升解释Vue 3中静态树提升的原理和工作方式强调静态树提升对性能的影响和优化效果 4…

Spring MVC: 请求参数的获取

Spring MVC 前言通过 RequestParam 注解获取请求参数RequestParam用法 通过 ServletAPI 获取请求参数通过实体类对象获取请求参数附 前言 在 Spring MVC 介绍中&#xff0c;谈到前端控制器 DispatcherServlet 接收客户端请求&#xff0c;依据处理器映射 HandlerMapping 配置调…