【深度学习】CodeFormer训练过程,如何训练人脸修复模型CodeFormer

文章目录

  • BasicSR介绍
  • 环境
  • 数据
  • 阶段 I - VQGAN
  • 阶段 II - CodeFormer (w=0)
  • 阶段 III - CodeFormer (w=1)

代码地址:https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

论文的一些简略介绍:
https://qq742971636.blog.csdn.net/article/details/134562550

BasicSR介绍

CodeFormer整个项目都沿袭BasicSR,了解一下BasicSR很有必要:

https://mp.csdn.net/mp_blog/creation/success/135674803

环境

# git clone this repository
git clone https://github.com/sczhou/CodeFormer
cd CodeFormer

# create new anaconda env
conda create -n codeformer python=3.8 -y
conda activate codeformer

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia

# install python dependencies
pip3 install -r requirements.txt
python basicsr/setup.py develop

conda install -c conda-forge dlib (only for face detection or cropping with dlib)

数据

找一些高清人脸数据1024*1024。

人脸数据需要对齐,对齐方式为: https://qq742971636.blog.csdn.net/article/details/135521146

阶段 I - VQGAN

训练VQGAN:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch
CUDA_VISIBLE_DEVICES=0,2,3 python -m torch.distributed.launch --nproc_per_node=3 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch # 指定三张显卡训练,对应VQGAN_512_ds32_nearest_stage1.yaml也是需要修改的

训练完VQGAN后,可以通过下面代码预先获得训练数据集的密码本序列,从而加速后面阶段的训练过程:

python scripts/generate_latent_gt.py

如果你不需要训练自己的VQGAN,可以在Release v0.1.0文档中找到预训练的VQGAN (vqgan_code1024.pth)和对应的密码本序列 (latent_gt_code1024.pth): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

打开日志查看训练过程:

tensorboard --logdir="/ssd/xiedong/CodeFormer/tb_logger/20240116_182107_VQGAN-512-ds32-nearest-stage1" --bind_all

在这里插入图片描述

VQGAN本身就是一个图生图的网络,在中间使用transformer将特征图转为embedding. 而 CodeFormer就是要利用这每张图的embedding来进行面部修复。

下面代码里用vqgan_code1024.pth获取训练数据的密码本,vqgan_code1024.pth的encoder输出的是2563232的特征图,由embedding给到1*1024,最终所有图保存为一个pytorch文件。

import argparse
import glob
import numpy as np
import os
import cv2
import torch
from torchvision.transforms.functional import normalize
from tqdm import tqdm

from basicsr.utils import imwrite, img2tensor, tensor2img

