基于Swin_Transformer的图像超分辨率系统

1.研究背景与意义

项目参考AAAI Association for the Advancement of Artificial Intelligence

研究背景与意义

随着科技的不断发展,图像超分辨率技术在计算机视觉领域中变得越来越重要。图像超分辨率是指通过使用计算机算法将低分辨率图像转换为高分辨率图像的过程。这项技术在许多领域都有广泛的应用,包括医学图像处理、监控摄像头、卫星图像处理等。

在过去的几十年里,图像超分辨率技术已经取得了显著的进展。早期的方法主要基于插值和滤波技术,但这些方法无法捕捉到图像中的细节和纹理。随着深度学习的兴起,基于深度学习的图像超分辨率方法开始受到关注。其中,基于卷积神经网络(CNN)的方法取得了很大的成功。

然而,传统的CNN方法在处理大尺寸图像时存在一些问题。首先,它们需要大量的计算资源和存储空间,限制了它们在实际应用中的可行性。其次,它们往往无法处理大尺寸图像中的细节和纹理,导致生成的高分辨率图像质量不佳。因此,寻找一种高效且准确的图像超分辨率方法是非常重要的。

近年来,Swin Transformer作为一种新兴的注意力机制模型,已经在自然语言处理和计算机视觉领域取得了显著的成果。Swin Transformer采用了一种分层的注意力机制,能够在处理大尺寸图像时保持较高的效率和准确性。因此,将Swin Transformer应用于图像超分辨率任务是非常有前景的研究方向。

基于Swin Transformer的图像超分辨率系统具有以下几个重要的意义:

首先,基于Swin Transformer的图像超分辨率系统可以提供更高质量的高分辨率图像。Swin Transformer的注意力机制能够更好地捕捉到图像中的细节和纹理,从而生成更加真实和清晰的图像。这对于许多应用领域,如医学图像处理和卫星图像处理,具有重要的意义。

其次,基于Swin Transformer的图像超分辨率系统可以提高计算效率。传统的CNN方法在处理大尺寸图像时需要大量的计算资源和存储空间,限制了它们在实际应用中的可行性。而Swin Transformer采用了一种分层的注意力机制,能够在处理大尺寸图像时保持较高的效率和准确性,从而降低了计算成本。

最后,基于Swin Transformer的图像超分辨率系统可以为其他相关领域的研究提供借鉴和参考。Swin Transformer作为一种新兴的注意力机制模型,已经在自然语言处理和计算机视觉领域取得了显著的成果。将其应用于图像超分辨率任务可以为其他领域的研究提供新的思路和方法。

综上所述,基于Swin Transformer的图像超分辨率系统具有重要的研究背景和意义。它可以提供更高质量的高分辨率图像,提高计算效率,并为其他相关领域的研究提供借鉴和参考。随着深度学习和注意力机制的不断发展,相信基于Swin Transformer的图像超分辨率系统将在未来取得更加广泛的应用和研究进展。

2.图片演示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.视频演示

基于Swin_Transformer的图像超分辨率系统_哔哩哔哩_bilibili

4.图像超分辨重建原理

为了对图像超分辨率重建原理有更深入的理解,本小节将对高分辨率图像到低分辨率的退化过程进行详细介绍。受硬件设备的限制、环境因素的干扰和传输条件的限制,人们采集所得的实际图像的分辨率往往很难达到预期,而这些低分辨率图像通常由高分辨率图像经过多种退化过程所产生,包括光线干扰、运动模糊、噪声、压缩等退化因素。由于图像超分辨率重建是一个典型的逆向问题,其核心概念是建立对应的退化模型来学习从高分辨率图像到低分辨率图像的退化关系,进一步恢复低分辨率图像的纹理细节,因此,建立合适的退化模型是解决超分辨率重建问题的关键。
在这里插入图片描述

通过对上述图像退化模型的分析,可知通常的图像退化过程可以描述为高分辨率图像历经一系列的退化因素的影响,产生模糊甚至失真的低分辨率图像。假设x为原始高分辨率图像,J为退化后的低分辨率图像,则图像退化模型可以表示为:
在这里插入图片描述

其中,H()表示整个退化过程,D()表示下采样操作,B()表示模糊操作,L()表示光线干扰,n表示随机噪声,一般为高斯噪声或泊松噪声。图像超分辨率重建的本质就是从岁反向求解x的过程,如式所示:
在这里插入图片描述

