SENet实现遥感影像场景分类

今天我们分享SENet实现遥感影像场景分类。

数据集

本次实验我们使用的是NWPU-RESISC45 Dataset。NWPU Dataset 是一个遥感影像数据集,其中 NWPU-RESISC45 Dataset 是由西北工业大学创建的遥感图像场景分类可用基准,该数据集包含像素大小为 256*256 共计 31500 张图像,涵盖 45 个场景类别,其中每个类别有 700 张图像。

这 45 个场景类别包括飞机、机场、棒球场、篮球场、海滩、桥梁、丛林、教堂、圆形农田、云、商业区、密集住宅、沙漠、森林、高速公路、高尔夫球场、地面田径、港口、工业地区、交叉口、岛、湖、草地、中型住宅、移动房屋公园、山、立交桥、宫、停车场、铁路、火车站、矩形农田、河、环形交通枢纽、跑道、海、船舶、雪山、稀疏住宅、体育场、储水箱、网球场、露台、火力发电站和湿地。

数据集划分

首先我们可以对数据集进行划分,按训练集、验证集、测试集比例7:1.5:1.5进行划分。

import os
import shutil
import random

# 设置数据集根目录
data_root = './datasets/NWPU-RESISC45'  

# 设置训练集、验证集、测试集的目录
train_dir = './datasets/train'
val_dir = './datasets/val'
test_dir = './datasets/test'

# 创建目录
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# 获取所有子文件夹列表
class_folders = sorted(os.listdir(data_root))

# 定义训练集、验证集、测试集比例
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

