改进系列(5):在ResNet网络添加SelfAttention自注意力层实现的遥感卫星下的土地利用情况图像分类

目录

1. ResNet介绍

2. SelfAttention 层

3. ResNet34 + SelfAttention

4. 遥感卫星下的土地使用情况分类

4.1 土地使用情况数据集 

4.2 训练

4.3 训练结果

4.4 推理


1. ResNet介绍

ResNet(残差网络)是一种深度卷积神经网络模型,由Kaiming He等人于2015年提出。它的提出解决了深度神经网络的梯度消失和梯度爆炸问题,使得深层网络的训练变得更加容易和有效。

在深度神经网络中,随着网络层数的增加,梯度在反向传播过程中逐渐变小,导致网络的训练变得困难。这是因为在传统的网络结构中,每个网络层都是通过直接逐层堆叠来进行信息的传递。当网络层数增加时,信息的传递路径变得更长,导致梯度逐渐消失。为了解决这个问题,ResNet提出了“残差学习”的概念。

ResNet引入了“残差块”(residual block)的概念,其中每个残差块包含一个跳跃连接(skip connection),将输入直接添加到输出中。这个跳跃连接允许梯度直接通过残差块传递,避免了梯度的消失问题。通过残差块的堆叠,ResNet可以构建非常深的网络,如ResNet-50、ResNet-101等。

ResNet的提出极大地促进了深度神经网络的发展。它在多个视觉任务上取得了非常好的性能,成为了目标检测、图像分类、图像分割等领域的重要基准模型。同时,ResNet的思想也影响了后续的深度神经网络架构设计,被广泛应用于各种深度学习任务中。

2. SelfAttention 层

自注意机制基于Vaswani等人在2017年提出的变压器架构。它计算所有输入单词的嵌入加权和,其中权重由每个单词与序列中其他单词的相关性决定。这些权重是通过嵌入之间的一系列点积运算计算的,然后是一个softmax函数来归一化权重。

与传统的序列模型相比,自注意机制有几个优点。它允许模型更有效地捕获长距离依赖关系,因为序列中任何单词的信息都可以直接影响任何其他单词的表示。它还支持并行计算,因为可以为每个单词独立计算注意力权重。这使得自我关注模型高效且可扩展。

python 实现的代码如下:

