LaMa 论文复现:Resolution-robust Large Mask Inpainting with Fourier Convolutions

 代码:GitHub - andy971022/auto-lama 

论文:https://arxiv.org/abs/2109.07161

1 LaMa 论文简介

2 LaMa代码复现

2.1 环境部署

 2.1.1 下载源码,创建环境,安装必需库

git clone https://github.com/advimman/lama
cd lama
conda env create -f conda_env.yml
conda activate lama
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -y
pip install pytorch-lightning==1.2.9

2.2  公开数据集训练测试与结果可视化

2.2.1 LaMa 测试数据集和预训练模型下载

(1)预训练模型下载链接:

预训练模型 https://disk.yandex.ru/d/kHJkc7bs7mKIVAicon-default.png?t=N7T8https://disk.yandex.ru/d/kHJkc7bs7mKIVA

 预训练模型下载好后,存放在checkpoints文件夹下。

(2)测试数据集下载:

    # Download data from http://places2.csail.mit.edu/download.html
    # Places365-Standard: Train(105GB)/Test(19GB)/Val(2.1GB) from High-resolution images section
    wget http://data.csail.mit.edu/places/places365/train_large_places365standard.tar
    wget http://data.csail.mit.edu/places/places365/val_large.tar
    wget http://data.csail.mit.edu/places/places365/test_large.tar

http://data.csail.mit.edu/places/places365/val_large.tar
http://data.csail.mit.edu/places/places365/test_large.tar

2.2.2  place365 数据集训练

2.2.3  place365 数据集测试

预测性能,基于big-lama数据集中的LaMa_test_images。

运行以下命令,其中refine=true 表示将运行图像修复器。

(nerf) D:\0A_project\lama-main\bin> python predict.py refine=True model.path=$(pwd)../checkpoint/big-lama indir=$(pwd)../LaMa_test_images outdir=$(pwd)../output

model.path=$(pwd)/big-lama: 这部分是传递给predict.py脚本的命令行参数之一。它设置了一个参数model.path,并将其值设置为当前目录(通过$(pwd)获取)下的big-lama

indir=$(pwd)/LaMa_test_images: 这是另一个命令行参数,用于设置输入目录。它将indir参数的值设置为当前目录下的LaMa_test_images目录。

outdir=$(pwd)/output: 类似地,这是设置输出目录的参数。它将outdir参数的值设置为当前目录下的output目录。

出错如下:

Traceback (most recent call last):
  File "predict.py", line 24, in <module>
    from  saicinpainting.evaluation.utils import move_to_device
ModuleNotFoundError: No module named 'saicinpainting'

代码段引用模块包内容如下:

import logging
import os
import traceback
import sys
from saicinpainting.evaluation.refinement import refine_predict
from saicinpainting.evaluation.utils import move_to_device

文件结构如下,saicinpainting模块包位于lama-main 主文件夹下,predict.py位于bin文件夹中。 

       

因此,出现 ModuleNotFoundError: No module named 'saicinpainting'  错误是该包没有在搜索路径中找到,故需要把该路径添加到搜索路径中,代码更改如下:

import logging
import os
import traceback
import sys
sys.path.append(r'D:\0A_project\lama-main')  # 添加项目根目录到 sys.path
from saicinpainting.evaluation.refinement import refine_predict
from saicinpainting.evaluation.utils import move_to_device

再次运行,又报错(待解决)

(1)不要用GPU预测,尝试无法解决

(2)python predict.py refine=True model.path=$(pwd)../checkpoint/big-lama indir=$(pwd)../LaMa_test_images outdir=$(pwd)../output  HYDRA_FULL_ERROR=1   无法解决

(3)D:\0A_project\lama-main\configs\prediction\default.yaml   添加  HYDRA_FULL_ERROR=1   无法解决

(4)注释掉predict.py line 41    

# register_debug_signal_handlers()  # kill -10 <pid> will result in traceback dumped into log

(2)在predict.py line45 后面加上一句

