基于可微分渲染器的相机位置优化【PyTorch3D】

在这个教程中,我们将使用可微渲染学习给定参考图像的相机的 [x, y, z] 位置。

我们将首先使用相机的起始位置初始化渲染器。 然后,我们将使用它来生成图像,使用参考图像计算损失,最后通过整个管道进行反向传播以更新相机的位置。

NSDT在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器

本教程展示如何:

  • 从 .obj 文件加载网格
  • 初始化相机、着色器和渲染器,
  • 渲染网格
  • 使用损失函数和优化器设置优化循环

首先确保已安装torch和torchvision,并使用以下代码安装pytorch3d:

import os
import sys
import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("2.1.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

导入模块:

import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte

# io utils
from pytorch3d.io import load_obj

# datastructures
from pytorch3d.structures import Meshes

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)

1、加载obj模型

我们将加载一个 obj 文件并创建一个 Meshes 对象。 网格是 PyTorch3D 中提供的独特数据结构,用于处理批量不同大小的网格。 它有几个在渲染管道中使用的有用的类方法:

# Set the cuda device 
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
teapot_mesh = Meshes(
    verts=[verts.to(device)],   
    faces=[faces.to(device)], 
    textures=textures
)

如果在克隆 PyTorch3D 存储库后在本地运行此笔记本,则网格将已经可用。 如果使用 Google Colab,请获取网格并将其保存在路径 data/ 中:

!mkdir -p data
!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj

2、优化设置

PyTorch3D 中的渲染器由光栅器和着色器组成,每个组件都有许多子组件,例如相机(正交/透视)。 在这里,我们初始化其中一些组件,并对其余组件使用默认值。

2.1 创建渲染器

为了优化相机位置,我们将使用渲染器,它仅生成对象的轮廓,而不应用任何照明或阴影。 我们还将初始化另一个应用完整 Phong 着色的渲染器,并使用它来可视化输出。

# Initialize a perspective camera.
cameras = FoVPerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)

2.2 创建参考图像

我们将首先定位茶壶并生成图像。 我们使用辅助函数将茶壶旋转到所需的视角。 然后我们可以使用渲染器来生成图像。 在这里,我们将使用两个渲染器并可视化轮廓和全着色图像。

世界坐标系定义为+Y向上、+X向左和+Z向内。世界坐标中的茶壶的壶嘴指向左侧。

我们定义了一个位于 z 轴正方向上的相机,因此可以看到右侧的喷口。

# Select the viewpoint using spherical angles  
distance = 3   # distance from camera to the object
elevation = 50.0   # angle of elevation in degrees
azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. 

# Get the position of the camera based on the spherical angles
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)

# Render the teapot providing the values of R and T. 
silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)

silhouette = silhouette.cpu().numpy()
image_ref = image_ref.cpu().numpy()

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(silhouette.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image
plt.grid(False)
plt.subplot(1, 2, 2)
plt.imshow(image_ref.squeeze())
plt.grid(False)

输出如下:

2.3 创建基础模型

这里我们创建一个简单的模型类并初始化相机位置的参数。

class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer
        
        # Get the silhouette of the reference RGB image by finding all non-white pixel values. 
        image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))
        self.register_buffer('image_ref', image_ref)
        
        # Create an optimizable parameter for the x, y, z position of the camera. 
        self.camera_position = nn.Parameter(
            torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))

    def forward(self):
        
        # Render the image using the updated camera position. Based on the new position of the 
        # camera we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        
        # Calculate the silhouette loss
        loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        return loss, image

3、初始化模型和优化器

现在我们可以创建上述模型的实例并为相机位置参数设置优化器。

# We will save images periodically and compose them into a GIF.
filename_output = "./teapot_optimization_demo.gif"
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)

# Initialize a model using the renderer, mesh and reference image
model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

可视化起始位置和参考位置:

plt.figure(figsize=(10, 10))

_, image_init = model()
plt.subplot(1, 2, 1)
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Starting position")

