SSF-CNN:空间光谱融合的卷积光谱图像超分网络

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

  • SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION
    • 简介
    • 解决问题
    • 网络框架
    • 代码实现
    • 训练部分
    • 运行结果

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值,使其空间分辨率大小和多光谱一致,然后进行concat操作。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark——光谱图像超分基准-CSDN博客
Paper: IEEE
Code:https://github.com/miraclefan777/SSFCNN

2023-11-25_16-06-09

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

class SSFCNNnet(nn.Module):
    def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):
        super(SSFCNNnet, self).__init__()
        self.scale_factor = scale_factor
        self.pdconv = pdconv

        self.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)

        self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")
        if pdconv:
            self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")
        else:
            self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")

        self.relu = nn.ReLU(inplace=True)

    def forward(self, lr_hs, hr_ms):
        """
            :param lr_hs:LR-HSI低分辨率的高光谱图像
            :param hr_ms:高分辨率的多光谱图像
            :return:
        """
        # 对LR-HSI低分辨率图像进行上采样,让其分辨率更高
        lr_hs_up = self.Upsample(lr_hs)
        # 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接
        x = torch.cat((lr_hs_up, hr_ms), dim=1)

        x = self.relu(self.conv1(x))
        if self.pdconv:
            x = torch.cat((x, hr_ms), dim=1)
            x = self.relu(self.conv2(x))
            x = torch.cat((x, hr_ms), dim=1)
        else:
            x = self.relu(self.conv2(x))

        out = self.conv3(x)
        return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nn

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="SSFCNNnet")
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    assert args.model in ['SSFCNNnet', 'PDcon_SSF']

    outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    # 训练参数
    # loss_func = nn.L1Loss(reduction='mean').cuda()
    criterion = nn.MSELoss()


    #################数据集处理#################
    R = create_F()
    PSF = fspecial('gaussian', 8, 3)
    downsample_factor = 8
    training_size = 64
    stride = 32
    stride1 = 32

    train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)

    eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    #################数据集处理#################

    # 模型
    if args.model == 'SSFCNNnet':
        model = SSFCNNnet().cuda()
    else:
        model = SSFCNNnet(pdconv=True).cuda()

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    # 模型初始化
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    optimizer = torch.optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    start_epoch = 0


    for epoch in range(start_epoch, args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                label, lr_hs, hr_ms = data

                label = label.to(device)
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                lr = optimizer.param_groups[0]['lr']
                pred = model(hr_ms, lr_hs)
                loss = criterion(pred, label)

                epoch_losses.update(loss.item(), len(label))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))
                t.update(len(label))

        # torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

        if epoch % 5 == 0:
            model.eval()
            val_loss = AverageMeter()

            SAM = Loss_SAM()
            RMSE = Loss_RMSE()
            PSNR = Loss_PSNR()

            sam = AverageMeter()
            rmse = AverageMeter()
            psnr = AverageMeter()

            for data in eval_dataloader:
                label, lr_hs, hr_ms = data
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                label = label.cpu().numpy()

                with torch.no_grad():
                    preds = model(hr_ms, lr_hs).cpu().numpy()

                sam.update(SAM(preds, label), len(label))
                rmse.update(RMSE(preds, label), len(label))
                psnr.update(PSNR(preds, label), len(label))

            if psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = psnr.avg
                best_weights = copy.deepcopy(model.state_dict())

            print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/189402.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv5轻量化改进之MobileNetv3

目录 一、原理 二、代码 三、应用到YOLOv5 一、原理 我们提出了基于互补搜索技术和新颖架构设计相结合的下一代mobilenet。MobileNetV3通过硬件网络架构搜索(NAS)和NetAdapt算法的结合来调整到移动电话cpu,然后通过新的架构进步进行改进。本文开始探索自动搜索算法和网络设计…

5 个适用于 Windows 的顶级免费数据恢复软件

对于计算机来说,最重要的是用户数据。除了您的数据之外,有关计算机的其他所有内容都是可替换的。这三个是数据丢失的最常见原因: 文件/文件夹删除丢失分区分区损坏 文件/文件夹删除 文件/文件夹删除是最常见的数据丢失类型。大多数时候&am…

Matplotlib网格子图_Python数据分析与可视化

Matplotlib网格子图 plt.subplot()绘制子图调整子图之间的间隔plt.subplots创建网格 plt.subplot()绘制子图 若干彼此对齐的行列子图是常见的可视化任务,matplotlib拥有一些可以轻松创建它们的简便方法。最底层且最常用的方法是plt.subplot()。 这个函数在一个网格…

零基础学Linux内核:1、Linux源码组织架构

文章目录 前言一、Linux内核的特征二、Linux操作系统结构1.Linux在系统中的位置2.Linux内核的主要子系统3、Linux系统主要数据结构 三、linux内核源码组织1、下载Linux源码2、Linux版本号3、linux源码架构目录讲解 前言 这里将是我们从零开始学习Linux的第一节,这节…

【Kotlin】类与接口

文章目录 类的定义创建类的实例构造函数主构造函数次构造函数init语句块 数据类的定义数据类定义了componentN方法 继承AnyAny:非空类型的根类型Any?:所有类型的根类型 覆盖方法覆盖属性覆盖 抽象类接口:使用interface关键字函数:funUnit:让…