train_config_path = os.path.join('<your_full_path_to_lama_base_directory>', train_config_path)

如下: 

在lin54 句后面加上
checkpoint_path = os.path.join('<your_full_path_to_lama_base_directory>', checkpoint_path)

2.2.4  测试结果和参数可视化

2.3 制作自己的数据集,训练测试与结果的可视化

2.3.1 制作自己的数据集

(1)创建数据集图片对应的mask图,命名为images_name_maskxxx.png, 将images原图与对应的masks原图放在同一文件夹下。数据集文件格式如下:
    ```    
    image1_mask001.png
    image1.png
    image2_mask001.png
    image2.png
    ```
(2)利用(https://github.com/advimman/lama/blob/main/bin/gen_mask_dataset.py) 生成随机的mask图片。

将自己图像的数据集存放在myown_dataset文件夹下面。

将configs/prediction/default.yaml 文件中的`image_suffix` 声明为png或jpg或_input.jpg,如下
indir: no  # 将在CLI中被覆盖
outdir: no  # 将在CLI中被覆盖

model:
  path: no  # 将在CLI中被覆盖
  checkpoint: best.ckpt

dataset:
  kind: default
  img_suffix: .png
  pad_out_to_modulo: 8  # 输出图像将被填充到8的倍数

device: cuda  # 使用CUDA设备
out_key: inpainted  # 输出键:inpainted

refine: False  # 如果为True,将运行图像修复器
refiner:
  gpu_ids: 0,1  # 使用的GPU编号。如果只使用单个GPU,使用:"0,"
  modulo: ${dataset.pad_out_to_modulo}  # 与数据集的填充模数一致
  n_iters: 15  # 每个尺度的迭代修复次数
  lr: 0.002  # 学习率
  min_side: 512  # 所有尺度的图像边缘都应 >= min_side / sqrt(2)
  max_scales: 3  # 图像-掩码金字塔的最大降尺度数量
  px_budget: 1800000  # 像素预算。任何图像都将调整大小以满足高*宽 <= px_budget
运行命令
python3 bin/gen_mask_dataset.py indir=$(pwd)/myown_dataset outdir=$(pwd)/myown_dataset   

gen_mask_dataset.py解读如下
#!/usr/bin/env python3

import glob  # 用于查找文件
import os  # 提供文件和目录操作的功能
import shutil  # 用于文件复制和移动
import traceback  # 用于处理异常信息

import PIL.Image as Image  # 用于处理图像的Python库
import numpy as np  # 用于数值计算的Python库
from joblib import Parallel, delayed  # 用于并行处理任务的库

from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop  # 导入特定的图像处理功能
from saicinpainting.evaluation.utils import load_yaml, SmallMode  # 导入加载YAML配置和小模式处理的功能
from saicinpainting.training.data.masks import MixedMaskGenerator  # 导入混合掩码生成器

# 创建一个包装器,用于生成多个掩码变体
class MakeManyMasksWrapper:
    def __init__(self, impl, variants_n=2):
        self.impl = impl
        self.variants_n = variants_n

    def get_masks(self, img):
        img = np.transpose(np.array(img), (2, 0, 1))
        return [self.impl(img)[0] for _ in range(self.variants_n)]

# 处理图像
def process_images(src_images, indir, outdir, config):
    # 根据配置选择掩码生成器
    if config.generator_kind == 'segmentation':
        mask_generator = SegmentationMask(**config.mask_generator_kwargs)
    elif config.generator_kind == 'random':
        variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
        mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
                                              variants_n=variants_n)
    else:
        raise ValueError(f'Unexpected generator kind: {config.generator_kind}')

    max_tamper_area = config.get('max_tamper_area', 1)

    for infile in src_images:
        try:
            # 获取文件相对路径
            file_relpath = infile[len(indir):]
            img_outpath = os.path.join(outdir, file_relpath)
            os.makedirs(os.path.dirname(img_outpath), exist_ok=True)

            # 打开输入图像并转换为RGB格式
            image = Image.open(infile).convert('RGB')

            # 将输入图像缩放到输出分辨率,并过滤小图像
            if min(image.size) < config.cropping.out_min_size:
                handle_small_mode = SmallMode(config.cropping.handle_small_mode)
                if handle_small_mode == SmallMode.DROP:
                    continue
                elif handle_small_mode == SmallMode.UPSCALE:
                    factor = config.cropping.out_min_size / min(image.size)
                    out_size = (np.array(image.size) * factor).round().astype('uint32')
                    image = image.resize(out_size, resample=Image.BICUBIC)
            else:
                factor = config.cropping.out_min_size / min(image.size)
                out_size = (np.array(image.size) * factor).round().astype('uint32')
                image = image.resize(out_size, resample=Image.BICUBIC)

            # 生成和选择掩码
            src_masks = mask_generator.get_masks(image)

            filtered_image_mask_pairs = []
            for cur_mask in src_masks:
                if config.cropping.out_square_crop:
                    (crop_left,
                     crop_top,
                     crop_right,
                     crop_bottom) = propose_random_square_crop(cur_mask,
                                                               min_overlap=config.cropping.crop_min_overlap)
                    cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
                    cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
                else:
                    cur_image = image

                if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
                    continue

                filtered_image_mask_pairs.append((cur_image, cur_mask))

            mask_indices = np.random.choice(len(filtered_image_mask_pairs),
                                            size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
                                            replace=False)

            # 剪裁掩码并保存掩码和输入图像
            mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
            for i, idx in enumerate(mask_indices):
                cur_image, cur_mask = filtered_image_mask_pairs[idx]
                cur_basename = mask_basename + f'_crop{i:03d}'
                Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
                                mode='L').save(cur_basename + f'_mask{i:03d}.png')
                cur_image.save(cur_basename + '.png')
        except KeyboardInterrupt:
            return
        except Exception as ex:
            print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')