for class_folder in class_folders:
    class_path = os.path.join(data_root, class_folder)
    images = os.listdir(class_path)
    random.shuffle(images)  # 随机打乱顺序

    num_images = len(images)
    num_train = int(num_images * train_ratio)
    num_val = int(num_images * val_ratio)

    train_images = images[:num_train]
    val_images = images[num_train:num_train + num_val]
    test_images = images[num_train + num_val:]

    # 移动图像到对应目录
    for img in train_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(train_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

    for img in val_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(val_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

    for img in test_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(test_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

划分完毕后,数据集分别保存在train、val、test三个文件夹内。每个文件夹内有21个子文件夹分别对应21类。

SENet

SeNet(Squeeze-and-Excitation Networks)是一种卷积神经网络(CNN)架构,由Jie Hu、Li Shen和Gang Sun于2017年提出。SeNet旨在通过引入注意力机制来增强模型对重要特征的学习能力,从而提高CNN在图像分类等计算机视觉任务上的性能。 SeNet的关键创新在于引入了“Squeeze-and-Excitation”模块,这个模块可以在不增加网络复杂度的情况下,自适应地学习特征通道之间的相关性,并对每个通道进行加权,以增强重要特征的表示。它由两个关键步骤组成: Squeeze(压缩)阶段:通过全局池化操作(通常是全局平均池化),将特征图的每个通道的信息进行汇总,生成通道级别的描述信息。 Excitation(激发)阶段:在Squeeze阶段生成的描述信息基础上,引入了多层感知机(MLP)结构来学习每个通道的权重。这些权重用于重新加权特征图,以增强有助于任务的重要特征并抑制不重要的特征。 SeNet模块可以轻松地集成到各种CNN架构中,例如ResNet、Inception等,通过在这些网络中插入SeNet模块,可以提高模型的性能,使其更具有泛化能力。 SeNet的提出在图像分类、目标检测和语义分割等计算机视觉任务中取得了显著的性能提升,并成为了当时领域内的重要技术之一。 alt

import torch.nn as nn
from torch.nn import functional as F


class Residual(nn.Module):
    def __init__(self, in_channel, out_channel, use_1x1Conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, stride=strides)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)

        if use_1x1Conv:
            self.conv3 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    def forward(self, X):
        out = F.relu(self.bn1(self.conv1(X)))
        out = self.bn2(self.conv2(out))
        if self.conv3:
            X = self.conv3(X)
        out += X
        return F.relu(out)


def residualBlock(in_channel, out_channel, num_residuals, first_block=False):
    blks = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blks.append(Residual(in_channel, out_channel, use_1x1Conv=True,
                                 strides=2))
        else:
            blks.append(Residual(out_channel, out_channel))

    return blks

class SEBlock(nn.Module):
    def __init__(self, C, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(C, C//r, bias=False),
            nn.ReLU(),
            nn.Linear(C//r, C, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        bs, c, _, _ = x.shape
        s = self.squeeze(x).view(bs, c)
        e = self.excitation(s).view(bs, c, 1, 1)
        return x * e.expand_as(x)

class SENet(nn.Module):
    def __init__(self, input_channel, n_classes):
        super().__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.b2 = nn.Sequential(*[SEBlock(C=64)])
        self.b3 = nn.Sequential(*residualBlock(64, 64, 2, first_block=True))
        self.b4 = nn.Sequential(*[SEBlock(C=64)])
        self.b5 = nn.Sequential(*residualBlock(64, 128, 2))
        self.b6 = nn.Sequential(*[SEBlock(C=128)])
        self.b7 = nn.Sequential(*residualBlock(128, 256, 2))
        self.b8 = nn.Sequential(*[SEBlock(C=256)])
        self.b9 = nn.Sequential(*residualBlock(256, 512, 2))
        self.b10 = nn.Sequential(*[SEBlock(C=512)])
        self.finalLayer = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(512, n_classes))

        self.b1.apply(self.init_weights)
        self.b2.apply(self.init_weights)
        self.b3.apply(self.init_weights)
        self.b4.apply(self.init_weights)
        self.b5.apply(self.init_weights)
        self.b6.apply(self.init_weights)
        self.b7.apply(self.init_weights)
        self.b8.apply(self.init_weights)
        self.b9.apply(self.init_weights)
        self.b10.apply(self.init_weights)
        self.finalLayer.apply(self.init_weights)

    def init_weights(self, layer):
        if type(layer) == nn.Conv2d:
            nn.init.kaiming_normal_(layer.weight, mode='fan_out')
        if type(layer) == nn.Linear:
            nn.init.normal_(layer.weight, std=1e-3)
        if type(layer) == nn.BatchNorm2d:
            nn.init.constant_(layer.weight, 1)
            nn.init.constant_(layer.bias, 0)
    

    def forward(self, X):
        out = self.b1(X)
        out = self.b2(out)
        out = self.b3(out)
        out = self.b4(out)
        out = self.b5(out)
        out = self.b6(out)
        out = self.b7(out)
        out = self.b8(out)
        out = self.b9(out)
        out = self.finalLayer(out)

        return out

训练过程

alt

精度与测试

「精度」

import torch
import torchvision.transforms as transforms
from torchvision import datasets
from models.SENet import SENet

# 定义测试集目录
test_dir = './datasets/test'

# 加载测试集数据
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 图像调整为模型输入大小
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_data = datasets.ImageFolder(root=test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
model = SENet(input_channel=3, n_classes=45).to(device)
model.load_state_dict(torch.load(f'SENet.pt', map_location='cuda:0'))

model.eval()

# 对测试集进行验证
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100*correct / total
print(f"Accuracy on test set: {accuracy}")
alt

「测试」 这里我们从测试集中选取几张图片并在我们的GUI界面中进行测试看看

1

总结

感兴趣的可以按文末方式,免费获取数据集、完整代码与训练结果

获取方法

如有需要,请关注微信公众号「DataAssassin」后,后台回复「027」领取。

更多更多内容与代码请加入我们的星球! alt 加入前不要忘了领取优惠券哦! alt

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/308943.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

1. seaborn-可视化统计关系

统计分析是了解数据集中的变量如何相互关联以及这些关系如何依赖于其他变量的过程。可视化是此过程的核心组件,这是因为当数据被恰当地可视化时,人的视觉系统可以看到指示关系的趋势和模式。 这里介绍三个seaborn函数。我们最常用的是relplot()。这是一…

golang实现加密解密文档

golang实现加密解密文档 package mainimport ("bytes""crypto/aes""crypto/cipher""crypto/rand""encoding/base64""flag""fmt""io""io/ioutil" )func main() {encodePtr : flag.…

阿赵UE学习笔记——8、贴图导入设置

阿赵UE学习笔记目录 大家好,我是阿赵。   继续学习虚幻引擎的用法,这次来说一下贴图的导入设置。   在内容浏览器里面可以看到纹理类型的资源,就是贴图了,鼠标悬浮在上面可以看到这个纹理贴图的信息: 双击纹理贴图…

Vue3技术解析(小册子)

随着 Vue 3 正式版本的发布,未来 Vue 3 将会成为前端的主流框架,这个毋庸置疑。Vue 3 在使用方面会兼容部分 Vue 2.x 的特性,比如 options API。 所以,究竟是要先学习 Vue 2 打好基础,还是直接学习 Vue 3 呢&#xff…

基于书生·浦语大模型应用开发范式介绍

文章目录 大模型应用开发范式LangChain简介构建向量数据库搭建知识库助手RAG方案优化建议 大模型应用开发范式 通用大模型的优势: 强大的语言理解、指令跟随、语言生成的能力可以理解用户自然语言的指令具有强大的知识储备和一定的逻辑推理能力。 通用大模型局限…

MongoDB快速实战与基本原理

MongoDB 介绍 什么是 MongoDB MongoDB 是一个文档数据库(以 JSON 为数据模型),由 C 语言编写,旨在为 WEB 应用提供可扩展的高性能数据存储解决方案。文档来自于“JSON Document”,并非我们一般理解的 PDF、WORD 文档…

构建安全可靠的系统:第二十一章到附录 A

第二十一章:建立安全和可靠性文化 原文:21. Building a Culture of Security and Reliability 译者:飞龙 协议:CC BY-NC-SA 4.0 作者:Heather Adkins 与 Peter Valchev,Felix Grbert,Ana Oprea…

计算机体系结构----重排序缓冲(ROB)

ROB的思想:不按顺序完成指令,但在使结果对体系结构状态可见之前重新排序 当指令被解码时,它会在 ROB 中保留下一个顺序条目当指令完成时,它将结果写入 ROB 条目当指令在 ROB 中最早并且无一例外地完成时,其结果移动到…

Java-布隆过滤器的实现

文章目录 前言一、概述二、误差率三、hash 函数的选择四、手写布隆过滤器五、guava 中的布隆过滤器 前言 如果想要判断一个元素是不是在一个集合里,一般想到的是将所有元素保存起来,然后通过比较确定。链表,树等等数据结构都是这种思路&…

LeetCode 145. 二叉树的后序遍历

145. 二叉树的后序遍历 给你一棵二叉树的根节点 root ,返回其节点值的 后序遍历 。 示例 1: 输入:root [1,null,2,3] 输出:[3,2,1]示例 2: 输入:root [] 输出:[]示例 3: 输入&…

Vue3:vue-cli项目创建及vue.config.js配置

一、node.js检测或安装: node -v node.js官方 二、vue-cli安装: npm install -g vue/cli # OR yarn global add vue/cli/*如果安装的时候报错,可以尝试一下方法 删除C:\Users**\AppData\Roaming下的npm和npm-cache文件夹 删除项目下的node…

C语言基础语法跟练 day2

题源&#xff1a;牛客网 16、BoBo写了一个十六进制整数ABCDEF&#xff0c;他问KiKi对应的十进制整数是多少。 #include <stdio.h>int main() { //创建变量char arr[] "ABCDEF";int i;int sum0,c; //依次转换十六进制为十进制for(i0; arr[i]!\0; i){char b …

每日学习更新(LQR+iLQR)

一直想更新一下根据cost to go来推导LQR&#xff0c;之前的话可能会直接套问题&#xff0c;但是对于理论有些困惑&#xff0c;正好最近在学习ilqr轨迹生成/优化&#xff0c;因此来推一下公式&#xff0c;以下参考B站Dr_CAN&#xff0c;链接如下&#xff1a; 【最优控制】5_线性…

记录汇川:H5U与Fctory IO测试6

主程序&#xff1a; 子程序: IO映射 子程序&#xff1a; 辅助上料 子程序&#xff1a; 自动程序 Fctory IO配置&#xff1a; 实际动作如下&#xff1a; Fctory IO测试6

软件测试工具Robot Framework如何安装

安装文件准备 表1 安装文件准备 Robot框架结构 为了更好的了解环境安装&#xff0c;我们先看下框架结构&#xff1a; 图1 Robot Framework Architecture Robot Framework 通过导入不同的库&#xff0c;就可以使用库中所提供的关键字&#xff0c;从而时行相关的测试。有几个标…

2023年全国职业院校技能大赛(高职组)“云计算应用”赛项赛卷②

2023年全国职业院校技能大赛&#xff08;高职组&#xff09; “云计算应用”赛项赛卷2 目录 需要竞赛软件包环境以及备赛资源可私信博主&#xff01;&#xff01;&#xff01; 2023年全国职业院校技能大赛&#xff08;高职组&#xff09; “云计算应用”赛项赛卷2 模块一 …

大众汽车宣布将ChatGPT,批量集成在多种汽车中!

1月9日&#xff0c;大众汽车在官网宣布&#xff0c;将ChatGPT批量集成到电动、内燃机汽车中。 大众表示&#xff0c;将ChatGPT与其IDA语音助手相结合&#xff0c;用户通过自然语言就能与ChatGPT进行互动&#xff0c;例如&#xff0c;帮我看看最近的三星米其林饭店在哪里&#…

Redis系列-15.Redis的IO多路复用原理解析

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱吃芝士的土豆倪&#xff0c;24届校招生Java选手&#xff0c;很高兴认识大家&#x1f4d5;系列专栏&#xff1a;Spring源码、JUC源码、Kafka原理、分布式技术原理、数据库技术&#x1f525;如果感觉博主的文章还不错的…

Java顺序表(1)

&#x1f435;本篇文章将对顺序表中的方法进行模拟实现 一、线性表 线性表是指在逻辑结构上呈连续的线性结构&#xff0c;而在物理结构上不一定是连续的结构&#xff0c;常见的线性表有&#xff1a;顺序表、链表、栈、队列等 二、顺序表 顺序表一般采用数组来存储数据&#x…

ELF文件格式解析二

使用objdump命令查看elf文件 objdump -x 查看elf文件所有头部的信息 所有的elf文件。 程序头部&#xff08;Program Header&#xff09;中&#xff0c;都以 PT_PHDR和PT_INTERP先开始。这两个段必须在所有可加载段项目的前面。 从上图中的INTERP段中&#xff0c;可以看到改段…