# 定义自注意力层
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = torch.softmax(energy, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x
        return out

3. ResNet34 + SelfAttention

这里只对resnet34做了添加,事实上其他版本的resnet网络添加自注意力机制是一样的,只需要把resnet34换成52、101之类的即可

关键代码如下:

添加后的效果如下:

4. 遥感卫星下的土地使用情况分类

下载链接在下面:

Resnet网络改进实战(添加SelfAttention自注意力机制):遥感卫星下的土地利用图像分类资源-CSDN文库

解压后的完整目录如下,data是数据集,runs是训练好的结果

4.1 土地使用情况数据集 

总共有21类别,分别放在不同的目录下,训练集有1470张图片,验证集有630张数据

标签类别如下:

{
    "0": "agricultural",
    "1": "airplane",
    "2": "baseballdiamond",
    "3": "beach",
    "4": "buildings",
    "5": "chaparral",
    "6": "denseresidential",
    "7": "forest",
    "8": "freeway",
    "9": "golfcourse",
    "10": "harbor",
    "11": "intersection",
    "12": "mediumresidential",
    "13": "mobilehomepark",
    "14": "overpass",
    "15": "parkinglot",
    "16": "river",
    "17": "runway",
    "18": "sparseresidential",
    "19": "storagetanks",
    "20": "tenniscourt"
}

可视化结果:

4.2 训练

这里训练了30个epoch,参数如下:

    "train parameters": {
        "model": "resnet34",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 8,
        "epochs": 30,
        "optim": "SGD",
        "lr": 0.001,
        "lrf": 0.0001
    },
    "Datasets": {
        "trainSets number": 1470,
        "validSets number": 630
    },
    "model": {
        "total parameters": 21731845.0,
        "train parameters": 621001,
        "flops": 3742463488.0
    },

想要更改训练超参数的可以在train脚本更改

4.3 训练结果

这里最后一轮的指标如下:

    "epoch:29": {
        "train info": {
            "accuracy": 0.9836734693810635,
            "agricultural": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "airplane": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "baseballdiamond": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "beach": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "buildings": {
                "Precision": 0.9857,
                "Recall": 0.9857,
                "Specificity": 0.9993,
                "F1 score": 0.9857
            },
            "chaparral": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "denseresidential": {
                "Precision": 0.9286,
                "Recall": 0.9286,
                "Specificity": 0.9964,
                "F1 score": 0.9286
            },
            "forest": {
                "Precision": 0.9722,
                "Recall": 1.0,
                "Specificity": 0.9986,
                "F1 score": 0.9859
            },
            "freeway": {
                "Precision": 0.971,
                "Recall": 0.9571,
                "Specificity": 0.9986,
                "F1 score": 0.964
            },
            "golfcourse": {
                "Precision": 0.9853,
                "Recall": 0.9571,
                "Specificity": 0.9993,
                "F1 score": 0.971
            },
            "harbor": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "intersection": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "mediumresidential": {
                "Precision": 0.9559,
                "Recall": 0.9286,
                "Specificity": 0.9979,
                "F1 score": 0.9421
            },
            "mobilehomepark": {
                "Precision": 0.9718,
                "Recall": 0.9857,
                "Specificity": 0.9986,
                "F1 score": 0.9787
            },
            "overpass": {
                "Precision": 0.9577,
                "Recall": 0.9714,
                "Specificity": 0.9979,
                "F1 score": 0.9645
            },
            "parkinglot": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "river": {
                "Precision": 0.9718,
                "Recall": 0.9857,
                "Specificity": 0.9986,
                "F1 score": 0.9787
            },
            "runway": {
                "Precision": 0.9722,
                "Recall": 1.0,
                "Specificity": 0.9986,
                "F1 score": 0.9859
            },
            "sparseresidential": {
                "Precision": 0.9859,
                "Recall": 1.0,
                "Specificity": 0.9993,
                "F1 score": 0.9929
            },
            "storagetanks": {
                "Precision": 1.0,
                "Recall": 0.9857,
                "Specificity": 1.0,
                "F1 score": 0.9928
            },
            "tenniscourt": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "mean precision": 0.9837190476190478,
            "mean recall": 0.9836666666666668,
            "mean specificity": 0.9991952380952381,
            "mean f1 score": 0.9836380952380953
        },
        "valid info": {
            "accuracy": 0.8571428571292516,
            "agricultural": {
                "Precision": 0.8437,
                "Recall": 0.9,
                "Specificity": 0.9917,
                "F1 score": 0.8709
            },
            "airplane": {
                "Precision": 1.0,
                "Recall": 0.9667,
                "Specificity": 1.0,
                "F1 score": 0.9831
            },
            "baseballdiamond": {
                "Precision": 0.8529,
                "Recall": 0.9667,
                "Specificity": 0.9917,
                "F1 score": 0.9062
            },
            "beach": {
                "Precision": 0.7692,
                "Recall": 1.0,
                "Specificity": 0.985,
                "F1 score": 0.8695
            },
            "buildings": {
                "Precision": 0.7714,
                "Recall": 0.9,
                "Specificity": 0.9867,
                "F1 score": 0.8308
            },
            "chaparral": {
                "Precision": 0.9062,
                "Recall": 0.9667,
                "Specificity": 0.995,
                "F1 score": 0.9355
            },
            "denseresidential": {
                "Precision": 0.72,
                "Recall": 0.6,
                "Specificity": 0.9883,
                "F1 score": 0.6545
            },
            "forest": {
                "Precision": 0.8788,
                "Recall": 0.9667,
                "Specificity": 0.9933,
                "F1 score": 0.9207
            },
            "freeway": {
                "Precision": 0.7241,
                "Recall": 0.7,
                "Specificity": 0.9867,
                "F1 score": 0.7118
            },
            "golfcourse": {
                "Precision": 0.8387,
                "Recall": 0.8667,
                "Specificity": 0.9917,
                "F1 score": 0.8525
            },
            "harbor": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "intersection": {
                "Precision": 0.8889,
                "Recall": 0.8,
                "Specificity": 0.995,
                "F1 score": 0.8421
            },
            "mediumresidential": {
                "Precision": 0.8077,
                "Recall": 0.7,
                "Specificity": 0.9917,
                "F1 score": 0.75
            },
            "mobilehomepark": {
                "Precision": 0.8437,
                "Recall": 0.9,
                "Specificity": 0.9917,
                "F1 score": 0.8709
            },
            "overpass": {
                "Precision": 0.6897,
                "Recall": 0.6667,
                "Specificity": 0.985,
                "F1 score": 0.678
            },
            "parkinglot": {
                "Precision": 0.9355,
                "Recall": 0.9667,
                "Specificity": 0.9967,
                "F1 score": 0.9508
            },
            "river": {
                "Precision": 0.9,
                "Recall": 0.6,
                "Specificity": 0.9967,
                "F1 score": 0.72
            },
            "runway": {
                "Precision": 0.8571,
                "Recall": 1.0,
                "Specificity": 0.9917,
                "F1 score": 0.9231
            },
            "sparseresidential": {
                "Precision": 0.9,
                "Recall": 0.9,
                "Specificity": 0.995,
                "F1 score": 0.9
            },
            "storagetanks": {
                "Precision": 0.92,
                "Recall": 0.7667,
                "Specificity": 0.9967,
                "F1 score": 0.8364
            },
            "tenniscourt": {
                "Precision": 1.0,
                "Recall": 0.8667,
                "Specificity": 1.0,
                "F1 score": 0.9286
            },
            "mean precision": 0.8594095238095237,
            "mean recall": 0.857157142857143,
            "mean specificity": 0.9928714285714286,
            "mean f1 score": 0.8540666666666668
        }
    }

曲线图:

混淆矩阵:

 

4.4 推理

推理结果如下:

想要更换数据集训练的话,参考readme文件即可

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/937313.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ARM循环程序和子程序设计

1、计算下列两组数据的累加和并存入到sum1和 sum2 单元中。datal:0x12,0x935,0x17,0x100,0x95,0x345。 data2:0x357,0x778,0x129,0x188,0x190,0x155,0x167。 1.定义数据段 ;定义数据段,类型为data(表示为数据段),权限为可读可写(程序可以读取和修改这…

蓝桥杯刷题——day5

蓝桥杯刷题——day5 题目一题干解题思路一代码解题思路二代码 题目二题干解题思路代码 题目一 题干 给定n个整数 a1,a2,⋯ ,an,求它们两两相乘再相加的和,即: 示例一: 输入: 4 1 3 6 9 输出: 117 题目链…

1_linux系统网络性能如何优化——几种开源网络协议栈比较

之前合集《计算机网络从入门到放弃》第一阶段算是已经完成了。都是理论,没有实操,让“程序猿”很难受,操作性不如 Modbus发送的报文何时等到应答和 tcp通信测试报告单1——connect和send。开始是想看linux内核网络协议栈的源码,然…

opencv——识别图片颜色并绘制轮廓

图像边缘检测 本实验要用到Canny算法,Canny边缘检测方法常被誉为边缘检测的最优方法。 首先,Canny算法的输入端应为图像的二值化结果,接收到二值化图像后,需要按照如下步骤进行: 高斯滤波。计算图像的梯度和方向。非极…

Edge SCDN深度解析,边缘安全加速的创新实践

边缘安全加速(Edge Secure Content Delivery Network,SCDN)是酷盾安全推出的边缘集分布式 DDoS 防护、CC 防护、WAF 防护、BOT 行为分析为一体的安全加速解决方案。通过边缘缓存技术,智能调度使用户就近获取所需内容,为…

软考高级架构 —— 10.6 大型网站系统架构演化实例 + 软件架构维护

10.6 大型网站系统架构演化实例 大型网站的技术挑战主要来自于庞大的用户,高并发的访问和海量的数据,主要解决这类问题。 1. 单体架构 特点: 所有资源(应用程序、数据库、文件)集中在一台服务器上。适用场景: 小型网站&am…

Golang囊地鼠gopher

开发知识点-golang 介绍红队专题-Golang工具Fscan简介主要功能ubuntu 安装windows 安装常用命令:项目框架源文件common目录Plugins目录Webscan目录入口点插件扫描类型爆破插件common.ScantypeWebtitle函数webpoc扫描POC 执行CEL-GO 实践CEL指纹识别免杀源码特征参考链接红队专…

多分类交叉熵与稀疏分类交叉熵

总结: 标签为 One-hot 编码的多分类问题,用分类交叉熵对于标签为整数的多分类问题,用稀疏分类交叉熵稀疏分类交叉熵内部会将整数标签转换为 One-hot 编码,而如果标签已经是 One-hot 编码的形式,再使用稀疏分类交叉熵就会多此一举。 算例 假设我们有三个类别:A、B 和 C。…

360极速浏览器不支持看PDF

360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…

零基础微信小程序开发——小程序的宿主环境(保姆级教程+超详细)

🎥 作者简介: CSDN\阿里云\腾讯云\华为云开发社区优质创作者,专注分享大数据、Python、数据库、人工智能等领域的优质内容 🌸个人主页: 长风清留杨的博客 🍃形式准则: 无论成就大小,…

麒麟信安推出支持信创PC的新一代云桌面方案,助力政务信创高效安全运维

12月11日,在第二届国家新一代自主安全计算系统产业集群融通生态大会上,麒麟信安发布了支持信创PC的新一代云桌面方案,该方案是基于国际TCI架构实现国产PC机云化纳管在国内的首次发布,并与银河麒麟桌面操作系统、长城国产PC整机实现…

vim优化

1.编辑如下内容&#xff1a; cat > /root/.vimrc <<EOF set tabstop2 " 设置 Tab 为 2 个空格 set shiftwidth2 " 设置自动缩进为 2 个空格 set expandtab " 将 Tab 转换为空格 " 基本设置 set number syntax on" 快捷键设置…

hive—常用的日期函数

目录 1、current_date 当前日期 2、now() 或 current_timestamp() 当前时间 3、datediff(endDate, startDate) 计算日期相差天数 4、months_between(endDate, startDate) 日期相差月数 5、date_add(startDate, numDays) 日期加N天 6、date_sub(startDate, numDays) 日期减…

医学分割数据集肾结石分割数据集labelme格式359张1类别

数据集格式&#xff1a;labelme格式(不包含mask文件&#xff0c;仅仅包含jpg图片和对应的json文件) 图片数量(jpg文件个数)&#xff1a;359 标注数量(json文件个数)&#xff1a;359 标注类别数&#xff1a;1 标注类别名称:["kidney stone"] 每个类别标注的框数&…

Vulnstack红日安全内网域渗透靶场2实战攻略

一&#xff1a;环境搭建 新增的网卡VMnet2&#xff0c;ip调成10段。 PC配置如下&#xff1a; DC在该环境中充当是域控。DC配置如下 &#xff1a; WEB配置&#xff1a;需要两块网卡&#xff0c;相当于网关服务器。 作者把外网网段都写成了192.168.111.1/24&#xff0c;我们可以…

基础库urllib的使用

学习爬虫&#xff0c;其基本的操作便是模拟浏览器向服务器发出请求&#xff0c;那么我们需要从哪个地方做起呢?请求需要我们自己构造吗?我们需要关心请求这个数据结构怎么实现吗?需要了解 HTTP、TCP、IP层的网络传输通信吗?需要知道服务器如何响应以及响应的原理吗? 可能…

【大数据技术基础】【记录Ubuntu 16.04升级到18.04】Ubuntu的一个版本升级到另一个版本

在 Ubuntu 操作系统中进行软件更新和系统升级 Ubuntu Kylin 16.04 LTS 系统进行系统升级到 Ubuntu 18.04.6 LTS 版本 升级提示&#xff1a;系统弹出提示框&#xff0c;告知用户有新版本的 Ubuntu 可用&#xff0c;询问用户是否想要升级。 认证窗口&#xff1a;显示了一个认证…

【环境搭建】Python、PyTorch与cuda的版本对应表

一个愿意伫立在巨人肩膀上的农民...... 在深度学习的世界里&#xff0c;选择合适的工具版本是项目成功的关键。CUDA、PyTorch和Python作为深度学习的三大支柱&#xff0c;它们的版本匹配问题不容忽视。错误的版本组合可能导致兼容性问题、性能下降甚至项目失败。因此&#xff0…

SpringBoot 学习

SpringBoot 学习 什么是 Springboot Spring Boot 是 Spring 提供的一个子项目&#xff0c;用于快速构建 Spring 应用程序 传统的问题&#xff1a; 导入依赖繁琐项目配置繁琐 SpringBoot 的特性 起步依赖&#xff1a;整合所有 web 的依赖配置好了自动配置&#xff1a;bean…

计算机网络-传输层 TCP协议(下)

目录 TCP核心机制五&#xff1a;流量控制 TCP核心机制六&#xff1a;拥塞控制 TCP核心机制七&#xff1a;延迟应答 TCP核心机制八&#xff1a;捎带应答 TCP核心机制九&#xff1a;面向字节流 粘包问题 TCP核心机制十&#xff1a;异常处理 &#xff08;1&#xff09;进程…