# 主函数
def main(args):
    if not args.indir.endswith('/'):
        args.indir += '/'

    os.makedirs(args.outdir, exist_ok=True)

    config = load_yaml(args.config)

    in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
    if args.n_jobs == 0:
        process_images(in_files, args.indir, args.outdir, config)
    else:
        in_files_n = len(in_files)
        chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
        Parallel(n_jobs=args.n_jobs)(
            delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
            for start in range(0, len(in_files), chunk_size)
        )

# 如果这个脚本被直接执行
if __name__ == '__main__':
    import argparse

    aparser = argparse.ArgumentParser()
    aparser.add_argument('config', type=str, help='Path to config for dataset generation')
    aparser.add_argument('indir', type=str, help='Path to folder with images')
    aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
    aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
    aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')

    main(aparser.parse_args())
    用于处理图像并生成掩码。它包含了一些配置选项、图像处理功能以及处理多个图像的能力。主要的功能包括处理输入图像,生成掩码,剪裁图像和掩码,然后将它们保存到指定的输出目录。这个脚本还支持多进程处理,可以加快处理大量图像。

在上述代码中,"掩码" 是指一个二值图像,通常表示了一些区域的存在或缺失。数学形式表示掩码通常是一个矩阵(或图像),其中每个元素可以是二进制值(0或1),表示相应位置是否包含某种特征或信息。

具体地,如果我们考虑一个二维掩码矩阵,其中每个元素 (i, j) 的值为 1 表示该位置被覆盖或包含信息,值为 0 表示该位置没有信息或被遮挡。掩码通常用于图像处理和计算机视觉任务中,用于标识感兴趣的区域或对象。

例如,一个简单的数学形式的表示可以是:

  • 对于一个 2D 图像,M(i, j) 表示掩码矩阵中的元素,其中 (i, j) 是矩阵的坐标,M(i, j) 的值为 1 表示该位置包含信息,M(i, j) 的值为 0 表示该位置不包含信息。

