计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用,本文我们借助PyTorch,快速构建和训练卷积神经网络(CNN)等模型,以实现街道房屋号码的准确识别。引入并注意力机制,它是一种模仿人类视觉注意机制的方法,在图像处理任务中具有广泛应用。通过引入注意力机制,模型可以自动关注图像中与房屋号码相关的区域,提高识别的准确性和鲁棒性。

一、项目介绍

街道房屋号码识别是计算机视觉中的一个重要任务,通过对街道房屋号码的自动识别,可以对街道图像进行更好的理解和分析。本文将介绍如何使用PyTorch框架和注意力机制,结合SVHN数据集,来实现街道房屋号码的分类识别。

二、SVHN数据集

SVHN(Street View House Numbers)是一个公开的大规模街道数字图像数据集。该数据集包含了从Google Street View中获取的房屋门牌号码图像,可以用于训练和测试机器学习模型,以实现自动识别街道房屋号码的任务。

2.1 数据集下载和加载

首先,我们需要下载并加载SVHN数据集。在PyTorch中,我们可以使用torchvision库中的datasets模块来实现这一步。

数据集的下载与查看:

train_dataset = datasets.SVHN(root='./data', split='train', download=True)

images = train_dataset.data[:10]  # shape: (10, 3, 32, 32)
labels = train_dataset.labels[:10]

images = np.transpose(images, (0, 2, 3, 1))

# Plot the images
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
axs = axs.ravel()

for i in range(10):
    axs[i].imshow(images[i])
    axs[i].set_title(f"Label: {labels[i]}")
    axs[i].axis('off')

plt.tight_layout()
plt.show()

在这里插入图片描述

数据集的加载,预处理,便于输入模型训练:

