BasicVSR++模型转JIT并用c++libtorch推理

BasicVSR++模型转JIT并用c++libtorch推理

文章目录

  • BasicVSR++模型转JIT并用c++libtorch推理
    • 安装BasicVSR++ 环境
      • 1.下载源码
      • 2. 新建一个conda环境
      • 3. 安装pytorch
      • 4. 安装 mim 和 mmcv-full
      • 5. 安装 mmedit
      • 6. 下载模型文件
      • 7. 测试一下能否正常运行
    • 转换为JIT模型
    • 用c++ libtorch推理
      • 效果

安装BasicVSR++ 环境

1.下载源码

git clone https://github.com/ckkelvinchan/BasicVSR_PlusPlus.git

2. 新建一个conda环境

conda create -n BasicVSRPLUSPLUS  python=3.8 -y
conda activate BasicVSRPLUSPLUS  

3. 安装pytorch

pytorch官网 安装合适的版本
我这里是CUDA11.6版本

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia

4. 安装 mim 和 mmcv-full

pip install openmim
mim install mmcv-full

5. 安装 mmedit

pip install mmedit

6. 下载模型文件

下载模型文件放在这里chkpts/basicvsr_plusplus_reds4.pth

7. 测试一下能否正常运行

python demo/restoration_video_demo.py configs/basicvsr_plusplus_reds4.py chkpts/basicvsr_plusplus_reds4.pth data/demo_000 results/demo_000

在这里插入图片描述
在这里插入图片描述
OK ! 环境正常!下面开始转换工作

转换为JIT模型

在demo下新建一个转换脚本

import os
import cv2
import mmcv
import numpy as np
import torch
from mmedit.core import tensor2img
from mmedit.apis import init_model


def main():
    # 加载模型并设置为评估模式
    model = init_model("configs/basicvsr_plusplus_reds4.py",
                       "chkpts/basicvsr_plusplus_reds4.pth", device=torch.device('cuda', 0))
    model.eval()

    # 准备一个示例输入
    src1 = cv2.imread("./data/img/00000000.png")
    src = cv2.cvtColor(src1, cv2.COLOR_BGR2RGB)
    src = torch.from_numpy(src / 255.).permute(2, 0, 1).float()
    src = src.unsqueeze(0)

    input_arg = torch.stack([src], dim=1)
    input_arg = input_arg.to(torch.device('cuda', 0))  # 确保输入在GPU上

    # # 执行模型推理
    # with torch.no_grad():  # 在推理时不需要计算梯度
    #     result = model(input_arg, test_mode=True)['output'].cpu()
    # output_i = tensor2img(result)
    # mmcv.imwrite(output_i, "./test.png")

    # 模型转换
    traced_model = torch.jit.trace(model.generator, input_arg)
    torch.jit.save(traced_model, "basicvsrPP.pt")

    # 测试
    res = traced_model(input_arg)
    out = tensor2img(res)
    mmcv.imwrite(out, "./testoo.png")


if __name__ == '__main__':
    main()

用c++ libtorch推理

/*
 * @Author: Liangbaikai
 * @LastEditTime: 2024-03-29 11:28:42
 * @Description: 视频超分
 * Copyright (c) 2024 by Liangbaikai, All Rights Reserved.
 */

#pragma once
#include <iostream>
#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <vector>
#include <c10/cuda/CUDACachingAllocator.h>

namespace LIANGBAIKAI_BASE_MODEL_NAME
{
    class lbk_video_super_resolution_basicPP
    {
    public:
        lbk_video_super_resolution_basicPP() = default;

        virtual ~lbk_video_super_resolution_basicPP()
        {
            c10::cuda::CUDACachingAllocator::emptyCache();
            // cudaDeviceReset();
        }