即构造相应的图像恢复函数H-(-),对低分辨率图像y进行逆向推算,去除由于软硬件技术、环境因素和人为因素所带来的模糊、下采样和噪声等影响,尽可能恢复出原始高分辨率图像x。

5.核心代码讲解

5.1 main_test_swin2sr.py

class Swin2SR:
    def __init__(self, args):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.define_model()
        self.model.eval()
        self.model = self.model.to(self.device)
        self.test_results = OrderedDict()
        self.test_results['psnr'] = []
        self.test_results['ssim'] = []
        self.test_results['psnr_y'] = []
        self.test_results['ssim_y'] = []
        self.test_results['psnrb'] = []
        self.test_results['psnrb_y'] = []
        self.psnr, self.ssim, self.psnr_y, self.ssim_y, self.psnrb, self.psnrb_y = 0, 0, 0, 0, 0, 0

    def define_model(self):
        # 001 classical image sr
        if self.args.task == 'classical_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 002 lightweight image sr
        elif self.args.task == 'lightweight_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 003 real image sr
        elif self.args.task == 'real_sr':
            if self.args.large_model:
                model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                            num_classes=3, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],
                            mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
            else:
                model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                            num_classes=3, embed_dim=48, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],
                            mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 004 grayscale denoising
        elif self.args.task == 'gray_dn':
            model = net(upscale=1, in_chans=1, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=1, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 005 color denoising
        elif self.args.task == 'color_dn':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 006 jpeg compression artifact reduction
        elif self.args.task == 'jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 007 color jpeg compression artifact reduction
        elif self.args.task == 'color_jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        else:
            raise NotImplementedError(f'Task [{self.args.task}] is not implemented.')

        return model

    def setup(self):
        folder, save_dir, border, window_size = self.args.folder_lq, './outputs/', 0, self.args.training_patch_size
        return folder, save_dir, border, window_size

    def get_image_pair(self, path):
        imgname = os.path.splitext(os.path.basename(path))[0]
        img_lq = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        img_gt = None
        if self.args.folder_gt is not None:
            img_gt = cv2.imread(os.path.join(self.args.folder_gt, f'{imgname}.png'), cv2.IMREAD_UNCHANGED)
        return imgname, img_lq, img_gt

    def test(self, img_lq, model, args, window_size):
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = model(img_lq)
        if args.task == 'compressed_sr':
            output = output[0][..., :h_old * args.scale, :w_old * args.scale]
        else:
            output = output[..., :h_old * args.scale, :w_old * args.scale]
        return output

    def evaluate(self, output, img_gt, border):
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        cv2.imwrite(f'{save_dir}/{imgname}_Swin2SR.png', output)

        if img_gt is not None:
            img_gt = (img_gt * 255.0).round().astype(np.uint8)  # float32 to uint8
            img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...]  # crop gt
            img_gt = np.squeeze(img_gt)

            psnr = util.calculate_psnr(output, img_gt, crop_border=border)
            ssim = util.calculate_ssim(output, img_gt, crop_border=border)
            self.test_results['psnr'].append(psnr)
            self.test_results['ssim'].append(ssim)
            if img_gt.ndim == 3:  # RGB image
                psnr_y = util.calculate_psnr(output, img_gt, crop_border=border, test_y_channel=True)
                ssim_y = util.calculate_ssim(output, img_gt, crop_border=border, test_y_channel=True)
                self.test_results['psnr_y'].append(psnr_y)
                self.test_results['ssim_y'].append(ssim_y)
            if args.task in ['jpeg_car', 'color_jpeg_car']:
                psnrb = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=False)
                self.test_results['psnrb'].append(psnrb)
                if args.task in ['color_jpeg_car']:
                    psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=True)
                    self.test_results['psnrb_y'].append(psnrb_y)
            print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'

该程序文件是一个用于图像超分辨率重建的测试程序。程序首先通过命令行参数解析器解析输入参数,包括任务类型、尺度因子、噪声水平、JPEG压缩因子等。然后加载模型并设置设备。接下来,程序设置文件夹和路径,并创建一个用于保存结果的文件夹。然后,程序遍历输入文件夹中的所有图像,读取图像并进行预处理。然后,程序使用模型对图像进行推理,并将结果保存为图像文件。最后,程序计算并打印出PSNR和SSIM等评估指标的平均值。

该程序文件依赖于其他模块和函数,包括argparse、cv2、glob、numpy、collections、os、torch、requests等。其中,models.network_swin2sr模块定义了Swin2SR模型,utils模块包含了计算PSNR和SSIM的函数。