掩码通常用于图像分割、遮挡区域检测、图像处理等任务,以便识别和操作图像中的感兴趣区域。在代码中,掩码用二维数组(NumPy数组)来表示,其中元素的值为0或1,这样可以方便地进行图像处理操作。

2.3.2  训练自己的数据集

python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

2.3.3  测试自己的数据集  

python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

2.3.4  测试结果参数及可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/122653.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Figma转Sketch文件教程,超简单!

相信大家做设计的都多多少少听过一点Figma和Sktech&#xff0c;这2个设计软件是目前市场上很受欢迎的专业UI设计软件&#xff0c;在全球各地都有很多粉丝用户。但是相对来说&#xff0c;Figma与Sketch只支持iOS系统有所不同&#xff0c;Figma是一个在线设计软件&#xff0c;不限…

TikTok shop美国小店适合哪些卖家做?附常见运营问题解答

一、Tiktok shop小店分类 大家都知道&#xff0c;美国小店可以分为5 种&#xff1a; 美国本土个人店: 最灵活&#xff0c;有扶持政策&#xff1b;美国法人企业店&#xff1a;要求高&#xff0c;有扶持政策&#xff1b;美国公司中国人占股店 (ACCU店) : 权重相对低&#xff0c…

Java版本spring cloud + spring boot企业电子招投标系统源代码

项目说明 随着公司的快速发展&#xff0c;企业人员和经营规模不断壮大&#xff0c;公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境&#xff0c;最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范&#xff0c;以及审…

《向经典致敬》第二届粤港澳大湾区著名歌唱家音乐会完美落幕

百年经典 歌坛盛会 “《向经典致敬》第二届粤港澳大湾区著名歌唱家音乐会暨2023福田人才之夜”完美落幕 2023年11月4日&#xff0c;阳光普照&#xff0c;秋意正浓&#xff0c;由中共深圳市福田区委宣传部、深圳市福田区文学艺术界联合会主办&#xff0c;深圳歌唱家协会承办&…

SpringBoot测试类启动web环境

1.坐标修改 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency> 2.测试类测试 说明&#xff1a;SpringBootTest()中的webEnvironment值的说明&#xff1b; 2.1不启…

VMware 虚拟机如何修改虚拟机系统的网卡速率为万兆——筑梦之路

1. 找到虚拟机系统安装目录 比如E:\vmware-system\kali\ 2. 找到vmx文件&#xff0c;用记事本打开 将 ethernet0.virtualDev "e1000" 这行改为 ethernet0.virtualDev "vmxnet3" 后保存&#xff08;注意vmxnet3全为小写&#xff09;&#xff0c;如果没…

Babylonjs学习笔记(九)——第一人称控制器

书接上回&#xff0c;实现第一人称控制器&#xff01;&#xff01;&#xff01; 以下步骤&#xff0c;缺一不可 相机相关设置 camera.applyGravity true; // 应用重力 camera.checkCollisions true; // 开启碰撞检测 const camera new FreeCamera("camera",ne…

Yakit工具篇:WebFuzzer模块之序列操作

简介 Web Fuzzer 序列就是将多个 Web Fuzzer 节点串联起来&#xff0c;实现更复杂的逻辑与功能。例如我们需要先进行登录&#xff0c;然后再进行其他操作&#xff0c;这时候我们就可以使用 Web Fuzzer 序列功能。或者是我们在一次渗透测试中需要好几个步骤才能验证是否有漏洞这…

Next.js 项目——从入门到入门(Eslint+Prettier)

Next.js官方文档地址 什么是 Next.js 这是一个用于生产环境的 React 框架。 Next.js 为您提供生产环境所需的所有功能以及最佳的开发体验&#xff1a;包括静态及服务器端融合渲染、 支持 TypeScript、智能化打包、 路由预取等功能&#xff0c;无需任何配置。 功能&#xff…

uniapp小程序接入腾讯云【增强版人脸核身接入】