        /**
         * @description: 初始化
         * @param {string} &modelpath 模型文件
         * @param {int} gpuid GPU的id
         * @return {*}成功返回0,失败返回-1
         */
        int init(const std::string &modelpath, int gpuid = 0)
        {
            try
            {
                _mymodule = std::make_unique<torch::jit::script::Module>(torch::jit::load(modelpath));
            }
            catch (const c10::Error &e)
            {
                std::cerr << "Error loading the model " << modelpath << std::endl;
                std::cerr << "Error " << e.what() << std::endl;
                return -1;
            }
            _gpuid = gpuid;

            if ((_gpuid < 0) || (!torch::cuda::is_available()))
            {
                _device = std::make_unique<torch::Device>(torch::kCPU);
                _mymodule->to(at::kCPU);
            }
            else
            {
                _device = std::make_unique<torch::Device>(torch::kCUDA, _gpuid);
                _mymodule->to(at::kCUDA, _gpuid);
            }

            _mymodule->eval();
            _modelsuccess = true;
            return 0;
        }

        /**
         * @description: 推理
         * @param {Mat} &inputpic 输入图片
         * @param {Mat} &outputpic 输出结果
         * @param {bool} showlog  是否打印日志
         * @return {*} 成功返回0,失败返回-1
         */
        int inference(cv::Mat &inputpic, cv::Mat &outputpic, bool showlog = false)
        {
            if (inputpic.empty() || (inputpic.channels() != 3))
            {
                std::cout << "input data ERROR" << std::endl;
                return -1;
            }

            if (!_modelsuccess)
            {
                std::cout << "model has not been inited!" << std::endl;
                return -1;
            }

            // torch::DeviceGuard 是一个类,它的作用是确保在使用完设备(如CPU或GPU)后,能够正确地将设备恢复到使用前的状态。
            torch::DeviceGuard device_guard(*_device); // 作用域内所有操作都在指定设备上运行,离开此作用域恢复

            cv::transpose(inputpic, inputpic); // 顺时针旋转

            // 将图片转换为tensor
            cv::Mat img_float;
            inputpic.convertTo(img_float, CV_32FC3, 1.0 / 255);
            torch::Tensor img_tensor = torch::from_blob(img_float.data, {img_float.rows, img_float.cols, 3}, torch::kFloat32).permute({2, 1, 0});
            img_tensor = (img_tensor - 0.5) / 0.5;
            img_tensor = (img_tensor + 1) / 2;
            img_tensor = torch::clamp(img_tensor, 0, 1);

            torch::Tensor src_unsqueezed = img_tensor.unsqueeze(0).to(*_device); // 将tensor转移到GPU上

            std::vector<torch::Tensor> tensors_to_stack = {src_unsqueezed}; // 创建一个包含 src 的 vector

            torch::Tensor input_arg = torch::stack(tensors_to_stack, 1); // 沿着维度1堆叠tensors

            if (showlog)
            {
                std::cout << input_arg.sizes() << std::endl;
            }

            torch::NoGradGuard no_grad; // 暂时禁用梯度计算
            auto output_dict = _mymodule->forward({input_arg});

            torch::Tensor output_data;
            if (output_dict.isTensor())
            {
                output_data = output_dict.toTensor().to(at::kCPU); // 如果是Tensor,则通过toTensor()方法获取它
                if (showlog)
                {
                    std::cout << "out shape: " << output_data.sizes() << std::endl;
                }
            }
            else
            {
                if (showlog)
                {
                    std::cerr << "The IValue does not contain a Tensor." << std::endl;
                }
            }

            float *f = output_data.data_ptr<float>();

            int output_width = output_data.size(3);
            int output_height = output_data.size(4);
            int size_pic = output_width * output_height;

            std::vector<cv::Mat> rgbChannels(3);
            rgbChannels[0] = cv::Mat(output_width, output_height, CV_32FC1, f);
            rgbChannels[1] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic);
            rgbChannels[2] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic + size_pic);

            rgbChannels[0].convertTo(rgbChannels[0], CV_8UC1, 255);
            rgbChannels[1].convertTo(rgbChannels[1], CV_8UC1, 255);
            rgbChannels[2].convertTo(rgbChannels[2], CV_8UC1, 255);

            merge(rgbChannels, outputpic);

