|
|
from io import BytesIO |
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import importlib |
|
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
import copy |
|
|
|
|
|
class EmbedContainer(nn.Module): |
|
|
def __init__(self, tensor): |
|
|
super().__init__() |
|
|
self.tensor = nn.Parameter(tensor) |
|
|
|
|
|
def forward(self): |
|
|
return self.tensor |
|
|
|
|
|
@torch.no_grad |
|
|
def zero_init(module): |
|
|
if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear: |
|
|
module.weight.zero_() |
|
|
module.bias.zero_() |
|
|
return module |
|
|
|
|
|
def import_str(string): |
|
|
|
|
|
module, cls = string.rsplit(".", 1) |
|
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
""" |
|
|
from https://github.com/Kai-46/minFM/blob/main/utils/ema.py |
|
|
Exponential Moving Average (EMA) utilities for PyTorch models. |
|
|
|
|
|
This module provides utilities for maintaining and updating EMA models, |
|
|
which are commonly used to improve model stability and generalization |
|
|
in training deep neural networks. It supports both regular tensors and |
|
|
DTensors (from FSDP-wrapped models). |
|
|
""" |
|
|
class EMA_FSDP: |
|
|
def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999): |
|
|
self.decay = decay |
|
|
self.shadow = {} |
|
|
self._init_shadow(fsdp_module) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _init_shadow(self, fsdp_module): |
|
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
if isinstance(fsdp_module, FSDP): |
|
|
with FSDP.summon_full_params(fsdp_module, writeback=False): |
|
|
for n, p in fsdp_module.module.named_parameters(): |
|
|
self.shadow[n] = p.detach().clone().float().cpu() |
|
|
else: |
|
|
for n, p in fsdp_module.named_parameters(): |
|
|
self.shadow[n] = p.detach().clone().float().cpu() |
|
|
|
|
|
@torch.no_grad() |
|
|
def update(self, fsdp_module): |
|
|
d = self.decay |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
if isinstance(fsdp_module, FSDP): |
|
|
with FSDP.summon_full_params(fsdp_module, writeback=False): |
|
|
for n, p in fsdp_module.module.named_parameters(): |
|
|
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) |
|
|
else: |
|
|
for n, p in fsdp_module.named_parameters(): |
|
|
print(n, self.shadow[n]) |
|
|
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) |
|
|
|
|
|
|
|
|
def state_dict(self): |
|
|
return self.shadow |
|
|
|
|
|
def load_state_dict(self, sd): |
|
|
self.shadow = {k: v.clone() for k, v in sd.items()} |
|
|
|
|
|
def copy_to(self, fsdp_module): |
|
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
with FSDP.summon_full_params(fsdp_module, writeback=True): |
|
|
for n, p in fsdp_module.module.named_parameters(): |
|
|
if n in self.shadow: |
|
|
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) |
|
|
|
|
|
def create_raymaps(cameras, h, w): |
|
|
rays_o, rays_d = create_rays(cameras, h, w) |
|
|
raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1) |
|
|
return raymaps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EMANorm(nn.Module): |
|
|
def __init__(self, beta): |
|
|
super().__init__() |
|
|
self.register_buffer('magnitude_ema', torch.ones([])) |
|
|
self.beta = beta |
|
|
|
|
|
def forward(self, x): |
|
|
if self.training: |
|
|
magnitude_cur = x.detach().to(torch.float32).square().mean() |
|
|
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta)) |
|
|
input_gain = self.magnitude_ema.rsqrt() |
|
|
x = x.mul(input_gain) |
|
|
return x |
|
|
|
|
|
class TimestepEmbedding(nn.Module): |
|
|
def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True): |
|
|
super().__init__() |
|
|
self.max_period = max_period |
|
|
self.time_factor = time_factor |
|
|
self.dim = dim |
|
|
if zero_weight: |
|
|
self.weight = nn.Parameter(torch.zeros(dim)) |
|
|
else: |
|
|
self.weight = None |
|
|
|
|
|
def forward(self, t): |
|
|
if self.weight is None: |
|
|
return timestep_embedding(t, self.dim, self.max_period, self.time_factor) |
|
|
else: |
|
|
return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0) |
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) |
|
|
def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0): |
|
|
""" |
|
|
Create sinusoidal timestep embeddings. |
|
|
:param t: a 1-D Tensor of N indices, one per batch element. |
|
|
These may be fractional. |
|
|
:param dim: the dimension of the output. |
|
|
:param max_period: controls the minimum frequency of the embeddings. |
|
|
:return: an (N, D) Tensor of positional embeddings. |
|
|
""" |
|
|
t = time_factor * t |
|
|
half = dim // 2 |
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) |
|
|
|
|
|
args = t[:, None].float() * freqs[None] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2: |
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
|
if torch.is_floating_point(t): |
|
|
embedding = embedding.to(t) |
|
|
return embedding |
|
|
|
|
|
def quaternion_to_matrix(quaternions): |
|
|
""" |
|
|
Convert rotations given as quaternions to rotation matrices. |
|
|
Args: |
|
|
quaternions: quaternions with real part first, |
|
|
as tensor of shape (..., 4). |
|
|
Returns: |
|
|
Rotation matrices as tensor of shape (..., 3, 3). |
|
|
""" |
|
|
r, i, j, k = torch.unbind(quaternions, -1) |
|
|
two_s = 2.0 / (quaternions * quaternions).sum(-1) |
|
|
|
|
|
o = torch.stack( |
|
|
( |
|
|
1 - two_s * (j * j + k * k), |
|
|
two_s * (i * j - k * r), |
|
|
two_s * (i * k + j * r), |
|
|
two_s * (i * j + k * r), |
|
|
1 - two_s * (i * i + k * k), |
|
|
two_s * (j * k - i * r), |
|
|
two_s * (i * k - j * r), |
|
|
two_s * (j * k + i * r), |
|
|
1 - two_s * (i * i + j * j), |
|
|
), |
|
|
-1, |
|
|
) |
|
|
return o.reshape(quaternions.shape[:-1] + (3, 3)) |
|
|
|
|
|
|
|
|
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert a unit quaternion to a standard form: one in which the real |
|
|
part is non negative. |
|
|
|
|
|
Args: |
|
|
quaternions: Quaternions with real part first, |
|
|
as tensor of shape (..., 4). |
|
|
|
|
|
Returns: |
|
|
Standardized quaternions as tensor of shape (..., 4). |
|
|
""" |
|
|
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) |
|
|
|
|
|
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Returns torch.sqrt(torch.max(0, x)) |
|
|
but with a zero subgradient where x is 0. |
|
|
""" |
|
|
ret = torch.zeros_like(x) |
|
|
positive_mask = x > 0 |
|
|
if torch.is_grad_enabled(): |
|
|
ret[positive_mask] = torch.sqrt(x[positive_mask]) |
|
|
else: |
|
|
ret = torch.where(positive_mask, torch.sqrt(x), ret) |
|
|
return ret |
|
|
|
|
|
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert rotations given as rotation matrices to quaternions. |
|
|
|
|
|
Args: |
|
|
matrix: Rotation matrices as tensor of shape (..., 3, 3). |
|
|
|
|
|
Returns: |
|
|
quaternions with real part first, as tensor of shape (..., 4). |
|
|
""" |
|
|
if matrix.size(-1) != 3 or matrix.size(-2) != 3: |
|
|
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") |
|
|
|
|
|
batch_dim = matrix.shape[:-2] |
|
|
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( |
|
|
matrix.reshape(batch_dim + (9,)), dim=-1 |
|
|
) |
|
|
|
|
|
q_abs = _sqrt_positive_part( |
|
|
torch.stack( |
|
|
[ |
|
|
1.0 + m00 + m11 + m22, |
|
|
1.0 + m00 - m11 - m22, |
|
|
1.0 - m00 + m11 - m22, |
|
|
1.0 - m00 - m11 + m22, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
quat_by_rijk = torch.stack( |
|
|
[ |
|
|
|
|
|
|
|
|
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), |
|
|
|
|
|
|
|
|
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), |
|
|
|
|
|
|
|
|
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), |
|
|
|
|
|
|
|
|
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) |
|
|
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) |
|
|
|
|
|
|
|
|
|
|
|
indices = q_abs.argmax(dim=-1, keepdim=True) |
|
|
expand_dims = list(batch_dim) + [1, 4] |
|
|
gather_indices = indices.unsqueeze(-1).expand(expand_dims) |
|
|
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2) |
|
|
return standardize_quaternion(out) |
|
|
|
|
|
@torch.amp.autocast(device_type="cuda", enabled=False) |
|
|
def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None): |
|
|
B, N = cameras.shape[:2] |
|
|
|
|
|
c2ws = torch.zeros(B, N, 3, 4, device=cameras.device) |
|
|
|
|
|
c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4]) |
|
|
c2ws[..., :, 3] = cameras[..., 4:7] |
|
|
|
|
|
_c2ws = c2ws |
|
|
|
|
|
ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c |
|
|
_c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :] |
|
|
|
|
|
if n_frame is not None: |
|
|
T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm |
|
|
else: |
|
|
T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm |
|
|
|
|
|
_c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2) |
|
|
|
|
|
R = matrix_to_quaternion(_c2ws[..., :3, :3]) |
|
|
T = _c2ws[..., :3, 3] |
|
|
cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1) |
|
|
|
|
|
if return_meta: |
|
|
return cameras, ref_w2c, T_norm |
|
|
else: |
|
|
return cameras |
|
|
|
|
|
def create_rays(cameras, h, w, uv_offset=None): |
|
|
prefix_shape = cameras.shape[:-1] |
|
|
cameras = cameras.flatten(0, -2) |
|
|
device = cameras.device |
|
|
N = cameras.shape[0] |
|
|
|
|
|
c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1) |
|
|
c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4]) |
|
|
c2w[:, :3, 3] = cameras[:, 4:7] |
|
|
|
|
|
|
|
|
fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1) |
|
|
|
|
|
fx, cx = fx * w, cx * w |
|
|
fy, cy = fy * h, cy * h |
|
|
|
|
|
inds = torch.arange(0, h*w, device=device).expand(N, h*w) |
|
|
|
|
|
i = inds % w + 0.5 |
|
|
j = torch.div(inds, w, rounding_mode='floor') + 0.5 |
|
|
|
|
|
u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0) |
|
|
v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0) |
|
|
|
|
|
zs = - torch.ones_like(i) |
|
|
xs = - (u - 1) * cx / fx * zs |
|
|
ys = (v - 1) * cy / fy * zs |
|
|
directions = torch.stack((xs, ys, zs), dim=-1) |
|
|
|
|
|
rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1) |
|
|
|
|
|
rays_o = c2w[..., :3, 3] |
|
|
rays_o = rays_o[..., None, :].expand_as(rays_d) |
|
|
|
|
|
rays_o = rays_o.reshape(*prefix_shape, h, w, 3) |
|
|
rays_d = rays_d.reshape(*prefix_shape, h, w, 3) |
|
|
|
|
|
return rays_o, rays_d |
|
|
|
|
|
def matrix_to_square(mat): |
|
|
l = len(mat.shape) |
|
|
if l==3: |
|
|
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1) |
|
|
elif l==4: |
|
|
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2) |
|
|
|
|
|
def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None): |
|
|
|
|
|
sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1)) |
|
|
|
|
|
xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1) |
|
|
|
|
|
means3D = xyz.contiguous().float() |
|
|
opacity = opacity.contiguous().float() |
|
|
scales = scale.contiguous().float() |
|
|
rotations = rotation.contiguous().float() |
|
|
shs = feature.contiguous().float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opacity_threshold > 0: |
|
|
mask = opacity[..., 0] >= opacity_threshold |
|
|
means3D = means3D[mask] |
|
|
opacity = opacity[mask] |
|
|
scales = scales[mask] |
|
|
rotations = rotations[mask] |
|
|
shs = shs[mask] |
|
|
|
|
|
print("Gaussian percentage: ", mask.float().mean()) |
|
|
|
|
|
if T_norm is not None: |
|
|
means3D = means3D * T_norm.item() |
|
|
scales = scales * T_norm.item() |
|
|
|
|
|
|
|
|
opacity = torch.log(opacity/(1-opacity)) |
|
|
scales = torch.log(scales + 1e-8) |
|
|
|
|
|
xyzs = means3D.detach() |
|
|
f_dc = shs.detach().flatten(start_dim=1).contiguous() |
|
|
opacities = opacity.detach() |
|
|
scales = scales.detach() |
|
|
rotations = rotations.detach() |
|
|
|
|
|
l = ['x', 'y', 'z'] |
|
|
|
|
|
for i in range(f_dc.shape[1]): |
|
|
l.append('f_dc_{}'.format(i)) |
|
|
l.append('opacity') |
|
|
for i in range(scales.shape[1]): |
|
|
l.append('scale_{}'.format(i)) |
|
|
for i in range(rotations.shape[1]): |
|
|
l.append('rot_{}'.format(i)) |
|
|
|
|
|
dtype_full = [(attribute, 'f4') for attribute in l] |
|
|
|
|
|
|
|
|
attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy() |
|
|
|
|
|
|
|
|
elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l)) |
|
|
el = PlyElement.describe(elements, 'vertex') |
|
|
|
|
|
print(path) |
|
|
|
|
|
PlyData([el]).write(path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.amp.autocast(device_type="cuda", enabled=False) |
|
|
def quaternion_slerp( |
|
|
q0, q1, fraction, spin: int = 0, shortestpath: bool = True |
|
|
): |
|
|
"""Return spherical linear interpolation between two quaternions. |
|
|
Args: |
|
|
quat0: first quaternion |
|
|
quat1: second quaternion |
|
|
fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1) |
|
|
spin: how much of an additional spin to place on the interpolation |
|
|
shortestpath: whether to return the short or long path to rotation |
|
|
""" |
|
|
d = (q0 * q1).sum(-1) |
|
|
if shortestpath: |
|
|
|
|
|
d[d < 0.0] = -d[d < 0.0] |
|
|
q1[d < 0.0] = q1[d < 0.0] |
|
|
|
|
|
_d = d.clamp(0, 1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
angle = torch.acos(_d) + spin * math.pi |
|
|
isin = 1.0 / (torch.sin(angle)+ 1e-10) |
|
|
q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None] |
|
|
q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None] |
|
|
|
|
|
q = q0_ + q1_ |
|
|
|
|
|
q[angle < 1e-5] = q0[angle < 1e-5] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return q |
|
|
|
|
|
def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]): |
|
|
""" |
|
|
Args: |
|
|
pose_a: first pose |
|
|
pose_b: second pose |
|
|
fraction |
|
|
""" |
|
|
|
|
|
quat_a = pose_a[..., :4] |
|
|
quat_b = pose_b[..., :4] |
|
|
|
|
|
dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True) |
|
|
quat_b = torch.where(dot < 0, -quat_b, quat_b) |
|
|
|
|
|
quaternion = quaternion_slerp(quat_a, quat_b, fraction) |
|
|
quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1) |
|
|
|
|
|
T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:] |
|
|
T = T + torch.randn_like(T) * noise_strengths[1] |
|
|
|
|
|
new_pose = pose_a.clone() |
|
|
new_pose[..., :4] = quaternion |
|
|
new_pose[..., 4:] = T |
|
|
return new_pose |
|
|
|
|
|
def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]): |
|
|
N, C = dense_cameras.shape |
|
|
M = t.shape |
|
|
|
|
|
left = torch.floor(t * (N-1)).long().clamp(0, N-2) |
|
|
right = left + 1 |
|
|
fraction = t * (N-1) - left |
|
|
|
|
|
a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C)) |
|
|
b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C)) |
|
|
|
|
|
new_pose = sample_from_two_pose(a[:, :7], |
|
|
b[:, :7], fraction, noise_strengths=noise_strengths[:2]) |
|
|
|
|
|
new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:] |
|
|
|
|
|
return torch.cat([new_pose, new_ins], dim=1) |
|
|
|
|
|
|