参数更新
参考文章:3dgs中的数学推导
协方差矩阵的参数更新
- 直接通过pytorch自带的更新机制,通过渲染后计算损失,只能更新2D协方差矩阵 Σ ′ \Sigma^\prime Σ′,再通过公式逆推出3d空间协方差矩阵 Σ \Sigma Σ的值。该过程处理矩阵计算多且复杂,计算效率低下。
- 为了提高计算效率,我们需要显示的表示
Σ
\Sigma
Σ,即利用前面学习到的将该矩阵拆分成旋转矩阵
R
R
R以及放缩矩阵
S
S
S
Σ = R S S ⊤ R ⊤ \boldsymbol{\Sigma}=\boldsymbol{R}\boldsymbol{S}\boldsymbol{S}^\top\boldsymbol{R}^\top Σ=RSS⊤R⊤ - 通过旋转四元数,我们可以进一步将要更新的R矩阵内的9个参数压缩到4个参数
q = q r + q i ⋅ i + q j ⋅ j + q k ⋅ k R ( q ) = 2 ( 1 2 − ( q j 2 + q k 2 ) ( q i q j − q r q k ) ( q i q k + q r q j ) ( q i q j + q r q k ) 1 2 − ( q i 2 + q k 2 ) ( q j q k − q r q i ) ( q i q k − q r q j ) ( q j q k + q r q i ) 1 2 − ( q i 2 + q j 2 ) ) \begin{gathered}\mathbf{q}=q_r+q_i\cdot i+q_j\cdot j+q_k\cdot k\\\mathbf{R}\left(\mathbf{q}\right)=2\begin{pmatrix}\frac12-\left(q_j^2+q_k^2\right)&(q_iq_j-q_rq_k)&(q_iq_k+q_rq_j)\\(q_iq_j+q_rq_k)&\frac12-\left(q_i^2+q_k^2\right)&(q_jq_k-q_rq_i)\\(q_iq_k-q_rq_j)&(q_jq_k+q_rq_i)&\frac12-\left(q_i^2+q_j^2\right)\end{pmatrix}\end{gathered} q=qr+qi⋅i+qj⋅j+qk⋅kR(q)=2 21−(qj2+qk2)(qiqj+qrqk)(qiqk−qrqj)(qiqj−qrqk)21−(qi2+qk2)(qjqk+qrqi)(qiqk+qrqj)(qjqk−qrqi)21−(qi2+qj2)
旋转四元数是一种用于表示三维空间中旋转的数学工具。它是四元数的一种特殊形式,由一个实部和三个虚部组成。
旋转四元数通常表示为q = w + xi + yj + zk,其中w是实部,(x, y, z)是虚部,i、j、k是虚数单位。这里需要满足四元数的数学性质:i² = j² = k² = ijk = -1。
旋转四元数的核心思想是,通过对旋转轴上的旋转角度进行编码,以及通过旋转轴的单位向量来表示旋转的方向。旋转四元数的实部(w)用于表示旋转角度的余弦值,而虚部(x, y, z)则表示旋转轴在单位向量上的三个分量。
- 放缩矩阵更不需要记录整个矩阵的信息,只需要记录其在三个轴方向的缩放比即可。
综上所述,协方差矩阵的更新转变为更新旋转四元数
q
q
q和一个含缩放比信息的三维向量
s
s
s。
3dgs推导了
q
q
q和
s
s
s的梯度,节约了自动微分的成本,具体可以参考3dgs原论文附录部分的梯度回传数学推导部分。
颜色的参数更新
官方代码中更新颜色部分的代码如下:
// Backward pass for conversion of spherical harmonics to RGB for
// each Gaussian.
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
{
// Compute intermediate values, as it is done during forward
glm::vec3 pos = means[idx];
glm::vec3 dir_orig = pos - campos;
glm::vec3 dir = dir_orig / glm::length(dir_orig);
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
// Use PyTorch rule for clamping: if clamping was applied,
// gradient becomes 0.
glm::vec3 dL_dRGB = dL_dcolor[idx];
dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
glm::vec3 dRGBdx(0, 0, 0);
glm::vec3 dRGBdy(0, 0, 0);
glm::vec3 dRGBdz(0, 0, 0);
float x = dir.x;
float y = dir.y;
float z = dir.z;
// Target location for this Gaussian to write SH gradients to
glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
// No tricks here, just high school-level calculus.
float dRGBdsh0 = SH_C0;
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
if (deg > 0)
{
float dRGBdsh1 = -SH_C1 * y;
float dRGBdsh2 = SH_C1 * z;
float dRGBdsh3 = -SH_C1 * x;
dL_dsh[1] = dRGBdsh1 * dL_dRGB;
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
dL_dsh[3] = dRGBdsh3 * dL_dRGB;
dRGBdx = -SH_C1 * sh[3];
dRGBdy = -SH_C1 * sh[1];
dRGBdz = SH_C1 * sh[2];
if (deg > 1)
{
float xx = x * x, yy = y * y, zz = z * z;
float xy = x * y, yz = y * z, xz = x * z;
float dRGBdsh4 = SH_C2[0] * xy;
float dRGBdsh5 = SH_C2[1] * yz;
float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
float dRGBdsh7 = SH_C2[3] * xz;
float dRGBdsh8 = SH_C2[4] * (xx - yy);
dL_dsh[4] = dRGBdsh4 * dL_dRGB;
dL_dsh[5] = dRGBdsh5 * dL_dRGB;
dL_dsh[6] = dRGBdsh6 * dL_dRGB;
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
dL_dsh[8] = dRGBdsh8 * dL_dRGB;
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
if (deg > 2)
{
float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
float dRGBdsh10 = SH_C3[1] * xy * z;
float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
dL_dsh[9] = dRGBdsh9 * dL_dRGB;
dL_dsh[10] = dRGBdsh10 * dL_dRGB;
dL_dsh[11] = dRGBdsh11 * dL_dRGB;
dL_dsh[12] = dRGBdsh12 * dL_dRGB;
dL_dsh[13] = dRGBdsh13 * dL_dRGB;
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
dL_dsh[15] = dRGBdsh15 * dL_dRGB;
dRGBdx += (
SH_C3[0] * sh[9] * 3.f * 2.f * xy +
SH_C3[1] * sh[10] * yz +
SH_C3[2] * sh[11] * -2.f * xy +
SH_C3[3] * sh[12] * -3.f * 2.f * xz +
SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
SH_C3[5] * sh[14] * 2.f * xz +
SH_C3[6] * sh[15] * 3.f * (xx - yy));
dRGBdy += (
SH_C3[0] * sh[9] * 3.f * (xx - yy) +
SH_C3[1] * sh[10] * xz +
SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
SH_C3[3] * sh[12] * -3.f * 2.f * yz +
SH_C3[4] * sh[13] * -2.f * xy +
SH_C3[5] * sh[14] * -2.f * yz +
SH_C3[6] * sh[15] * -3.f * 2.f * xy);
dRGBdz += (
SH_C3[1] * sh[10] * xy +
SH_C3[2] * sh[11] * 4.f * 2.f * yz +
SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
SH_C3[4] * sh[13] * 4.f * 2.f * xz +
SH_C3[5] * sh[14] * (xx - yy));
}
}
}
// The view direction is an input to the computation. View direction
// is influenced by the Gaussian's mean, so SHs gradients
// must propagate back into 3D position.
glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
// Account for normalization of direction
float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
// Gradients of loss w.r.t. Gaussian means, but only the portion
// that is caused because the mean affects the view-dependent color.
// Additional mean gradient is accumulated in below methods.
dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
}
- 输入参数包括:
idx: 代表当前高斯函数的索引;
deg: 指定球谐函数的阶数;
max_coeffs: 球谐函数的系数个数;
means: 存储每个高斯函数的均值向量;
campos: 相机位置;
shs: 存储每个高斯函数的球谐函数系数;
clamped: 存储每个高斯函数是否需要进行截断;
dL_dcolor: 目标颜色对 RGB 颜色空间的导数;
dL_dmeans: 目标颜色对高斯函数均值的导数;
dL_dshs: 目标颜色对球谐函数系数的导数。
- 该函数主要实现以下过程:
计算相机与当前高斯函数均值之间的方向向量;
根据 PyTorch 规则,如果某个高斯函数需要进行截断,则其梯度为 0;
计算 RGB 颜色空间中的梯度;
计算 RGB 颜色空间中每个分量对坐标轴的偏导数;
根据球谐函数的定义,计算球谐函数系数与 RGB 颜色空间中每个分量之间的导数关系;
根据相机位置和方向向量,计算目标颜色对高斯函数均值的导数;
将高斯函数均值的导数累加到 dL_dmeans 中;