            return 0;
        }

    private:
        bool _modelsuccess = false;
        int _gpuid = 0;
        std::unique_ptr<torch::Device> _device;
        std::unique_ptr<torch::jit::script::Module> _mymodule;
    };

}
#include <unistd.h>
#include"lbk_video_super_resolution.hpp"
using namespace LIANGBAIKAI_BASE_MODEL_NAME;
int main(int argc,char *argv[])
{
    if(argc < 5){
        std::cout << "./test 模型  GPUid(cpu传-1) 输入图片 输出图片" << std::endl;
        return -1;
    }
    std::string modelfile = argv[1];
    int gpuid = atoi(argv[2]);
    std::string imgfile = argv[3];
    std::string outfile = argv[4];

    cv::Mat src = cv::imread(imgfile);


    lbk_video_super_resolution_basicPP test;
    if(0 > test.init(modelfile,gpuid)){
        std::cout << "init failed" << std::endl;
        return -1;
    }
    cv::Mat out;
    int rec = test.inference(src,out,true);
    if(rec >= 0){
        cv::imwrite(outfile, out);
    }

    return 0;
}

效果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/500927.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

只出现一次的数字 II

题目链接 只出现一次的数字 II 题目描述 注意点 nums中&#xff0c;除某个元素仅出现一次外&#xff0c;其余每个元素都恰出现三次设计并实现线性时间复杂度的算法且使用常数级空间来解决此问题 解答思路 本题与只出现一次的数字的数字类似&#xff0c;区别是重复的数字会…

深度学习InputStreamReader类

咦咦咦&#xff0c;各位小可爱&#xff0c;我是你们的好伙伴——bug菌&#xff0c;今天又来给大家普及Java SE相关知识点了&#xff0c;别躲起来啊&#xff0c;听我讲干货还不快点赞&#xff0c;赞多了我就有动力讲得更嗨啦&#xff01;所以呀&#xff0c;养成先点赞后阅读的好…

SpringMVC注解及使用规则

文章目录 前言一、SpringMVC注解是什么&#xff1f;二、使用步骤1.注解使用2创建JSP3 SpringMVC视图1. 逻辑视图&#xff08;Logical View&#xff09;2. 物理视图&#xff08;Physical View&#xff09;区别和关系 4 SpringMVC注解总结 总结 前言 提示&#xff1a;这里可以添…

Java运算符-三元运算符,这你必须得会!

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一个人虽可以走的更快&#xff0c;但一群人可以走的更远。 我是一名后…

08-研发流程设计(上):如何设计Go项目的开发流程?

在Go 项目开发中&#xff0c;我们不仅要完成产品功能的开发&#xff0c;还要确保整个过程是高效的&#xff0c;代码是高质量的。 所以&#xff0c;Go 项目开发一定要设计一个合理的研发流程&#xff0c;来提高开发效率、减少软件维护成本。研发流程会因为项目、团队和开发模式…

基于SSM的百货中心供应链管理系统设计与实现(论文+源码)_kaic

摘 要 社会发展日新月异&#xff0c;用计算机应用实现数据管理功能已经算是很完善的了&#xff0c;但是随着移动互联网的到来&#xff0c;处理信息不再受制于地理位置的限制&#xff0c;处理信息及时高效&#xff0c;备受人们的喜爱。本次开发一套百货中心供应链管理系统有管理…

政安晨:【Keras机器学习实践要点】(九)—— 保存、序列化和导出模型

政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff01; 这篇文章是保存、序列化和导出模型的完整指南。 …

python实战之进阶篇(一)

定义类 1. 构造方法 2. 实例方法 3. 类方法 类似于Java中的静态方法, 使用方式: 类名.类方法 4. 私有变量 5. 私有方法 6. 使用属性set和get

vue3+threejs新手从零开发卡牌游戏(二十一):添加战斗与生命值关联逻辑

首先将双方玩家的HP存入store中&#xff0c;stores/common.ts代码如下&#xff1a; import { ref, computed } from vue import { defineStore } from piniaexport const useCommonStore defineStore(common, () > {const _font ref() // 字体const p1HP ref(4000) // 己…

4、Cocos Creator 动画系统

