基于卷积神经网络的高光谱图像分类详细教程(含python代码)

目录

一、背景

二、基于卷积神经网络的代码实现

1)建立卷积神经网络模型

2)训练函数代码

3)全图可视化

三、项目代码


一、背景

卷积神经网络(Convolutional Neural Networks, CNNs)在处理高光谱图像分类任务时,展现出了卓越的性能。高光谱图像因其丰富的光谱信息而具有极高的空间分辨率,能够捕捉到物体在不同波段上的细微差异。CNNs通过其特有的卷积层、池化层和全连接层结构,能够自动提取图像中的空间和光谱特征,有效地捕捉到高光谱数据中的复杂模式。在卷积层中,局部感受野和权值共享机制使得网络能够学习到图像的局部特征,而池化层则有助于减少参数数量,提高特征的不变性。通过层层堆叠的非线性变换,CNNs能够构建出深层次的特征表示,从而实现对高光谱图像中不同地物类别的高精度分类。此外,通过引入迁移学习、数据增强和正则化等技术,可以进一步提高模型的泛化能力和分类性能,使得卷积神经网络成为高光谱图像分析领域的一项重要工具。

深度学习是机器学习的一个分支,它通过构建多层的神经网络来模拟人脑处理信息的方式,从而实现对复杂数据的高效处理和模式识别。在深度学习中,2D卷积是一种核心操作,尤其在图像处理和计算机视觉领域中扮演着至关重要的角色。

2D卷积操作涉及一个卷积核(或滤波器),它在输入数据(如图像)上滑动,计算卷积核与输入数据局部区域的点积,从而生成输出特征图。这个过程可以捕捉到输入数据的空间层次结构,即从低级特征(如边缘和角点)到高级特征(如纹理和对象部分)的逐步抽象。

二、基于卷积神经网络的代码实现

下面我们以IP数据集为例子进行展开讲解。

1)建立卷积神经网络模型

import torch.nn as nn

class CNN_2D(nn.Module):
    def __init__(self, num_classes):
        super(CNN_2D, self).__init__()
        # 2D卷积块
        self.block_2D = nn.Sequential(
            nn.Conv2d(
                in_channels=30,out_channels=64,kernel_size=(3, 3),stride=(2,2),padding=(1,1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(3, 3),stride=(2,2),padding=(1,1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(3, 3),stride=(2,2),padding=(1,1)),
            nn.ReLU(inplace=True)
        )

        # 全连接层
        self.classifier = nn.Sequential(
            nn.Linear(in_features=16*256,out_features=512
            ),
            nn.Dropout(p=0.4),
            nn.Linear(in_features=512,out_features=256
            ),
            nn.Dropout(p=0.4),
            nn.Linear(in_features=256,out_features=num_classes))

    def forward(self, x):
        y = self.block_2D(x)
        y = y.reshape(y.shape[0], -1)
        y = self.classifier(y)
        return y

2)训练函数代码

1、导入相关包

首先导入训练需要用到的相关函数和包

from utils import applyPCA,createImageCubes,splitTrainTestSet,reports
import scipy.io as sio
from net import CNN_2D
import torch
import numpy as np
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import time

2、数据集加载

将下载的数据加载进内存,便于后续处理

X = sio.loadmat('./data/Indian_pines.mat')['indian_pines']
y = sio.loadmat('./data/Indian_pines_gt.mat')['indian_pines_gt']

3、PCA降维

由于数据的波段非常大,这里利用PCA进行数据降维。降维后数据维度得到减少,从220个波段降维到30个波段。

X_pca = applyPCA(X, numComponents=pca_components)

4、数据集的样本划分与标签分配

根据数据标签和数据,对其进行样本采样,并划分成训练集和验证集。这里以窗口为25的大小,训练集和测试集的占比分别为20%的训练,80%的验证。

X_pca, y = createImageCubes(X_pca, y, windowSize=patch_size)
Xtrain, Xtest, ytrain, ytest = splitTrainTestSet(X_pca, y, test_ratio)

5、转数据为torch张量

首先定义样本加载函数:

# 加载数据函数
class TrainDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = Xtrain.shape[0]
        self.x_data = torch.FloatTensor(Xtrain)
        self.y_data = torch.LongTensor(ytrain)
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index], self.y_data[index]
    def __len__(self):
        # 返回文件数据的数目
        return self.len

class TestDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = Xtest.shape[0]
        self.x_data = torch.FloatTensor(Xtest)
        self.y_data = torch.LongTensor(ytest)
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index], self.y_data[index]
    def __len__(self):
        # 返回文件数据的数目
        return self.len

