【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS

文件路径daclip-uir-main/universal-image-restoration/config/daclip-sde/test.py

代码有部分修改

导包

import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutils

import numpy as np
import torch
from IPython import embed
import lpips

import options as option
from models import create_model

sys.path.insert(0, "../../")
import open_clip
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

注意open_clip使用的是项目里的代码,而非环境里装的那个。data、util、option同样是项目里有的包

声明

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, default='options/test.yml', help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)

opt = option.dict_to_nonedict(opt)

配置文件 

设置配置文件相对地址options/test.yml

在该配置文件中配置GT和LQ图像文件地址

datasets:
  test1:
   name: Test
   mode: LQGT
   dataroot_GT: C:\Users\86136\Desktop\LQ_test\shadow\GT
   dataroot_LQ: C:\Users\86136\Desktop\LQ_test\shadow\LQ

设置results_root结果地址,每次计算结束这个地址保存要求记录的计算结果

该目录下Test文件夹将保存一张GT一张LQ一张复原图像  。

不设置也会默认在项目内 daclip-uir-main\results\daclip-sde\universal-ir

#### path
path:
  pretrain_model_G: E:\daclip\pretrained\universal-ir.pth
  daclip: E:\daclip\pretrained\daclip_ViT-B-32.pt
  results_root: C:\Users\86136\Desktop\daclip-uir-main\results\daclip-sde\universal-ir
  log: 

 

#### mkdir and logger
util.mkdirs(
    (
        path
        for key, path in opt["path"].items()
        if not key == "experiments_root"
        and "pretrain_model" not in key
        and "resume" not in key
    )
)

# os.system("rm ./result")
# os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

 报错执行代码没有删除再创建权限?我把相关os操作注释了,全部保存到result对我影响不大

加载创建数据对

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info(
        "Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)
        )
    )
    test_loaders.append(test_loader)

 自定义包含复原IR-SDE模型的外层类model,参考app.py

# load pretrained model by default
model = create_model(opt)
device = model.device

 加载DA-CLIP、IR-SDE

# clip_model, _preprocess = clip.load("ViT-B/32", device=device)
if opt['path']['daclip'] is not None:
    clip_model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=opt['path']['daclip'])
else:
    clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model = clip_model.to(device)

else是直接使用CLIP的ViT-B-32模型进行测试的代码。与我测DA-CLIP无关。

想使用的话 目测要预先下载对应模型权重并手动修改pretrained为文件地址,否则报错hf无法连接

sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)

scale = opt['degradation']['scale']

加载IR-SDE、LPIPS

如果不指定crop_border后续crop_border=scale

处理并计算


for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt["name"]  # path opt['']
    logger.info("\nTesting [{:s}]...".format(test_set_name))
    test_start_time = time.time()
    dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
    util.mkdir(dataset_dir)

    test_results = OrderedDict()
    test_results["psnr"] = []
    test_results["ssim"] = []
    test_results["psnr_y"] = []
    test_results["ssim_y"] = []
    test_results["lpips"] = []
    test_times = []

    for i, test_data in enumerate(test_loader):
        single_img_psnr = []
        single_img_ssim = []
        single_img_psnr_y = []
        single_img_ssim_y = []
        need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True
        img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        #### input dataset_LQ
        LQ, GT = test_data["LQ"], test_data["GT"]
        img4clip = test_data["LQ_clip"].to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_context, degra_context = clip_model.encode_image(img4clip, control=True)
            image_context = image_context.float()
            degra_context = degra_context.float()

        noisy_state = sde.noise_state(LQ)

        model.feed_data(noisy_state, LQ, GT, text_context=degra_context, image_context=image_context)
        tic = time.time()
        model.test(sde, save_states=False)
        toc = time.time()
        test_times.append(toc - tic)

        visuals = model.get_current_visuals()
        SR_img = visuals["Output"]
        output = util.tensor2img(SR_img.squeeze())  # uint8
        LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8
        GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8
        
        suffix = opt["suffix"]
        if suffix:
            save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")
        else:
            save_img_path = os.path.join(dataset_dir, img_name + ".png")
        util.save_img(output, save_img_path)

        # remove it if you only want to save output images
        LQ_img_path = os.path.join(dataset_dir, img_name + "_LQ.png")
        GT_img_path = os.path.join(dataset_dir, img_name + "_HQ.png")
        util.save_img(LQ_, LQ_img_path)
        util.save_img(GT_, GT_img_path)

        if need_GT:
            gt_img = GT_ / 255.0
            sr_img = output / 255.0

            crop_border = opt["crop_border"] if opt["crop_border"] else scale
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]
                cropped_gt_img = gt_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]

            psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
            lp_score = lpips_fn(
                GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1).squeeze().item()

            test_results["psnr"].append(psnr)
            test_results["ssim"].append(ssim)
            test_results["lpips"].append(lp_score)

            if len(gt_img.shape) == 3:
                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    if crop_border == 0:
                        cropped_sr_img_y = sr_img_y
                        cropped_gt_img_y = gt_img_y
                    else:
                        cropped_sr_img_y = sr_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                        cropped_gt_img_y = gt_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                    psnr_y = util.calculate_psnr(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )
                    ssim_y = util.calculate_ssim(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )

                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)

                    logger.info(
                        "img{:3d}:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}; LPIPS: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.".format(
                            i, img_name, psnr, ssim, lp_score, psnr_y, ssim_y
                        )
                    )
            else:
                logger.info(
                    "img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                        img_name, psnr, ssim
                    )
                )

                test_results["psnr_y"].append(psnr)
                test_results["ssim_y"].append(ssim)
        else:
            logger.info(img_name)


    ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])
    ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
    ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
    logger.info(
        "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(
            test_set_name, ave_psnr, ave_ssim
        )
    )
    if test_results["psnr_y"] and test_results["ssim_y"]:
        ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])
        ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])
        logger.info(
            "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(
                ave_psnr_y, ave_ssim_y
            )
        )

    logger.info(
            "----average LPIPS\t: {:.6f}\n".format(ave_lpips)
        )

    print(f"average test time: {np.mean(test_times):.4f}")