目录 1、Clip 参数 2、动画编辑器 3、基本操作 更改时间轴缩放比例 移动显示区域 更改当前选中的时间轴节点 播放 / 暂停动画 修改 clip 属性 快捷键 4、模拟实验 5、动画事件 6、注意事项 参考 Animation 组件是节点上的一个组件。Clip 动画剪辑就是一份动画的声…

【edge浏览器无法登录某些网站,以及迅雷插件无法生效的解决办法】

edge浏览器无法登录某些网站&#xff0c;以及迅雷插件无法生效的解决办法 edge浏览器无法登录某些网站&#xff0c;但chrome浏览器可以登录浏览器插件无法使用&#xff0c;比如迅雷如果重装插件重装浏览器重装迅雷后仍然出现问题 edge浏览器无法登录某些网站&#xff0c;但chro…

InfoNCE loss

InfoNCE loss是一种用于自监督学习的损失函数&#xff0c;通常用于训练对比学习模型&#xff0c;如自编码器或神经网络。全称是"InfoNCE: Contrastive Estimation of Neural Entropy"&#xff0c;基于对比学习的思想&#xff0c;旨在最大化正样本的相似性&#xff0c…

QSplashScreen

以前打红警的时候进入游戏界面会有一个启动界面&#xff0c;比如美国是有伞兵&#xff0c;英国有狙击手&#xff0c;韩国有黑鹰战机的一些介绍&#xff0c;这些就是启动界面&#xff0c;就是由QSplashScreen这个类来实现的。 QSplashScreen 是 Qt 框架中的一个类&#xff0c;用…

【每日跟读】常用英语500句(300~400)

【每日跟读】常用英语500句 I had to take a shower. 我洗了个澡 Go on in. 赶紧进去吧 Hold up. 等一下 They seem like nice people. 他们看起来像好人 Such a wonderful age. 如此美好的年纪 That’s very impressive. 真厉害 I can see that. 看得出来 You should …

绘制多个box箱型图

1.首先生成随机数据 import random # 创建一个153629行&#xff0c;13列的数据&#xff0c; random_data np.random.randn(153629, 13) #创建数据标签&#xff0c;后续将根据数据标签绘制不同的箱型图 label [0,1,2,3] labels np.asarray(random.choices(label,k 153629))…

U盘文件突然消失:原因分析与恢复策略

U盘遭遇“幽灵”之手&#xff0c;文件不翼而飞 你是否曾遭遇过这样的诡异情况&#xff1a;前一天还好好存放在U盘里的文件&#xff0c;第二天却突然消失得无影无踪&#xff1f;这简直就像是一场无声的灾难&#xff0c;令人措手不及。U盘作为我们日常工作和生活中不可或缺的数据…

Gitea 的详细介绍

什么是 Gitea&#xff1f; Gitea 是一个开源、轻量级的自托管 Git 服务&#xff0c;它允许用户搭建类似于 GitHub 或 GitLab 的代码托管平台。由于采用 Go 语言开发&#xff0c;Gitea 具有高效的性能和跨平台特性&#xff0c;适合个人开发者或小团队使用。 Gitea 的特点 轻量…

java数据结构与算法刷题-----LeetCode34. 在排序数组中查找元素的第一个和最后一个位置

java数据结构与算法刷题目录&#xff08;剑指Offer、LeetCode、ACM&#xff09;-----主目录-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/123063846 文章目录 二分查找 二分查找 解题思路&#xff1a;时间复杂度O( l o g 2 …

读所罗门的密码笔记05_新的力量平衡

1. 技术发展 1.1. 美国和中国俨然成为人工智能研究、开发和部署方面无可争议的领导者 1.2. 人类从未彻底阻止某种技术的发展&#xff0c;虽然不同国家在不同时间对克隆、化学武器和核武器等技术采取了暂停或禁止措施&#xff0c;但我们仍在继续推进最前…

mysql之MyBatis核心工作原理

MyBatis核心工作原理 一、源码环境 1.手动编译源码 工欲善其事必先利其器。为了方便我们在看源码的过程中能够方便的添加注释&#xff0c;我们可以自己来从官网下载源码编译生成对应的Jar包&#xff0c;然后上传到本地maven仓库&#xff0c;再引用这个Jar。大家可以自行去官…