总体而言,该程序文件实现了图像超分辨率重建的测试功能,包括加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤。

5.2 predict.py

class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        print("Loading pipeline...")

        self.device = "cuda:0"

        args = argparse.Namespace()
        args.scale = 4
        args.large_model = False

        tasks = ["classical_sr", "compressed_sr", "real_sr"]
        paths = [
            "weights/Swin2SR_ClassicalSR_X4_64.pth",
            "weights/Swin2SR_CompressedSR_X4_48.pth",
            "weights/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth",
        ]
        sizes = [64, 48, 128]

        self.models = {}
        for task, path, size in zip(tasks, paths, sizes):
            args.training_patch_size = size
            args.task, args.model_path = task, path
            self.models[task] = define_model(args)
            self.models[task].eval()
            self.models[task] = self.models[task].to(self.device)

    def predict(
        self,
        image: Path = Input(description="Input image"),
        task: str = Input(
            description="Choose a task",
            choices=["classical_sr", "real_sr", "compressed_sr"],
            default="real_sr",
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        model = self.models[task]

        window_size = 8
        scale = 4

        img_lq = cv2.imread(str(image), cv2.IMREAD_COLOR).astype(np.float32) / 255.0
        img_lq = np.transpose(
            img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)
        )  # HCW-BGR to CHW-RGB
        img_lq = (
            torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
        )  # CHW-RGB to NCHW-RGB

        # inference
        with torch.no_grad():
            # pad input image to be a multiple of window_size
            _, _, h_old, w_old = img_lq.size()
            h_pad = (h_old // window_size + 1) * window_size - h_old
            w_pad = (w_old // window_size + 1) * window_size - w_old
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[
                :, :, : h_old + h_pad, :
            ]
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[
                :, :, :, : w_old + w_pad
            ]

            output = model(img_lq)

            if task == "compressed_sr":
                output = output[0][..., : h_old * scale, : w_old * scale]
            else:
                output = output[..., : h_old * scale, : w_old * scale]

        # save image
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(
                output[[2, 1, 0], :, :], (1, 2, 0)
            )  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        output_path = "/tmp/out.png"
        cv2.imwrite(output_path, output)

        return Path(output_path)

这个程序文件是一个用于图像超分辨率预测的预测器。它使用了Swin2SR模型来进行预测。文件中定义了一个名为Predictor的类,继承自BasePredictor类。在setup方法中,加载了模型并将其放入内存中以提高多次预测的效率。在predict方法中,通过传入一个输入图像和一个任务类型,可以运行单个预测。预测过程中,首先将输入图像进行预处理,然后使用模型进行推理,最后将输出图像保存到指定路径并返回。

5.3 ui.py


class Swin2SR:
    def __init__(self, args):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.define_model()
        self.model.eval()
        self.model = self.model.to(self.device)
    
    def define_model(self):
        if self.args.task == 'classical_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task in ['lightweight_sr']:
            model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'compressed_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'real_sr':
            if not self.args.large_model:
                model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                            img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                            mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
            else:
                model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                            img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                            num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                            mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
            param_key_g = 'params_ema'
        elif self.args.task == 'jpeg_car':
            model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
                        img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'color_jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=126, window_size=7,
                        img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='', resi_connection='1conv')
            param_key_g = 'params'
        pretrained_model = torch.load(self.args.model_path)
        model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model,
                              strict=True)
        return model
    
    def setup(self):
        if self.args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
            save_dir = f'results/swin2sr_{self.args.task}_x{self.args.scale}'
            if self.args.save_img_only:
                folder = self.args.folder_lq
            else:
                folder = self.args.folder_gt
            border = self.args.scale
            window_size = 8
       

ui.py是一个用于图像超分辨率的PyQt5界面程序。它导入了PyQt5和其他一些必要的库,并定义了一些函数来加载模型、设置参数、获取图像对和进行测试。主要的函数是main()函数,它接受一个图像路径作为输入,并根据指定的参数加载模型并对图像进行超分辨率处理。处理结果将保存在指定的文件夹中。

5.4 models\network_swin2sr.py
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        pretrained_window_size (int): Window size in pre-training.
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
            pretrained_window_size=to_2tuple(pretrained_window_size))

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        # calculate attention mask for SW-MSA
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, N, C).
        """
        B, N, C = x.shape
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, N, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        x = self.attn(x, mask=self.attn_mask)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

这是一个实现Swin Transformer模型的Python程序文件。Swin Transformer是一种用于压缩图像超分辨率和恢复的模型,具体细节可以参考论文https://arxiv.org/abs/2209.11345。

程序文件中定义了一些辅助函数和模块,包括Mlp、window_partition、window_reverse、WindowAttention和SwinTransformerBlock。

Mlp是一个多层感知机模块,用于对输入进行线性变换和激活函数处理。

window_partition和window_reverse函数用于将输入图像划分为窗口,并将窗口恢复为原始图像。

WindowAttention是一个基于窗口的多头自注意力模块,支持相对位置偏置。

SwinTransformerBlock是Swin Transformer的一个基本模块,包括窗口注意力和多层感知机。

整个程序文件实现了Swin Transformer模型的核心组件,可以用于图像超分辨率和恢复任务。

5.5 utils\plots.py


class ImageLoader:
    def __init__(self, debug=False, norm=True, resize=None):
        self.debug = debug
        self.norm = norm
        self.resize = resize
    
    def load_img(self, filename):
        img = cv2.imread(filename)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.norm:   
            img = img / 255.
            img = img.astype(np.float32)
        if self.debug:
            print (img.shape, img.dtype, img.min(), img.max())
            
        if self.resize:
            img = cv2.resize(img, (self.resize[0], self.resize[1]))
            
        return img
    
    def plot_all(self, images, axis='off', figsize=(16, 8)):
        fig = plt.figure(figsize=figsize, dpi=80)
        nplots = len(images)
        for i in range(nplots):
            plt.subplot(1,nplots,i+1)
            plt.axis(axis)
            plt.imshow(images[i])
        plt.show()

这个程序文件是一个用于绘制图像的工具文件,文件名为utils\plots.py。该文件包含了两个函数load_img和plot_all。

load_img函数用于加载图像文件。它接受一个文件名作为参数,并可选择是否进行调试、归一化和调整大小。函数首先使用OpenCV库的imread函数读取图像文件,然后将图像从BGR颜色空间转换为RGB颜色空间。如果选择进行归一化,则将图像的像素值除以255,并将其转换为32位浮点数类型。如果选择进行调试,则会打印图像的形状、数据类型、最小值和最大值。如果选择调整大小,则会使用OpenCV库的resize函数将图像调整为指定的大小。最后,函数返回加载和处理后的图像。

plot_all函数用于绘制多个图像。它接受一个图像列表作为参数,并可选择绘制轴的样式和图像的大小。函数首先创建一个matplotlib的Figure对象,并设置其大小和分辨率。然后,根据图像列表的长度,在Figure对象中创建相应数量的子图。对于每个子图,设置轴的样式,并使用imshow函数显示对应的图像。最后,调用show函数显示绘制的图像。

这个程序文件提供了方便的函数来加载和绘制图像,可以在图像处理和分析的过程中使用。

5.6 utils\util_calculate_psnr_ssim.py
import cv2
import torch
import numpy as np

class ImageMetrics:
    def __init__(self, input_order='HWC'):
        self.input_order = input_order

    def calculate_psnr(self, img1, img2, crop_border, test_y_channel=False):
        assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
        if self.input_order not in ['HWC', 'CHW']:
            raise ValueError(f'Wrong input_order {self.input_order}. Supported input_orders are ' '"HWC" and "CHW"')
        img1 = self.reorder_image(img1)
        img2 = self.reorder_image(img2)
        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)

        if crop_border != 0:
            img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
            img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

        if test_y_channel:
            img1 = self.to_y_channel(img1)
            img2 = self.to_y_channel(img2)

        mse = np.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        return 20. * np.log10(255. / np.sqrt(mse))

    ......

该程序文件是一个用于计算图像质量评估指标的工具文件。主要包含以下几个函数:

  1. calculate_psnr(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的峰值信噪比(PSNR)指标。

  2. _ssim(img1, img2):计算图像的结构相似性(SSIM)指标。

  3. calculate_ssim(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的结构相似性(SSIM)指标。

  4. _blocking_effect_factor(im):计算图像的块效应因子。

  5. calculate_psnrb(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的PSNR-B指标。

  6. reorder_image(img, input_order=‘HWC’):重新排列图像的通道顺序。

  7. to_y_channel(img):将图像转换为Y通道。

  8. bgr2ycbcr(img, y_only=False):将BGR图像转换为YCbCr图像。

这些函数可以用于评估图像处理算法的

6.系统整体结构

整体功能和构架概述:

该图像超分辨率系统的整体功能是实现图像的超分辨率重建。它使用了基于Swin Transformer的模型进行图像超分辨率处理。系统包含了多个程序文件,每个文件负责不同的功能模块。主要的程序文件包括:

  1. main_test_swin2sr.py:用于图像超分辨率重建的测试程序。负责加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤。

  2. predict.py:图像超分辨率预测的预测器。使用Swin2SR模型进行预测,定义了Predictor类,负责加载模型并进行预测。

  3. ui.py:图像超分辨率的PyQt5界面程序。通过界面输入图像路径和参数,调用predict.py中的函数进行超分辨率处理。

  4. models\network_swin2sr.py:实现Swin Transformer模型的核心组件。定义了多个辅助函数和模块,包括Mlp、WindowAttention和SwinTransformerBlock等。

  5. utils\plots.py:用于绘制图像的工具文件。包含load_img和plot_all函数,用于加载和绘制图像。

  6. utils\util_calculate_psnr_ssim.py:用于计算图像质量评估指标的工具文件。包含多个函数,用于计算PSNR、SSIM和其他指标。

  7. utils_init_.py:空文件,用于标识utils文件夹为Python模块。

下表整理了每个文件的功能:

文件路径功能
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\main_test_swin2sr.py图像超分辨率重建的测试程序,包括加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\predict.py图像超分辨率预测的预测器,定义了Predictor类,负责加载模型并进行预测
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\ui.py图像超分辨率的PyQt5界面程序,通过界面输入图像路径和参数,调用predict.py中的函数进行超分辨率处理
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\models\network_swin2sr.py实现Swin Transformer模型的核心组件,包括辅助函数和模块,如Mlp、WindowAttention和SwinTransformerBlock
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils\plots.py用于绘制图像的工具文件,包含load_img和plot_all函数,用于加载和绘制图像
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils\util_calculate_psnr_ssim.py用于计算图像质量评估指标的工具文件,包含多个函数,用于计算PSNR、SSIM和其他指标
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils_init_.py空文件,用于标识utils文件夹为Python模块

7.Swin_Transformer用于超分辨率重建

参考该博客提出的RefSR工作,主要观点是将Transformer作为一个attention,这样可以更好地将参考图像(Ref)的纹理信息转移到高质图像(HR)中。做法还是比较有意思的,如下图所示,将上采样的LR图像、依次向下/上采样的Ref图像、原始Ref图像中提取的纹理特征分别作为Q、K、V。纹理Transformer包含了4个结构:1)DNN实现的可学习的纹理提取器(learnable texture extractor)2)相关性嵌入模块( relevance embedding)3)用于纹理转换的硬注意力模块(hard-attention)4)用于纹理合成的软注意力模块(soft-attention)。此外整个纹理Transformer模块可以跨尺度的方式进一步堆叠,这使得能够从不同尺度(例如,从1x倍到4x倍放大率)恢复纹理。

在这里插入图片描述

网络的整体架构

如下图所示,将多个纹理Transformer(即上图)堆叠、上采下采融合来实现超分。
其中RBS为多个残差Block,CSFI为跨尺度特征集成模块(ross-scale feature integration )

纹理Transformer

在这里插入图片描述

即图,介绍一下他的四个组件。
1)DNN实现的可学习的纹理提取器。就是将图像送入DNN,然后DNN可以训练
2)相关性嵌入模块。使用归一化内积计算Q、K之间的相关性。获得矩阵r i , j r_{i,j}r
3)硬注意力。通过h i = a r g m a x ( r i , j ) h_{i}=argmax(r_{i,j})h
4)软注意力。获得软注意力图s i = a r g m a x ( r i , j ) s_{i}=argmax(r_{i,j})s

再分析一下这个公式,当S大的时候,说明当前块和T的相关性大,所以用更多的T的特征,如果S小,则使用更少的参考帧特征。
在这里插入图片描述

损失函数

L1 loss + GAN loss + Percepture Loss

网络结构

在这里插入图片描述
1)Shallow Feature Extraction 为一层3x3卷积。
2)HQ Image Reconstruction在SR任务中采用sub-pixel Conv,就是unpixelShuffle。denoise和JPEG去伪影用一层卷积。
3)对STL,就是Transformer的Encoder结构。将输入划分为M ∗ M M*MM∗M个块X,然后每个X映射为QKV,通过多头attention后将输出concat。MLP通过两层FC实现。作者还进行了划窗来避免图像块之间的信息不融合问题。步长为M / 2 M/2M/2

EMHA

主要是在获得QKV之后,将QKV特征分为s组,每组分别进行attention获得输出O,然后将输出Concat,这样可以将大矩阵相乘拆分为多个小矩阵相乘。这也是Transformer常见的减少参数操作。
在这里插入图片描述

HFM

此外该博客的作者还用了一个High-frequencyFiltering Module (HFM)提取高频信息,结构如下,仅供参考。

在这里插入图片描述

Microsoft Bing Turing ISR(T-ISR)

Introducing Turing Image Super Resolution: AI powered image enhancements for Microsoft Edge and Bing Maps
这篇不算论文,是微软介绍自家用于Microsoft Edge和Bing Maps上ISR的技术博客。但是效果非常Amazing啊,但缺点是有些地方没有仔细介绍。

设计原则

1)人类视觉为基准(Human eyes as the north star)
广泛使用的指标如PSNR,SSIM并不总是和人眼视觉的直观感受匹配的,同时也需要GT图。我们构建了一个并行评估工具匹配人眼判断,并将这个工具作为north star metric来引导模型训练。(可是作者没介绍这个工具是啥55555)
2)噪声建模(Noise modeling)
开始作者也是将HR图像降质然后构建HR-LR图相对训练。但这样有些case效果好,但是对真实的LR图像不鲁棒。因此随机对输入图像用blurring, compression 和 gaussian noise进行破坏可以恢复细节。
3)Perceptual and GAN loss
仅pixel loss不够,要引入感知和GAN loss,并用权重结合。
4)Transformers for vision
CNN和Transformer各有优缺点,因此未利用他们各自优点,将网络分为Enhance和Zoom,前者使用Transformer,后者使用CNN。(其实这段也没详细介绍各自优缺点是什么。整体四准则很对我胃口啊,果然英雄所见略同hhhh)

DeepEnhance – Cleaning and Enhancing Images

在处理高度压缩和从远程卫星拍摄的航拍照片等very noise图像时,Transformer清理噪声做的很好。如人脸的噪声和处理包含很多纹理的森林的特征就很不同。这是因为大数据集和Transformer卓越的远程记忆能力。我们先使用了一个稀疏Transformer,将其放大以支持非常大的序列长度来“Enhance”图像,产生干净的,crisper和更具吸引力,尺寸相同的图像。有些场景不需要放大图像,那到这里就可以停止了。

在这里插入图片描述

8.系统整合

下图完整源码&环境部署视频教程&自定义UI界面

在这里插入图片描述

参考博客《基于Swin_Transformer的图像超分辨率系统》

9.参考文献


[1]盘展鸿,朱鉴,迟小羽,等.基于特征融合和注意力机制的图像超分辨率模型[J].计算机应用研究.2022,39(3).DOI:10.19734/j.issn.1001-3695.2021.07.0288 .

[2]邓焱文.基于深度学习的超分辨率重建在人脸识别中的应用[D].2019.

[3]Yu-Qi Liu,Xin Du,Hui-Liang Shen,等.Estimating Generalized Gaussian Blur Kernels for Out-of-Focus Image Deblurring[J].IEEE Transactions on Circuits & Systems for Video Technology.2020,31(3).829-843.DOI:10.1109/TCSVT.2020.2990623 .

[4]Shengxiang Zhang,Gaobo Liang,Shuwan Pan,等.A Fast Medical Image Super Resolution Method Based on Deep Learning Network[J].IEEE Access.2018.712319-12327.DOI:10.1109/ACCESS.2018.2871626 .

[5]Huihui Song,Qingshan Liu,Guojie Wang,等.Spatiotemporal Satellite Image Fusion Using Deep Convolutional Neural Networks[J].IEEE journal of selected topics in applied earth observations & remote sensing.2018,11(3).821-829.DOI:10.1109/JSTARS.2018.2797894 .

[6]Park, S.,Serpedin, E.,Qaraqe, K..Gaussian Assumption: The Least Favorable but the Most Useful [Lecture Notes][J].IEEE Signal Processing Magazine.2013,30(3).183-186.

[7]Mittal, A.,Soundararajan, R.,Bovik, A.C..Making a “Completely Blind” Image Quality Analyzer[J].Signal Processing Letters, IEEE.2013,20(3).209-212.DOI:10.1109/LSP.2012.2227726 .

[8]Ogawa, T.,Haseyama, M..Missing Intensity Interpolation Using a Kernel PCA-Based POCS Algorithm and its Applications[J].IEEE Transactions on Image Processing.2011,20(2).

[9]Yang, J.Wright, J.Huang, T.Ma, Y..Image Super-Resolution Via Sparse Representation[J].IEEE Transactions on Image Processing.2010,19(11).2861-2873.

[10]Bovik A.C.,Zhou Wang,Simoncelli E.P.,等.Image quality assessment: from error visibility to structural similarity[J].IEEE Transactions on Image Processing.2004,13(4).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/231244.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

111.am40刷机折腾记4-firefly镜像-dp正常显示

1. 平台&#xff1a; rk3399 am40 4g32g 2. 内核&#xff1a;firefly的内核&#xff08;整体镜像&#xff09; 版本&#xff1a; linux4.4.194 3. 交叉编译工具 &#xff1a;暂时不编译 4. 宿主机&#xff1a;ubuntu18.04 5. 需要的素材和资料&#xff1a;boot-am40-202…

Uos打包工具最新

我司中标的桌面端项目&#xff08;electron开发的应用&#xff09;需兼容统信UOS&#xff0c;关键是要发布到应用商店&#xff0c;首先使用了debreateForUos工具进行打包&#xff0c;打包之后也通过了审核上架到了商店&#xff0c;本以为一切都是如此丝滑顺利&#xff0c;但后续…

10个电子工程师常用的测量仪器详解

之前我们聊了电子工程师常用的模电及数电&#xff0c;得到了很多粉丝朋友的追捧&#xff0c;所以今天主要讲讲电子工程师常用的测量仪器&#xff0c;希望对小伙伴们有所帮助&#xff0c;一起来看看吧&#xff01; 1、万用表 万用表是最基本的测量仪器之一&#xff0c;用于测量…

【Linux】cat 命令使用

cat 命令 cat&#xff08;英文全拼&#xff1a;concatenate&#xff09;命令用于连接文件并打印到标准输出设备上。 可以使用cat连接多个文件、创建新文件、将内容附加到现有文件、查看文件内容以及重定向终端或文件中的输出。 cat可用于在不同选项的帮助下格式化文件的输出…

志愿者小程序开发方案详解

志愿者服务小程序有三端&#xff1a;用户端商家端&#xff0c;管理员端&#xff0c;总管理后台。申请成为志愿者&#xff0c;参加志愿者活动&#xff0c;获得积分和服务时长&#xff0c;志愿者服务时长排名&#xff0c;积分可以兑换商品。社区管理员可以管理自己社区的志愿者和…

使用IDM批量下载NASA气象数据

写在前面:因为科研需要&#xff0c;所以需要批量下载NASA数据&#xff0c;但是nasa的文件会每天给一个url链接&#xff0c;手动下载起来很慢&#xff0c;所以特写此篇文章用以教学如何批量下载NASA气象数据 1.下载NASA数据: 首先我们先进入到官网&#xff1a; 官网链接 右上…

海外媒体发稿:软文发稿推广技巧解析超级实用-华媒舍

随着互联网时代的发展&#xff0c;软文发稿成为推广产品与服务的重要手段之一。本文将向大家介绍软文发稿推广的技巧&#xff0c;帮助您更好地利用软文推广商业活动。无论是拥有自己的品牌还是个人创业者&#xff0c;都可以从中受益。 1. 什么是软文&#xff1f; 软文是指以文…

Windows 系统,TortoiseSVN 无法修改 Log 信息解决方法

使用SVN提交版本信息时&#xff0c;注释内容写的不全。通过右键TortoiseSVN的Show log看到提交的的注释&#xff0c;右键看到Edit log message的选项&#xff0c;然而提交后却给出错误提示&#xff1a; Repository has not been enabled to accept revision propchanges; ask …

Python:核心知识点整理大全9-笔记

目录 ​编辑 5.2.4 比较数字 5.2.5 检查多个条件 1. 使用and检查多个条件 2. 使用or检查多个条件 5.2.6 检查特定值是否包含在列表中 5.2.7 检查特定值是否不包含在列表中 banned_users.py 5.2.8 布尔表达式 5.3 if 语句 5.3.1 简单的 if 语句 5.3.2 if-else 语句 …

AI改写文章的软件,免费AI改写原创文章的工具

在当今数字化时代&#xff0c;人们的工作生活已经离不开信息技术的支持。特别是在文案创作领域&#xff0c;如何提高效率、保持文案质量成为了许多写作者关注的焦点&#xff0c;本文将深入探讨AI改写文案的工具&#xff0c;包括其原理、应用场景。 AI改写文案的背后原理 为了更…

Android View.inflate 和 LayoutInflater.from(this).inflate 的区别

前言 两个都是布局加载器&#xff0c;而View.inflate是对 LayoutInflater.from(context).inflate的封装&#xff0c;功能相同&#xff0c;案例使用了dataBinding。 View.inflate(context, layoutResId, root) LayoutInflater.from(context).inflate(layoutResId, root, fals…

mmdetection测试保存到新的文件夹,无需标签

这个是用demo这个代码测试的&#xff0c;需要先训练一个pth文件夹&#xff0c;训练之后再调用pth文件夹进行测试。测试的代码文件名是&#xff1a;image_demo_new.py&#xff0c;代码如系所示&#xff1a; # Copyright (c) OpenMMLab. All rights reserved. import asyncio fr…

【FPGA图像处理实战】- RGB与YUV互转

RGB颜色空间和YUV颜色空间是图像处理中经常遇到的两个颜色空间,但它们的特性不一样,应用的场景有差异,所以经常会遇到有RGB转YUV、YUV转RGB的需求。 前几天更新了FPGA数学运算的几节课程,今天我们来学习一下“RGB与YUV互转”,主要分为5个部分:RGB与YUV的介绍、RGB与YUV互…

uniapp实战 —— 竖排多级分类展示

效果预览 完整范例代码 页面 src\pages\category\category.vue <script setup lang"ts"> import { getCategoryTopAPI } from /apis/category import type { CategoryTopItem } from /types/category import { onLoad } from dcloudio/uni-app import { compu…

Java网络通信-第21章

Java网络通信-第21章 1.网络程序设计基础 网络程序设计基础涵盖了许多方面&#xff0c;包括网络协议、Web开发、数据库连接、安全性等。 1.1局域网与互联网 局域网&#xff08;LAN&#xff09;与互联网&#xff08;Internet&#xff09;是两个不同的概念&#xff0c;它们分…

咸鱼开店的经验分享

项目特点 1.无需囤货、积货、压货&#xff1b; 2.开店0门槛&#xff0c;有淘宝号即可&#xff0c;无需繁琐的开店流程&#xff1b; 3.免收店铺押金、保证金、平台佣金&#xff1b; 4.平台自带流量&#xff0c;无需砸钱推广。 准备的账号 1.不在于多&#xff0c;而在于精。…

CPU、MCU、MPU、DSP、FPGA各是什么?有什么区别?

1、CPU 中央处理器&#xff0c;简称 CPU&#xff08;Central Processing Unit&#xff09;&#xff0c;中央处理器主要包括两个部分&#xff0c;即控制器、运算器&#xff0c;其中还包括高速缓冲存储器及实现它们之间联系的数据、控制的总线。 电子计算机三大核心部件就是CPU…

使用Postman进行自动化集成测试

1 前言 笔者在使用Node开发HTTP接口的过程中&#xff0c;发现当接口数量越来越多&#xff0c;且接口之间互相依赖时&#xff0c;接口测试流程就会变得十分繁琐&#xff0c;且容易出错。那如何才能高效且全面地对接口进行测试呢&#xff1f; 通过实践&#xff0c;笔者发现可以…

激光雷达标定板提高扫地机器人感知环境能力和清洁效率

智能扫地机器人的激光雷达标定板是一种用于校准激光雷达的设备&#xff0c;它通常由不同反射率的涂料涂覆在板面上&#xff0c;用于接收激光雷达发出的激光束并将其反射回来&#xff0c;从而帮助校准激光雷达的测量参数。在自动驾驶和机器人领域&#xff0c;激光雷达和相机的联…

什么是Redis数据库,如何在 CentOS 7 上安装 Redis,看完你就懂了

目录 一、Redis简介 二、Redis特点 三、数据类型 四、Redis应用场景 五、Centos环境部署Redis 六、常见参数整理 一、Redis简介 Redis &#xff0c;是一个高性能(NOSQL)的key-value数据库,Redis是一个开源的使用ANSI C语言编写、支持网络、可基于内存亦可持久化的日志型…