然后将数据以批量的方式加载进来:

trainset = TrainDS()
testset  = TestDS()
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128, shuffle=True)
test_loader  = torch.utils.data.DataLoader(dataset=testset,  batch_size=128, shuffle=False)

6、自动选择GPU还是CPU训练

# 使用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 网络放到GPU上
net = CNN_2D(num_classes=class_num).to(device)

7、设置相关的损失、优化函数以及迭代次数

EPOCH = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

8、开始训练

训练过程中,保存验证集最优结果的权值,以便后续使用。

for epoch in range(EPOCH):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

9、绘制损失和准确率曲线

将训练过程中的损失和准确率保存,并进行绘制,查看他们之间的关系。

10、对验证集进行评估,获取最终结果

classification, confusion, oa, each_acc, aa, kappa, names = reports(ytest, y_pred_test)
print(classification)
print("混淆矩阵\n",confusion)
print("kapap:", kappa)
print("aa:", aa)
print("oa:", oa)
print("训练时间:",train_time_1-train_time_0,"验证时间:",test_time_1-test_time_0)

3)全图可视化

采用逐点预测的方式

for i in range(height):
    for j in range(width):
        image_patch = X[i:i+patch_size, j:j+patch_size, :]
        image_patch = image_patch.reshape(1,image_patch.shape[0],image_patch.shape[1], image_patch.shape[2])
        X_test_image = torch.FloatTensor(image_patch.transpose(0, 3, 1, 2)).to(device)
        prediction = net(X_test_image)
        prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
        outputs[i][j] = prediction+1

预测结果比较不错,精度99%以上了。

三、项目代码

本项目的代码通过以下链接下载:基于卷积神经网络的高光谱图像分类代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/602095.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mask RCNN(Mask_RCNN-master)简单部署

一.注意事项 1.本文主要是引用大佬的文章(侵权请联系,马上删除),做的工作为简单补充 二.介绍 ①简介: Mask R-CNN(Mask Region-based Convolutional Neural Network)是一种用于目标检测和语义…

两个手机在一起ip地址一样吗?两个手机是不是两个ip地址

在数字时代的浩瀚海洋中,手机已经成为我们生活中不可或缺的一部分。随着移动互联网的飞速发展,IP地址成为了连接手机与互联网的桥梁。那么,两个手机在一起IP地址一样吗?两个手机是不是两个IP地址?本文将带您一探究竟&a…

Apipost使用心得,让接口文档变得更清晰,更快捷

Idea和Apipost结合使用 Idea 安装插件Apipost-Helper-2.0 在【file】–>【settings】–>【Plugins】搜索 “Apipost-Helper-2.0”–>【install】,重启Idea 编写controller接口 在idea中编写业务功能及接口之后,在controller中鼠标【右键】单…

亚马逊Amazon商品详情和关键词搜索API接口分享

一、亚马逊Amazon商品详情API接口 亚马逊商品详情API接口是亚马逊平台为开发者提供的一项重要服务,它允许开发者通过程序调用API来获取亚马逊商品的相关数据。这个接口为获取商品数据提供了便利的途径,有助于用户进行商品搜索、商品分类以及数据分析等操…

Stable Diffusion基础:ControlNet之人体姿势控制

在AI绘画中精确控制图片是一件比较困难的事情,不过随着 ControlNet 的诞生,这一问题得到了很大的缓解。 今天我就给大家分享一个使用Stable Diffusion WebUI OpenPose ControlNet 复制照片人物姿势的方法,效果可以参考上图。 OpenPose 可以…

不得不聊的微服务Gateway

一、 什么是Gateway? 1.网关的由来 单体应用拆分成多个服务后,对外需要一个统一入口,解耦客户端与内部服务 2.网关的作用 Spring Cloud Gateway是Spring Cloud生态系统中的一员,它被设计用于处理所有微服务的入口流量。作为一…

Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels

文章目录 Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels摘要方法实验结果 Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels 摘要 Soft Dice Loss(SDL)在医学图像领域的许多自动分割中发挥了关键作用。在过…

【数据库原理及应用】期末复习汇总高校期末真题试卷07

试卷 一、填空题(每空1分,共10分) 1.数据库管理系统在外模式、模式和内模式这三级模式之间提供了两层映象,其中 映象保证了数据的逻辑独立性。 2. 数据模型通常由 、数据操作和完整性约束三部分组…

vue 文本中的\n 、<br>换行显示