文档地址&#xff1a;https://cloud.tencent.com/document/product/1007/56812 企业申请注册这边就不介绍了&#xff0c;根据官方文档去申请注册。 申请成功后&#xff0c;下载【微信小程序sdk】 一、解压sdk&#xff0c;创建wxcomponents文件夹 sdk解压后发现是原生小程序代…

数据结构-堆

一、什么是堆 先了解两种特别的二叉树 满二叉树 除最后一层无任何子节点外&#xff0c;每一层上的所有结点都有两个子结点的二叉树 完全二叉树 完全二叉树相对于满二叉树来说&#xff0c;最后一层叶子节点从左到右中间没有空缺的&#xff0c;像这样&#xff1a; 计算机科学…

驾考在线答题系统源码:含PC+手机版驾考宝典多题库

安装说明&#xff1a; 1、上传到网站根目录 2、用 phpMyadmin 导入数据库文件 db.sql 3、修改数据库链接文件 /ThinkPHP/Conf/convention.php# &#xff08;记得不要用记事本修改&#xff0c;否则可能会出现验证码显示不了问题&#xff0c;建议用 Notepad 4、 帐号 admin 密码…

Git 入门使用 —— 建库、代码上下传、常用命令

目录 一、Git 入门 1.1 Git简介 1.2 Git安装 1.3 创建码云仓库 二、Git 使用 2.1 git初始化操作 2.2 代码上传 2.3 代码下载 2.4 代码更新 2.4.1 仓库管理者 2.4.1 仓库使用者 三、Git 常用命令 一、Git 入门 1.1 Git简介 Git是一个开源的分布式版本控制系统&am…

Wireshark分析tcp交互过程

三次握手 客户端发起请求 Tcp段长度为575字节&#xff0c;seq1&#xff0c;ack1&#xff0c;next_seq576 服务器响应&#xff1a; Tcp段长度为175字节&#xff0c;seq1&#xff0c;ack576&#xff0c;next_seq176 客户端响应&#xff1a; Tcp段长度523字节&#xff0c;seq576&…

Lazarus安装和入门资料

azarus-2.2.6-fpc-3.2.2-win64 下载地址 Lazarus 基础教程 - Lazarus Tutorials for Beginners Lazarus Tutorial #1 - Learning programming_哔哩哔哩_bilibili https://www.devstructor.com/index.php?pagetutorials Lazarus是一款开源免费的object pascal语言RAD IDE&…

数据结构: 哈希桶

目录 1.概念 2.模拟实现 2.1框架 2.2哈希桶结构 2.3相关功能 Modify --Insert --Erase --Find 2.4非整型数据入哈希桶 1.仿函数 2.BKDR哈希 1.概念 具有相同地址的key值归于同一集合中,这个集合称为一个桶,各个桶的元素通过单链表链接 2.模拟实现 2.1框架 a.写出…

H5横屏适配方案

横屏模式一般使用场景比较少&#xff0c;特殊情况除外&#xff0c;一般用于游戏、操作性比较大的网页会采用横屏 整体代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" conte…

【第2章 Node.js基础】2.2 Node.js回调函数

学习目标 &#xff08;1&#xff09;理解Node.js的回调函数&#xff1b; &#xff08;2&#xff09;掌握回调函数的使用。 什么是回调函数 回调函数是一种特殊的函数&#xff0c;它作为参数传递给另一个函数&#xff0c;并在特定的事件或条件发生时被调用。回调函数通常用于异…

【Git】深入了解Git及其常用命令

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于Git的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.Git是什么 二.SVN和Git的区别 三.Git的…

二十、泛型(5)

本章概要 边界通配符 编译器有多聪明逆变无界通配符捕获转换 边界 边界&#xff08;bounds&#xff09;在本章的前面进行了简要介绍。边界允许我们对泛型使用的参数类型施加约束。尽管这可以强制执行有关应用了泛型类型的规则&#xff0c;但潜在的更重要的效果是我们可以在…