开头往log记录了相应配置文件内容,不需要可以注释。

遍历测试数据集(test_loaders)计算各种评价指标,如峰值信噪比(PSNR)、结构相似性(SSIM)和感知损失(LPIPS)。

在处理过程中,代码首先会创建一个目录来保存测试结果。

然后,对于每个测试图像,代码会加载对应的图像(如果可用),并使用一个名为clip_model的模型对图像进行编码。

接下来,代码会使用一个名为sde的随机微分方程模型和名为model的深度学习模型来处理带有噪声的图像,并生成复原图像(SR_img)。额可能作者拿了以前做超分的代码没改变量名

在这个过程中,text_contextimage_context被用作模型的输入,

图像都会被保存到之前创建的目录中。

此外,代码还会计算并记录每个图像的PSNR、SSIM和LPIPS分数,并在最后打印出这些分数的平均值。 代码中还包含了一些用于图像处理的实用函数,如util.tensor2img用于将张量转换为图像,util.save_img用于保存图像,以及util.calculate_psnrutil.calculate_ssim用于计算PSNR和SSIM分数。psnr_y和ssim_y 不用可以把相关代码注释。

最后,代码还计算了平均测试时间,并将其打印出来。

结果

log处理的单张图像报错的信息 0是该处理的图像排序序号,即正在处理第0张图

24-04-03 17:28:24.697 - INFO: img  0:_MG_2374_no_shadow - PSNR: 27.779773 dB; SSIM: 0.863140; LPIPS: 0.078669; PSNR_Y: 29.135256 dB; SSIM_Y: 0.869278.

 

可以给复原结果图加个后缀方便区分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/516259.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ruoyi-nbcio-plus基于vue3的flowable流程元素选择区面板的升级修改

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 http://122.227.135.243:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码&#xff1a…

工业项目中你连SCADA都没见过?

什么是SCADA SCADA是一种监控和数据采集系统,全称是Supervisor.Contro.an.Dat.Acquisition。SCADA系统在工业项目中具有广泛应用,包括生产线监控、工艺控制、设备维护、能源管理、安全监控和产量跟踪等多个场景。通过实时监测、数据采集和远程控制等功能…

网络协议栈--数据链路层

目录 对比理解“数据链路层”和“网络层”一、认识以太网1.1 以太网帧格式1.2 认识MAC地址1.3 对比理解MAC地址和IP地址1.4 认识MTU1.5 MTU对IP协议的影响1.6 MTU对UDP协议的影响1.7 MTU对于TCP协议的影响1.8 查看硬件地址和MTU 二、ARP协议2.1 ARP协议的作用2.2 ARP协议的工作…

15、Scalable Diffusion Models with Transformers

简介 官网 DiT(Diffusuion Transformer)将扩散模型的 UNet backbone 换成 Transformer,并且发现通过增加 Transformer 的深度/宽度或增加输入令牌数量,具有较高 Gflops 的 DiT 始终具有较低的 FID(~2.27)…

springJPA如果利用注解的方式 进行多表关联操作

前言:上一篇我写了个用JPA的Specification这个接口怎么做条件查询并且进行分页的,想学的自己去找一下 地址:springJPA动态分页 今天我们来写个 利用jpa的Query注解实现多表联合查询的demo 注意: 不建议在实际项目中用这玩意. 因为: 1. 用Query写的sql 可读性极差,给后期维护这…

六角螺母缺陷分类数据集:3440张图像

六角螺母缺陷数据集:包含变形,划痕,断裂,生锈,以及优质螺母图片数据,共计3440张,无标注 一.变形螺母-1839 二.断裂螺母-287 三.划痕螺母-473 四.生锈螺母-529 五.优良螺母-312 适用于CV项目&am…