from basicsr.utils.registry import ARCH_REGISTRY

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--test_path', type=str, default='/ssd/xiedong/FFHQ/faces_hq_sr')
    parser.add_argument('-o', '--save_root', type=str, default='/ssd/xiedong/FFHQ/lt_output')
    parser.add_argument('--codebook_size', type=int, default=1024)
    parser.add_argument('--ckpt_path', type=str, default='/ssd/xiedong/CodeFormer/weights/vqgan/vqgan_code1024.pth')
    args = parser.parse_args()

    if args.save_root.endswith('/'):  # solve when path ends with /
        args.save_root = args.save_root[:-1]
    dir_name = os.path.abspath(args.save_root)
    os.makedirs(dir_name, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_path = args.test_path
    save_root = args.save_root
    ckpt_path = args.ckpt_path
    codebook_size = args.codebook_size

    vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
                                               codebook_size=codebook_size).to(device)
    checkpoint = torch.load(ckpt_path)['params_ema']

    vqgan.load_state_dict(checkpoint)
    vqgan.eval()

    sum_latent = np.zeros((codebook_size)).astype('float64')
    size_latent = 32
    latent = {}
    latent['orig'] = {}
    latent['hflip'] = {}
    for i in ['orig', 'hflip']:
        # for i in ['hflip']:
        for img_path in tqdm(sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g')))):
            img_name = os.path.basename(img_path)
            img = cv2.imread(img_path)
            if i == 'hflip':
                cv2.flip(img, 1, img)
            img = img2tensor(img / 255., bgr2rgb=True, float32=True)
            normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            img = img.unsqueeze(0).to(device)
            with torch.no_grad():
                # output = net(img)[0]
                # x, feat_dict = vqgan.encoder(img, True)
                x = vqgan.encoder(img)
                x, _, log = vqgan.quantize(x)
            # del output
            torch.cuda.empty_cache()

            min_encoding_indices = log['min_encoding_indices']
            min_encoding_indices = min_encoding_indices.view(size_latent, size_latent)
            latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
            print(img_name, latent[i][img_name[:-4]].shape)

    latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
    torch.save(latent, latent_save_path)
    print(f'\nLatent GT code are saved in {save_root}')

阶段 II - CodeFormer (w=0)

w=0 是需要模型完全追求抽象美学,w=1 是需要模型完全追求与原图相似。

在第一个阶段,得到了每张图对应的embedding。

训练密码本训练预测模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch

预训练CodeFormer第二阶段模型 (codeformer_stage2.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

阶段 III - CodeFormer (w=1)

训练可调模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch

预训练CodeFormer模型 (codeformer.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/347894.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mysql索引相关学习笔记:B+ Tree、索引分类、索引优化、索引失效场景及其他常见面试题

前言 索引是Mysql中常用到的一个功能,可以大大加快查询速度,同时面试中也是经常碰到。本文是学习Mysql索引的归纳总结。 索引采用的数据结构——B 树 本部分主要是参考自小林Coding B树的由来 二分查找可以每次缩减一半,从而提高查找效率…

【mongoDB】数据库的创建和删除

目录 1. 查看所有数据库 2.创建数据库 3.查看当前连接的数据库 4.删除数据库 1. 查看所有数据库 show dbs 2.创建数据库 use 数据库名 例如创建一个名为 aaa 的数据库 3.查看当前连接的数据库 db 4.删除数据库 use 数据库名 db.dropDataBase() 比如删除数据库 aaa

1.25号c++

1.引用 引用就是给变量起别名 格式: 数据类型 &引用名 同类型的变量名 (& 引用符号) eg: int a 10; int &b a; //b引用a,或者给a变量取个别名叫b int *p; //指针可以先定义 后指向 p &a; //int &a…

【MySQL】如何通过DDL去创建和修改员工信息表

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-fmKISDBsFq74ab2Z {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

docker(第二部分)

来自尚硅谷杨哥 少一点胡思乱想,心中无女人,编码自然神,忘掉心上人,抬手灭红尘。人间清醒,赚钱第一。好好学习,天天向上。听懂六六六。 7.Dokcer容器数据卷 1,)坑:容器卷记得加入 …

shared_ptr 与 unique_ptr 的转换 笔记

推荐B站文章: 6.shared_ptr与unique_ptr_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV18B4y187uL?p6&vd_sourcea934d7fc6f47698a29dac90a922ba5a3我的往期文章: 独占指针:unique_ptr 与 函数调用-CSDN博客https://blog.csdn.n…

银行数据仓库体系实践(5)--数据转换

数据转换作业主要是指在数据仓库内的结构化数据批量加工,对于非结构化数据以及在线查询接口、数据流的开发主要是遵循代码开发规范以及各中间件的开发规范,如使用java来开发遵守java开发规范,使用Kafka需要遵循Kafka的使用和设计规范。同时做…

对话泛能网程路:能源产业互联网,行至中程

泛能网的能源产业互联网的标杆价值还不仅于此。其在产业互联之外,也更大的特殊性在于其也更在成为整个碳市场的“辅助运营商”,包括电力、碳等一系列被泛能网帮助企业改造和沉淀的要素资产,都在构成着碳交易市场的未来底层。 这恰是产业互联…

有关Quick BI中Case子句中多次使用lod函数返回空值问题分析

一、Quick BI中的lod_ include函数 lod_ include {维度1[,维度2]...:聚合表达式[:过滤条件]} 作用:将表达式中的维度一起作为分组依据进行订算。其中, 1) 维度1[,维度2]... :声明维度,指定聚合表达式要连接到的一个或多个维…

开源项目Git Commit规范与ChangeLog

一,conventional commit(约定式提交) Conventional Commits 是一种用于给提交信息增加人机可读含义的规范。它提供了一组用于创建清晰的提交历史的简单规则。 1.1 作用 自动化生成 CHANGELOG基于提交类型,自动决定语义化的版本变更向项目相关合作开发…

OpenCV书签 #互信息的原理与相似图片搜索实验

1. 介绍 互信息(Mutual Information) 是信息论中的一个概念,用于衡量两个随机变量之间的关联程度。在图像处理和计算机视觉中,互信息常被用来度量两幅图像之间的相似性。 互信息可以看成是一个随机变量中包含的关于另一个随机变…

【网站项目】基于SSM的251国外摇滚乐队交流和周边售卖系统

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

DAY30:回溯算法332\51\37基本思路了解+总结

Leetcode: 332 重新安排行程 代码随想录 这道题目有几个难点: 一个行程中,如果航班处理不好容易变成一个圈,成为死循环,容易出现环路。有多种解法,字母序靠前排在前面,让很多同学望而退步,如…

yolov8上使用gpu教程

yolov8上使用gpu教程 安装Cuda和Cudnnyolov8上使用gpu 安装Cuda和Cudnn 1.查看支持的cuda版本,并去官网下载。 nvidia-smi2.网址:https://developer.nvidia.com/cuda-toolkit-archive 3.安装细节 安装的前提基础是,有vs的C环境。我电脑有…

多流转换 (分流,合流,基于时间的合流——双流联结 )

目录 一,分流 1.实现分流 2.使用侧输出流 二,合流 1,联合 2,连接 三,基于时间的合流——双流联结 1,窗口联结 1.1 窗口联结的调用 1.2 窗口联结的处理流程 2,间隔联结 2.1 间隔联…

Qt单选按钮

前言 本篇文章介绍Qt的单选按钮,就是QRadioButton QRadioButton是一个选项按钮,可以打开(选中)或关闭(取消选中)。单选按钮通常向用户提供“众多之一”的选择。 在一组单选按钮中,一次只能选中…

手动搭建koa+ts项目框架(apidoc文档篇)

文章目录 一、安装apidoc工具二、使用1、项目根目录新建apidoc.json2、定义接口路由上方注解对应信息3、配置静态文件访问目录4、生成api文档如有启发,可点赞收藏哟~ 一、安装apidoc工具 全局安装 npm i apidoc -g查看是否安装成功 apidoc -v二、使用 1、项目根…

注解@profile的使用

目录 profile介绍配置演示 profile介绍 profile 通常作用在controller类上,当它标记的环境有效时,该controller类才生效,例如:Profile(“dev”),当dev环境被激活时,Profile(“dev”)所注解的controller类才生效。 配置…

工作软技能第一弹,关于职场沟通、成长的那些事

引言 在谈绩效后,我收获了一些心得,在此梳理出来,加深印象并且共勉 基本信息 在步入职场后,你可能跟我一样虽然技术水平有在上升,但是在处理一些事情上可能偶尔没能获得预期的成果。我在通过绩效沟通以及自我反思后…

【学网攻】 第(10)节 -- 路由器单臂路由配置

系列文章目录 目录 系列文章目录 文章目录 前言 一、单臂路由是什么? 二、实验 1.引入 实验拓扑图 PC配置 Sw配置 Router配置 实验验证 总结 文章目录 【学网攻】 第(1)节 -- 认识网络【学网攻】 第(2)节 -- 交换机认识及使用【学网攻】 第(3)节 -- 交…