plt.subplot(1, 2, 2)
plt.imshow(model.image_ref.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette");

输出如下:

4、运行优化

我们运行前向和后向传递的多次迭代,并每 10 次迭代保存输出。

loop = tqdm(range(200))
for i in loop:
    optimizer.zero_grad()
    loss, _ = model()
    loss.backward()
    optimizer.step()
    
    loop.set_description('Optimizing (loss %.4f)' % loss.data)
    
    if loss.item() < 200:
        break
    
    # Save outputs to create a GIF. 
    if i % 10 == 0:
        R = look_at_rotation(model.camera_position[None, :], device=model.device)
        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = img_as_ubyte(image)
        writer.append_data(image)
        
        plt.figure()
        plt.imshow(image[..., :3])
        plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
        plt.axis("off")
    
writer.close()

迭代期间输出如下:

完成后可以查看 ./teapot_optimization_demo.gif,优化过程的炫酷 gif:

5、结束语

在本教程中,我们学习了如何从 obj 文件加载网格,初始化名为 Meshes 的 PyTorch3D 数据结构,设置由光栅化器和着色器组成的渲染器,设置包括模型和损失函数的优化循环,并运行优化。


原文链接:用可微渲染优化相机位置 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/194290.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue+SpringBoot的企业项目合同信息系统

项目编号&#xff1a; S 046 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S046&#xff0c;文末获取源码。} 项目编号&#xff1a;S046&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 合同审批模块2.3 合…

Flask Echarts 实现历史图形查询

Flask前后端数据动态交互涉及用户界面与服务器之间的灵活数据传递。用户界面使用ECharts图形库实时渲染数据。它提供了丰富多彩、交互性强的图表和地图&#xff0c;能够在网页上直观、生动地展示数据。ECharts支持各种常见的图表类型&#xff0c;包括折线图、柱状图、饼图、散点…

FFmepg 核心开发库及重要数据结构与API

文章目录 前言一、FFmpeg 核心开发库二、FFmpeg 重要数据结构与 API1、简介2、FFmpeg 解码流程①、FFmpeg2.x 解码流程②、FFmpeg4.x 解码流程 3、FFMpeg 中比较重要的函数以及数据结构①、数据结构②、初始化函数③、音视频解码函数④、文件操作⑤、其他函数 三、FFmpeg 流程1…

【活动回顾】sCrypt在柏林B2029开发者周

B2029 是柏林的一个区块链爱好者、艺术家和建设者聚会&#xff0c;学习、讨论和共同构建比特币区块链地方。 在2023年6月9日至11日&#xff0c;举行了第7次Hello Metanet研讨会。本次研讨会旨在为参与者提供一个学习、讨论和共同构建比特币区块链的平台。 在这个充满激情和创意…

C语言:输出所有“水仙花数”。“水仙花数”是指一个3位数,其各位数字的立方和等于该数本身,如153=1^3 +5^3+3^3

分析&#xff1a; 在主函数 main 中&#xff0c;程序首先定义四个整型变量 m、a、b 和 c&#xff0c;并用于计算和判断水仙花数。然后使用 printf 函数输出提示信息。 接下来&#xff0c;程序使用 for 循环结构&#xff0c;从 100 到 999 遍历所有三位数。对于每个遍历到的数 m…

Vue简易的车牌输入键盘,可以根据需要修改

效果图如下&#xff1a; 代码如下&#xff1a; <template><div><div class"carNoBoxInput"><div style"padding: 6px;border: 2px solid #fff;border-radius: 6px;margin: 6px 3px 6px 6px;"><input class"inputBox"…

Make sure bypassing Vue built-in sanitization is safe here.

一、问题描述 二、问题分析 XSS(跨站脚本攻击) XSS攻击通常指的是通过利用网页开发时留下的漏洞&#xff0c;通过巧妙的方法注入恶意指令代码到网页&#xff0c;使用户加载并执行攻击者恶意制造的网页程序。这些恶意网页程序通常是JavaScript&#xff0c;但实际上也可以包括J…

https到底把什么加密了?

首先直接说结论&#xff0c; https安全通信模式&#xff0c;是使用TLS加密传输所有的http协议。再重复一遍&#xff0c;是所有&#xff01; 通常将TLS加密传输http这个通信过程称为https&#xff0c;如果使用协议封装的逻辑结构来表达就是&#xff1a; IP TCP TLS 【 HTTP 】…

大连大学2023年11月程序设计竞赛(同步赛)

B、爆wa种子!&#xff08;数学&#xff09; 一、题目要求 链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 爆wa种子发现了上次玩游戏时你和妙wa种子的py交易&#xff0c;所以他要求这次玩游戏你来当爆wa种子的枪手&#xff0c;为他写个程序…

前端 --- HTML

目录 一、网络的三大基石 ​二、什么是HTML 一、HTML 指的是超文本标记语言 二、HTML的作用 三、HTML的标准结构 四、IDE_HBuilder的使用 一、编码工具&#xff1a; 二、集成开发环境 三、HBuilder使用步骤&#xff1a; 五、HTML的标签的使用 一、html_head_body 二、head…

CSS新手入门笔记整理:CSS字体样式

字体类型&#xff1a;font-family 语法 font-family&#xff1a;字体1,字体2,...,字体n; font-family可以指定多种字体。使用多个字体时&#xff0c;将按从左到右的顺序排列&#xff0c;并且以英文逗号&#xff08;,&#xff09;隔开。如果我们不定义font-family&#xff0c…

viple模拟器使用(三):unity模拟器中实现沿右墙迷宫算法

沿右墙迷宫算法 引导 线控模拟可以使得通过用户手动操作&#xff0c;实现机器人在模拟环境下在迷宫中行走&#xff08;即&#xff1a;运动&#xff09;&#xff0c;算法可以使得机器人按照一定的策略自动行走&#xff0c;沿右墙迷宫算法就是其中的一种策略。 目的 运行程序后&…

Scrapy框架内置管道之图片视频和文件(一篇文章齐全)

1、Scrapy框架初识&#xff08;点击前往查阅&#xff09; 2、Scrapy框架持久化存储&#xff08;点击前往查阅&#xff09; 3、Scrapy框架内置管道 4、Scrapy框架中间件&#xff08;点击前往查阅&#xff09; Scrapy 是一个开源的、基于Python的爬虫框架&#xff0c;它提供了…

3D模型纹理集合并【Python|C#】

使用 Substance Painter 时&#xff0c;将模型的各个部分分成不同的纹理集非常有用。 这可以帮助遮罩&#xff0c;或者只是保持层栈干净。 不幸的是&#xff0c;Painter 无法将多个纹理集中的所有贴图导出为单个图集&#xff0c;即使在创建单独对象的 UV 时考虑到了这一点。 显…

Django创建基本的app应用并配置URL路径-成功运行服务

开发环境&#xff1a;Pycharm2021 Win11 首先创建虚拟环境: 可参考&#xff1a; Pycharm开发环境下创建python运行的虚拟环境&#xff08;自动执行安装依赖包&#xff09;_pycharm自动下载依赖包_heda3的博客-CSDN博客 1、安装 Django 在虚拟环境下安装pip install django …

ES 8.x开始(docker-compose安装、kibana使用、java操作)

学习文档地址 一、Docker安装 这里使用docker-compose来安装&#xff0c;方便后续迁移&#xff0c;Elasticserach和kibina一起安装。 1、创建安装目录 configdataplugins 2、配置文件 配置文件有两个&#xff0c;一个是ES的配置文件&#xff0c;一个docker-compose的配置文件 …

DS图—图的最短路径/Dijkstra算法【数据结构】

DS图—图的最短路径/Dijkstra算法【数据结构】 题目描述 给出一个图的邻接矩阵&#xff0c;输入顶点v&#xff0c;用迪杰斯特拉算法求顶点v到其它顶点的最短路径。 输入 第一行输入t&#xff0c;表示有t个测试实例 第二行输入顶点数n和n个顶点信息 第三行起&#xff0c;每行…

龙芯loongarch64服务器编译安装pyarrow

1、简介 pyarrow是一个高效的Python库,用于在Python应用程序和Apache Arrow之间进行交互。Arrow是一种跨语言的内存格式,可以快速高效地转移大型数据集合。它提供了一种通用的数据格式,将数据在内存中表示为表格,并支持诸如序列化和分布式读取等功能。 龙芯的Python仓库安…

Mo0n(月亮) MCGS触摸屏在野0day利用,强制卡死锁屏

项目:https://github.com/MartinxMax/Mo0n 后面还会不会在,我可就不知道了奥…还不收藏点赞关注 扫描存在漏洞的设备 #python3 Mo0n.py -scan 192.168.0.0/24 入侵锁屏 #python3 Mo0n.py -rhost 192.168.0.102 -lock 解锁 #python3 Mo0n.py -rhost 192.168.0.102 -unlock …

Linux(10):Shell scripts

什么是 Shell scripts shell script&#xff08;程序化脚本&#xff09;&#xff1a;shell 部分是一个文字接口下让我们与系统沟通的一个工具接口&#xff1b;script 是脚本的意思&#xff0c;shell script 就是针对 shell 写的脚本。 shell script 是利用 shell 的功能所写的…