法律行业案例法模型出现,OPenAI公布与法律AI公司Harvey合作案例

Harvey与OpenAl合作,为法律专业人士构建了一个定制训练的案例法模型。该模型是具有复杂推理广泛领域知识以及超越单一模型调用能力的任务的AI系统,如起草法律文件、回答复杂诉讼场景问题以及识别数百份合同之间的重大差异。 Harvey公司由具有反垄断和证…

阿里云服务器199元一年,ECS u1实例性能测评

阿里云服务器ECS u1实例,2核4G,5M固定带宽,80G ESSD Entry盘优惠价格199元一年,性能很不错,CPU采用Intel Xeon Platinum可扩展处理器,购买限制条件为企业客户专享,实名认证信息是企业用户即可&a…

树莓派部署yolov5实现目标检测(ubuntu22.04.3)

最近两天搞了一下树莓派部署yolov5,有点难搞(这个东西有点老,版本冲突有些包废弃了等等) 最后换到ubuntu系统弄了,下面是我的整体步骤: 1.烧完ubuntu镜像后,接显示器按系统流程进行系统部署(大于…

win10+Intel显卡安装配置stable-diffusion-webui绘画网页

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 例如:第一章 Python 机器学习入门之pandas的使用 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目…

精准扶贫管理系统|基于Springboot的精准扶贫管理系统设计与实现(源码+数据库+文档)

精准扶贫管理系统目录 目录 基于Springboot的精准扶贫管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员模块的实现 (1)用户信息管理 (2)贫困户信息管理 (3)新闻类型管理 &a…

openGauss学习笔记-256 openGauss性能调优-使用Plan Hint进行调优-优化器GUC参数的Hint

文章目录 openGauss学习笔记-256 openGauss性能调优-使用Plan Hint进行调优-优化器GUC参数的Hint256.1 功能描述256.2 语法格式256.3 参数说明 openGauss学习笔记-256 openGauss性能调优-使用Plan Hint进行调优-优化器GUC参数的Hint 256.1 功能描述 设置本次查询执行内生效的…

程序员沟通之道:TCP与UDP之辩,窥见有效沟通的重要性(day19)

程序员沟通的重要性: 今天被师父骂了一顿,说我不及时回复他,连最起码的有效沟通都做不到怎么当好一个程序员,想想还挺有道理,程序员需要知道用户到底有哪些需求,用户与程序员之间的有效沟通就起到了关键性作…

图DP

目录 有向无环图DP 力扣 329. 矩阵中的最长递增路径 力扣 2192. 有向无环图中一个节点的所有祖先 有向有环图DP 力扣 1306. 跳跃游戏 III 有向无环图DP 力扣 329. 矩阵中的最长递增路径 给定一个 m x n 整数矩阵 matrix ,找出其中 最长递增路径 的长度。 对…

Golang | Leetcode Golang题解之第3题无重复字符的最长子串

题目: 题解: func lengthOfLongestSubstring(s string) int {// 哈希集合,记录每个字符是否出现过m : map[byte]int{}n : len(s)// 右指针,初始值为 -1,相当于我们在字符串的左边界的左侧,还没有开始移动r…

50道Java经典面试题总结

1、那么请谈谈 AQS 框架是怎么回事儿? (1)AQS 是 AbstractQueuedSynchronizer 的缩写,它提供了一个 FIFO 队列,可以看成是一个实现同步锁的核心组件。 AQS 是一个抽象类,主要通过继承的方式来使用&#x…

Linux系统——网络管理

此文章以红帽Linux9版本为例进行讲解。 红帽Linux9版本的网络管理十分全面,可在多处进行网络配置的修改,但需要注意的是,在9版本内,用户可在配置文件内进行网络配置的修改,但系统不会执行修改的命令,而在9之…

C语言中的结构体:高级特性与扩展应用

前言 结构体在C语言中的应用不仅限于基本的定义和使用,还包含一些高级特性和扩展应用,这些特性和应用使得结构体在编程中发挥着更加重要的作用。 一、位字段(Bit-fields) 在结构体中,我们可以使用位字段来定义成员…

小林coding图解计算机网络|基础篇01|TCP/IP网络模型有哪几层?

小林coding网站通道:入口 本篇文章摘抄应付面试的重点内容,详细内容还请移步: 文章目录 应用层(Application Layer)传输层(Transport Layer)TCP段(TCP Segment) 网络层(Internet Layer)IP协议的寻址能力IP协议的路由能力 数据链路层(Link Lay…

Hadoop Yarn

首先先从Yarn开始讲起,Yarn是Hadoop架构的资源管理器,可以管理mapreduce程序的资源分配和任务调度。 Yarn主要有ResourceManager、NodeManage、ApplicationMaster,Container ResourceMange负责管理全局的资源 NodeManage(NM&a…