import torch
from torchvision import datasets, transforms

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# 下载并加载SVHN数据集
trainset = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = datasets.SVHN(root='./data', split='test', download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

三、卷积网络搭建

使用PyTorch搭建卷积神经网络。卷积神经网络(Convolutional Neural Network, CNN)是一种主要用于处理具有类似网格结构的数据的神经网络,如图像(2D网格的像素点)或者文本(1D网格的单词)。

3.1 网络结构定义

下面是一个基础的卷积神经网络模型,包含两个卷积层、两个最大池化层和两个全连接层。

from torch import nn

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        return self.fc2(out)

四、加入注意力机制

注意力机制是一种能够改进模型性能的技术。在我们的模型中,我们将添加一个注意力层来帮助模型更好地专注于输入图像中的重要部分。

4.1 注意力层定义

我将实现基本的注意力层,这个层将会生成一个和输入同样大小的注意力图,然后将输入和这个注意力图对应元素相乘,以此来实现对输入的加权。

注意力机制层的数学原理:
注意力机制的数学原理可以用以下公式表示:

给定输入张量 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w,其中 b b b 是批量大小, c c c 是通道数, h h h 是高度, w w w 是宽度。注意力机制分为两个阶段:特征提取和特征加权。
1.特征提取阶段:
首先,通过自适应平均池化层(AdaptiveAvgPool2d)将输入张量 x x x 在高度和宽度上进行平均池化,得到形状为 b × c × 1 × 1 b \times c \times 1 \times 1 b×c×1×1 的张量 y y y。这里使用自适应平均池化是为了使得张量 y y y 在不同尺寸的输入上也能产生相同的输出。
2.特征加权阶段:
接下来,通过全连接层(Linear)和非线性激活函数ReLU对张量 y y y 进行特征变换,减少通道数,并保留重要特征。然后再通过另一个全连接层和Sigmoid激活函数得到权重张量 y ′ ∈ R b × c × 1 × 1 y' \in \mathbb{R}^{b \times c \times 1 \times 1} yRb×c×1×1,表示每个通道的权重值。这里的权重值在0到1之间,用于控制每个通道在后续的计算中所占的比重。将权重张量 y ′ y' y 扩展成与输入张量 x x x 相同的形状,并将其与输入张量相乘,得到经过注意力加权的特征张量。这样就实现了对输入张量的自适应特征加权。

数学表示为:
y = AdaptiveAvgPool2d ( x ) y ′ = Sigmoid ( Linear ( ReLU ( Linear ( y ) ) ) ) output = x ⊙ y ′ y = \text{AdaptiveAvgPool2d}(x) \\ y' = \text{Sigmoid}(\text{Linear}(\text{ReLU}(\text{Linear}(y)))) \\ \text{output} = x \odot y' y=AdaptiveAvgPool2d(x)y=Sigmoid(Linear(ReLU(Linear(y))))output=xy

其中 ⊙ \odot 表示按元素相乘操作。

注意力机制层的搭建代码:

class AttentionLayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(AttentionLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel// reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

4.2 在网络中加入注意力层

我们将注意力层加入到ConvNet模型中:

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(32))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(64))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(8 * 8 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        return self.fc2(out)

五、模型训练与测试

接下来,我们将进行模型的训练和测试。

5.1 模型训练

import torch.optim as optim

model = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5.2 模型测试

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

六、结论

这篇文章就像是一张奇妙的地图,引领你进入计算机视觉任务的神奇世界。在这个世界里,你将与PyTorch和注意力机制这两位强大的伙伴结伴前行,共同探索街道房屋号码识别的奥秘。

想象一下,你置身于繁忙的街道上,满目琳琅的房屋号码挑战着你的视力。而你却拥有了一种神奇的眼力,能轻松识别出每一个号码。这种超凡能力正是计算机视觉任务的魔法所在。

我们要携手PyTorch这位强大的工具,它如同一把巧妙的魔法棒,能帮助我们构建强大的神经网络模型。通过PyTorch,我们可以灵活地定义模型的结构,设置各种参数,并进行高效的训练和推理。

我们遇到了注意力机制,就像是一盏明亮的灯塔,照亮了我们前进的方向。注意力机制能够使神经网络集中注意力于图像中的重要区域,从而提高识别的准确性。利用这种机制,我们可以让模型更加聪明地注重街道房屋号码所在的位置和细节,从而更好地进行识别。而SVHN数据集则是我们探险的指南,其中包含了大量真实世界中的街道房屋号码图像。通过导入这些数据,我们可以让模型从中学习并提高自己的识别能力。这些图像将带领我们穿越城市的角落,感受不同场景下的挑战和变化。通过这篇文章,我们不仅可以更深入地理解计算机视觉任务的本质,还能获得启发。就像是一次奇妙的冒险,我们将学会如何使用PyTorch和注意力机制来实现街道房屋号码的识别任务。让我们一起跟随这个引人入胜的旅程,开拓视野,追寻新的可能性吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/81499.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Quivr 基于GPT和开源LLMs构建本地知识库 (更新篇)

一、前言 自从大模型被炒的越来越火之后,似乎国内涌现出很多希望基于大模型构建本地知识库的需求,大概在5月底的时候,当时Quivr发布了第一个0.0.1版本,第一个版本仅仅只是使用LangChain技术结合OpenAI的GPT模型实现了一个最基本的…

【搭建WebDAV服务手机ES文件浏览器远程访问】

文章目录 1. 安装启用WebDAV2. 安装cpolar3. 配置公网访问地址4. 公网测试连接5. 固定连接公网地址6. 使用固定地址测试连接 有时候我们想通过移动设备访问群晖NAS 中的文件,以满足特殊需求,我们在群辉中开启WebDav服务,结合cpolar内网工具生成的公网地址,通过移动客户端ES文件…

在ubuntu中将dict.txt导入到数据库sqlite3

将dict.txt导入到数据库 #include <head.h> #include <sqlite3.h> int do_insert(int i,char *str,sqlite3 *db); int main(int argc, const char *argv[]) {//创建泵打开一个数据库sqlite3 *db NULL;if(sqlite3_open("./my.db",&db) ! SQLITE_OK){…

G0第26章:微服务概述与gRPCprotocol buffers

Go微服务与云原生 1、微服务架构介绍 单体架构&#xff08;电商&#xff09; SOA架构&#xff08;电商&#xff09; 微服务架构&#xff08;电商&#xff09; 优势 挑战 拆分 发展史 第一代:基于RPC的传统服务架构 第二代:Service Mesh(istio) 微服务架构分层 核心组件 Summar…

【3D激光SLAM】LOAM源代码解析--scanRegistration.cpp

系列文章目录 【3D激光SLAM】LOAM源代码解析–scanRegistration.cpp 写在前面 本系列文章将对LOAM源代码进行讲解&#xff0c;在讲解过程中&#xff0c;涉及到论文中提到的部分&#xff0c;会结合论文以及我自己的理解进行解读&#xff0c;尤其是对于其中坐标变换的部分&…

回归预测 | MATLAB实现TSO-SVM金枪鱼群算法优化支持向量机多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现TSO-SVM金枪鱼群算法优化支持向量机多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现TSO-SVM金枪鱼群算法优化支持向量机多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09;效果一览基…

springcloud3 hystrix实现服务监控显示3(了解)

一 hystrix的服务监控调用 1.1 hystrix的服务监控调用 hystrix提供了准实时的监控调用&#xff08;hystrix dashbord&#xff09;&#xff0c;Hystrix 会持续的记录所有通过hystrix发送的请求的执行信息&#xff0c;并以统计报表和图形的形式展示给用户&#xff0c;包括每秒执…

不是说嵌入式是风口吗,那为什么工作还那么难找?

最近确实有很多媒体、机构渲染嵌入式可以拿高薪&#xff0c;这在行业内也是事实&#xff0c;但前提是你有足够的竞争力&#xff0c;真的懂嵌入式。 时至今日&#xff0c;能做嵌入式程序开发的人其实相当常见&#xff0c;尤其是随着树莓派、Arduino等开发板的普及&#xff0c;甚…

【Java】智慧工地SaaS平台源码:AI/云计算/物联网/智慧监管

智慧工地是指运用信息化手段&#xff0c;围绕施工过程管理&#xff0c;建立互联协同、智能生产、科学管理的施工项目信息化生态圈&#xff0c;并将此数据在虚拟现实环境下与物联网采集到的工程信息进行数据挖掘分析&#xff0c;提供过程趋势预测及专家预案&#xff0c;实现工程…

互联网发展历程:保护与隔离,防火墙的安全壁垒

互联网的快速发展&#xff0c;不仅带来了便利和连接&#xff0c;也引发了越来越多的安全威胁。在数字时代&#xff0c;保护数据和网络安全变得尤为重要。然而&#xff0c;在早期的网络中&#xff0c;安全问题常常让人担忧。 安全问题的困扰&#xff1a;网络威胁日益增加 随着互…

分布式websocket解决方案

1、websocket问题由来 websocket基础请自行学习,本文章是解决在分布式环境下websocket通讯问题。 在单体环境下,所有web客户端都是连接到某一个微服务上,这样消息都是到达统一服务端,并且也是由一个服务端进行响应,所以不会出现问题。 但是在分布式环境下,我们很容易发现…

Postgresql源码(112)plpgsql执行sql时变量何时替换为值

相关 《Postgresql源码&#xff08;41&#xff09;plpgsql函数编译执行流程分析》 《Postgresql源码&#xff08;46&#xff09;plpgsql中的变量类型及对应关系》 《Postgresql源码&#xff08;49&#xff09;plpgsql函数编译执行流程分析总结》 《Postgresql源码&#xff08;5…

Android AppCompatActivity标题栏操作

使用 AndroidStudio 新建的工程默认用 AppCompatActivity &#xff0c;是带标题栏的。 记录下 修改标题栏名称 和 隐藏标题栏 的方法。 修改标题栏名称 Override protected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);setContentView(R…

Eureka注册中心

全部流程 注册服务中心 添加maven依赖 <!--引用注册中心--> <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-netflix-eureka-server</artifactId> </dependency> 配置Eureka 因为自…

W6100-EVB-PICO 做UDP Client 进行数据回环测试(八)

前言 上一章我们用开发板作为UDP Server进行数据回环测试&#xff0c;本章我们让我们的开发板作为UDP Client进行数据回环测试。 连接方式 使开发板和我们的电脑处于同一网段&#xff1a; 开发板通过交叉线直连主机开发板和主机都接在路由器LAN口 测试工具 网路调试工具&a…

prompt-engineering-note(面向开发者的ChatGPT提问工程学习笔记)

介绍&#xff1a; ChatGPT Prompt Engineering Learning Notesfor Developers (面向开发者的ChatGPT提问工程学习笔记) 课程简单介绍了语言模型的工作原理&#xff0c;提供了最佳的提示工程实践&#xff0c;并展示了如何将语言模型 API 应用于各种任务的应用程序中。 此外&am…

idea gerrit 插件使用指引

IDEA安装gerrit插件 在线安装&#xff08;推荐&#xff09; 直接搜索gerrit&#xff0c;安装即可离线安装 可以到github下载离线包&#xff1a;https://github.com/uwolfer/gerrit-intellij-plugin/releases&#xff0c;不过可能会有版本不兼容问题&#xff0c;还是推荐在线安装…

CSAPP Lab2:Bomb Lab

说明 6关卡&#xff0c;每个关卡需要输入相应的内容&#xff0c;通过逆向工程来获取对应关卡的通过条件 准备工作 环境 需要用到gdb调试器 apt-get install gdb系统: Ubuntu 22.04 本实验会用到的gdb调试器的指令如下 r或者 run或者run filename 运行程序,run filename就…

AIGC绘画:基于Stable Diffusion进行AI绘图

文章目录 AIGC深度学习模型绘画系统stable diffusion简介stable diffusion应用现状在线网站云端部署本地部署Stable Diffusion AIGC深度学习模型绘画系统 stable diffusion简介 Stable Diffusion是2022年发布的深度学习文本到图像生成模型&#xff0c;它主要用于根据文本的描述…

基于互斥锁的生产者消费者模型

文章目录 生产者消费者 定义代码实现 / 思路完整代码执行逻辑 / 思路 局部具体分析model.ccfunc&#xff08;消费者线程&#xff09; 执行结果 生产者消费者 定义 生产者消费者模型 是一种常用的 并发编程模型 &#xff0c;用于解决多线程或多进程环境下的协作问题。该模型包含…