C++ 通过SQLite实现命令行工具

本文介绍了一个基于 C、SQLite 和 Boost 库的简单交互式数据库操作 Shell。该 Shell 允许用户通过命令行输入执行各种数据库操作,包括添加、删除主机信息,设置主机到特定主机组,以及显示主机和主机组列表。通过调用 SQLite3 库实现数据库连接…

适用于 Mac 和 Windows 的顶级U 盘数据恢复软件

由于意外删除或设备故障而丢失 USB 驱动器中的数据始终是一件令人压力很大的事情,检索该信息的最佳选择是使用优质数据恢复软件。为了让事情变得更容易,我们已经为您完成了所有研究并测试了工具,并且我们列出了最好的 USB 记忆棒恢复软件&…

elasticsearc DSL查询文档

文章目录 DSL查询文档DSL查询分类全文检索查询使用场景基本语法示例 精准查询term查询range查询总结 地理坐标查询矩形范围查询附近查询 复合查询相关性算分算分函数查询1)语法说明2)示例3)小结 布尔查询1)语法示例:2&…

基于C#实现Kruskal算法

这篇我们看看第二种生成树的 Kruskal 算法,这个算法的魅力在于我们可以打一下算法和数据结构的组合拳,很有意思的。 一、思想 若存在 M{0,1,2,3,4,5}这样 6 个节点,我们知道 Prim 算法构建生成树是从”顶点”这个角度来思考的,然…

车载电子电器架构 ——电子电气架构设计方案概述

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 注:本文1万多字,认证码字,认真看!!! 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证…

机器学习探索计划——KNN算法流程的简易了解

文章目录 数据准备阶段KNN预测的过程1.计算新样本与已知样本点的距离2.按照举例排序3.确定k值4.距离最近的k个点投票 scikit-learn中的KNN算法 数据准备阶段 import matplotlib.pyplot as plt import numpy as np# 样本特征 data_X [[0.5, 2],[1.8, 3],[3.9, 1],[4.7, 4],[6.…

FreeRTOS入门教程(任务通知)

文章目录 前言一、什么是任务通知二、任务通知和队列,信号量的区别三、任务通知的优点和缺点1.优点2.缺点 四、任务状态和通知值五、任务通知相关的函数发出通知取出通知 六、任务通知具体使用1.实现轻量级信号量二进制信号量计数型信号量 2.实现轻量级队列 总结 前…

数仓中数据清洗的方法

在数据采集的过程中,需要从不同渠道获取数据并汇集在数仓中,采集的原始数据首先需要进行解析,然后对不准确、不完整、不合理、格式、字符等不规范数据进行过滤清洗,清洗过的数据才能更加符合需求,从而使后续的数据分析…

【数据库】执行计划中二元操作对一趟扫描算法的应用,理解代价评估的应用和优化,除了磁盘代价还有CPU计算代价不容忽略

二元操作的一趟算法 ​专栏内容: 手写数据库toadb 本专栏主要介绍如何从零开发,开发的步骤,以及开发过程中的涉及的原理,遇到的问题等,让大家能跟上并且可以一起开发,让每个需要的人成为参与者。 本专栏会定…

Java中的异常语法知识居然这么好玩!后悔没有早点学习

学习异常后,发现异常的知识是多么的吸引人!不仅可以用来标记错误,还可以自己定义一个异常,用来实现自己想完成的业务逻辑,接下来一起去学习吧 目录 一、异常的概念及体系结构 1.异常的概念 2.异常的体系结构 3.异常…

【数据处理】 -- 【两分钟】了解【最好】的方式 -- 【正则表达式】

直接匹配; 普通字符 元匹配: . 任意单字符 r’表示单引号里字符为其特殊含义,比如.不是句号是匹配符的意思 *任意次数(换行结束) 一次及以上 {3,4}指定次数,至少3次,最多4次|{3}固定4次 [\d.]单个任意…

软件工程简明教程

软件工程简明教程 何为软件工程? 1968 年 NATO(北大西洋公约组织)提出了软件危机(Software crisis)一词。同年,为了解决软件危机问题,“软件工程”的概念诞生了。一门叫做软件工程的学科也就应…

redis运维(二十)redis 的扩展应用 lua(二)

一 redis 的扩展应用 lua redis lua脚本语法 ① 什么是脚本缓存 redis 缓存lua脚本 说明: 重启redis,脚本缓存会丢失 下面讲解 SCRIPT ... 系列 SCRIPT ② LOAD 语法:SCRIPT LOAD lua代码 -->载入一个脚本,只是预加载,不执行思考1&#xff1…

吴恩达《机器学习》10-4-10-5:诊断偏差和方差、正则化和偏差/方差

一、诊断偏差和方差 在机器学习中,诊断偏差和方差是改进模型性能的关键步骤。通过了解这两个概念,能够判断算法的问题究竟是欠拟合还是过拟合,从而有针对性地调整模型。 1. 概念理解 偏差(Bias): 表示模…

《微信小程序开发从入门到实战》学习三十一

3.4 开发参与投票页面 3.4.9 显示投票结果 在实际使用中,一个用户不能对同一个投票进行重复提交,因此需要向服务器端提交投票结果和提交用户ID。另外页面,需要完善。用户提交完投票后 ,还需要显示投票目前的结果,提交…