一、背景&#xff1a; 后端接口返回数据以\n 作为换行符&#xff0c;前端显示时候需要换行显示&#xff1b; demo&#xff1a; <p style"white-space: pre-wrap;">{{ info }}</p>data() {return {info: 1、优化图片\n 2、 优化时间\n}},项目上&#…

通配符证书价格350元

通配符SSL证书是一种特殊的域名SSL证书&#xff0c;这款SSL证书默认保护主域名以及主域名下的所有子域名&#xff0c;因此&#xff0c;子域名比较多的个人或者企事业单位开发者都倾向于选择通配符SSL证书来简化SSL证书管理过程&#xff0c;节省购买SSL证书的资金&#xff0c;降…

前端如何设置div可滚动,且设置滚动条颜色

在前端中&#xff0c;设置 div 为可滚动并通过 CSS 自定义滚动条的颜色并不是所有浏览器都直接支持的功能&#xff0c;因为滚动条的样式在很大程度上取决于操作系统和浏览器的默认样式。然而&#xff0c;你可以使用某些 CSS 属性来尝试自定义滚动条的外观&#xff0c;这些属性在…

JavaEE概述 + Maven

文章目录 一、JavaEE 概述二、工具 --- Maven2.1 Maven功能 仓库 坐标2.2 Maven之项目构建2.3 Maven之依赖管理 三、插件 --- Maven Helper 一、JavaEE 概述 Java SE、JavaEE&#xff1a; Java SE&#xff1a;指Java标准版&#xff0c;适用于各行各业&#xff0c;主要是Java…

2024 Flutter 一季度热门 issue/roadmap 进展和个人感触闲聊

因为最近的《Flutter&#xff1a;听说你最近到处和人说我解散了&#xff1f;》相关事件之后&#xff0c;不少人对于目前 Flutter 的一些进度情况比较关心&#xff0c;刚好这里做一个简要汇总&#xff0c;报告几个过去一季度相关的热门 issue/roadmap 情况&#xff0c;另外这些天…

邮件群发系统的效率怎么样?如何评估性能?

邮件群发系统的使用方法&#xff1f;邮件群发工具的关键功能&#xff1f; 邮件群发系统已成为企业、组织及个人进行信息沟通的重要工具。然而&#xff0c;当我们谈论邮件群发系统的效率时&#xff0c;我们需要从多个维度来全面分析和评估。AokSend就来介绍一下。 邮件群发系统…

ReactFlow的ReactFlow实例事件传参undefined处理状态切换

1.问题 ReactFlow的ReactFlow实例有些事件我们在不同的状态下并不需要&#xff0c;而且有时候传参会出现其它渲染效果&#xff0c;比如只读状态下我们不想要拖拉拽onEdgesChange连线重连或删除的功能。 2.思路 事件名称类型默认值onEdgesChange(changes: EdgeChange[]) >…

AI大模型探索之路-训练篇17:大语言模型预训练-微调技术之QLoRA

系列篇章&#x1f4a5; AI大模型探索之路-训练篇1&#xff1a;大语言模型微调基础认知 AI大模型探索之路-训练篇2&#xff1a;大语言模型预训练基础认知 AI大模型探索之路-训练篇3&#xff1a;大语言模型全景解读 AI大模型探索之路-训练篇4&#xff1a;大语言模型训练数据集概…

浅谈消息队列和云存储

1970年代末&#xff0c;消息系统用于管理多主机的打印作业&#xff0c;这种削峰解耦的能力逐渐被标准化为“点对点模型”和稍复杂的“发布订阅模型”&#xff0c;实现了数据处理的分布式协同。随着时代的发展&#xff0c;Kafka&#xff0c;Amazon SQS&#xff0c;RocketMQ&…

基于大数据+Hadoop的豆瓣电子图书推荐系统实现

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 系列文章目录 基于大数…

组合模式(Composite)——结构型模式

组合模式(Composite)——结构型模式 组合模式是一种结构型设计模式&#xff0c; 你可以使用它将对象组合成树状结构&#xff0c; 并且能通过通用接口像独立整体对象一样使用它们。如果应用的核心模型能用树状结构表示&#xff0c; 在应用中使用组合模式才有价值。 例如一个场景…

新能源汽车充电站智慧充电电能服务综合解决方案

安科瑞薛瑶瑶18701709087/17343930412 ★解决方案 ✔目的地充电-EMS微电网平台 基于EMS解决方案从设备运维的角度解决本地充电的能量管理及运维问题&#xff0c;与充电管理平台打通数据&#xff0c;为企业微电网提供源、网、荷、储、充一体化解决方案。 ✔运营场站--电能服务…