from io import BytesIO import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import importlib from plyfile import PlyData, PlyElement import copy class EmbedContainer(nn.Module): def __init__(self, tensor): super().__init__() self.tensor = nn.Parameter(tensor) def forward(self): return self.tensor @torch.no_grad def zero_init(module): if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear: module.weight.zero_() module.bias.zero_() return module def import_str(string): # From https://github.com/CompVis/taming-transformers module, cls = string.rsplit(".", 1) return getattr(importlib.import_module(module, package=None), cls) """ from https://github.com/Kai-46/minFM/blob/main/utils/ema.py Exponential Moving Average (EMA) utilities for PyTorch models. This module provides utilities for maintaining and updating EMA models, which are commonly used to improve model stability and generalization in training deep neural networks. It supports both regular tensors and DTensors (from FSDP-wrapped models). """ class EMA_FSDP: def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999): self.decay = decay self.shadow = {} self._init_shadow(fsdp_module) @torch.no_grad() def _init_shadow(self, fsdp_module): # 判断是否是FSDP模型 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if isinstance(fsdp_module, FSDP): with FSDP.summon_full_params(fsdp_module, writeback=False): for n, p in fsdp_module.module.named_parameters(): self.shadow[n] = p.detach().clone().float().cpu() else: for n, p in fsdp_module.named_parameters(): self.shadow[n] = p.detach().clone().float().cpu() @torch.no_grad() def update(self, fsdp_module): d = self.decay from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if isinstance(fsdp_module, FSDP): with FSDP.summon_full_params(fsdp_module, writeback=False): for n, p in fsdp_module.module.named_parameters(): self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) else: for n, p in fsdp_module.named_parameters(): print(n, self.shadow[n]) self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) # Optional helpers --------------------------------------------------- def state_dict(self): return self.shadow # picklable def load_state_dict(self, sd): self.shadow = {k: v.clone() for k, v in sd.items()} def copy_to(self, fsdp_module): # load EMA weights into an (unwrapped) copy of the generator from torch.distributed.fsdp import FullyShardedDataParallel as FSDP with FSDP.summon_full_params(fsdp_module, writeback=True): for n, p in fsdp_module.module.named_parameters(): if n in self.shadow: p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) def create_raymaps(cameras, h, w): rays_o, rays_d = create_rays(cameras, h, w) raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1) return raymaps # def create_raymaps(cameras, h, w): # rays_o, rays_d = create_rays(cameras, h, w) # raymaps = torch.cat([rays_d, torch.cross(rays_d, rays_o, dim=-1)], dim=-1) # return raymaps class EMANorm(nn.Module): def __init__(self, beta): super().__init__() self.register_buffer('magnitude_ema', torch.ones([])) self.beta = beta def forward(self, x): if self.training: magnitude_cur = x.detach().to(torch.float32).square().mean() self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta)) input_gain = self.magnitude_ema.rsqrt() x = x.mul(input_gain) return x class TimestepEmbedding(nn.Module): def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True): super().__init__() self.max_period = max_period self.time_factor = time_factor self.dim = dim if zero_weight: self.weight = nn.Parameter(torch.zeros(dim)) else: self.weight = None def forward(self, t): if self.weight is None: return timestep_embedding(t, self.dim, self.max_period, self.time_factor) else: return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0) @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ t = time_factor * t half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) if torch.is_floating_point(t): embedding = embedding.to(t) return embedding def quaternion_to_matrix(quaternions): """ Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) # from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 if torch.is_grad_enabled(): ret[positive_mask] = torch.sqrt(x[positive_mask]) else: ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) indices = q_abs.argmax(dim=-1, keepdim=True) expand_dims = list(batch_dim) + [1, 4] gather_indices = indices.unsqueeze(-1).expand(expand_dims) out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2) return standardize_quaternion(out) @torch.amp.autocast(device_type="cuda", enabled=False) def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None): B, N = cameras.shape[:2] c2ws = torch.zeros(B, N, 3, 4, device=cameras.device) c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4]) c2ws[..., :, 3] = cameras[..., 4:7] _c2ws = c2ws ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c _c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :] if n_frame is not None: T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm else: T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm _c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2) R = matrix_to_quaternion(_c2ws[..., :3, :3]) T = _c2ws[..., :3, 3] cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1) if return_meta: return cameras, ref_w2c, T_norm else: return cameras def create_rays(cameras, h, w, uv_offset=None): prefix_shape = cameras.shape[:-1] cameras = cameras.flatten(0, -2) device = cameras.device N = cameras.shape[0] c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1) c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4]) c2w[:, :3, 3] = cameras[:, 4:7] # fx, fy, cx, cy should be divided by original H, W fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1) fx, cx = fx * w, cx * w fy, cy = fy * h, cy * h inds = torch.arange(0, h*w, device=device).expand(N, h*w) i = inds % w + 0.5 j = torch.div(inds, w, rounding_mode='floor') + 0.5 u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0) v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0) zs = - torch.ones_like(i) xs = - (u - 1) * cx / fx * zs ys = (v - 1) * cy / fy * zs directions = torch.stack((xs, ys, zs), dim=-1) rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1) rays_o = c2w[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) rays_o = rays_o.reshape(*prefix_shape, h, w, 3) rays_d = rays_d.reshape(*prefix_shape, h, w, 3) return rays_o, rays_d def matrix_to_square(mat): l = len(mat.shape) if l==3: return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1) elif l==4: return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2) def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None): sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1)) xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1) means3D = xyz.contiguous().float() opacity = opacity.contiguous().float() scales = scale.contiguous().float() rotations = rotation.contiguous().float() shs = feature.contiguous().float() # [N, 1, 3] # print(means3D.shape, opacity.shape, scales.shape, rotations.shape, shs.shape) # prune by opacity if opacity_threshold > 0: mask = opacity[..., 0] >= opacity_threshold means3D = means3D[mask] opacity = opacity[mask] scales = scales[mask] rotations = rotations[mask] shs = shs[mask] print("Gaussian percentage: ", mask.float().mean()) if T_norm is not None: means3D = means3D * T_norm.item() scales = scales * T_norm.item() # invert activation to make it compatible with the original ply format opacity = torch.log(opacity/(1-opacity)) scales = torch.log(scales + 1e-8) xyzs = means3D.detach() # .cpu().numpy() f_dc = shs.detach().flatten(start_dim=1).contiguous() #.cpu().numpy() opacities = opacity.detach() #.cpu().numpy() scales = scales.detach() #.cpu().numpy() rotations = rotations.detach() #.cpu().numpy() l = ['x', 'y', 'z'] # All channels except the 3 DC for i in range(f_dc.shape[1]): l.append('f_dc_{}'.format(i)) l.append('opacity') for i in range(scales.shape[1]): l.append('scale_{}'.format(i)) for i in range(rotations.shape[1]): l.append('rot_{}'.format(i)) dtype_full = [(attribute, 'f4') for attribute in l] # 最优化方案:使用numpy的recarray直接创建 attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy() # 使用recarray直接创建,避免循环和类型转换 elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l)) el = PlyElement.describe(elements, 'vertex') print(path) PlyData([el]).write(path) # plydata = PlyData([el]) # vert = plydata["vertex"] # sorted_indices = np.argsort( # -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) # / (1 + np.exp(-vert["opacity"])) # ) # buffer = BytesIO() # for idx in sorted_indices: # v = plydata["vertex"][idx] # position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) # scales = np.exp( # np.array( # [v["scale_0"], v["scale_1"], v["scale_2"]], # dtype=np.float32, # ) # ) # rot = np.array( # [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], # dtype=np.float32, # ) # SH_C0 = 0.28209479177387814 # color = np.array( # [ # 0.5 + SH_C0 * v["f_dc_0"], # 0.5 + SH_C0 * v["f_dc_1"], # 0.5 + SH_C0 * v["f_dc_2"], # 1 / (1 + np.exp(-v["opacity"])), # ] # ) # buffer.write(position.tobytes()) # buffer.write(scales.tobytes()) # buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) # buffer.write( # ((rot / np.linalg.norm(rot)) * 128 + 128) # .clip(0, 255) # .astype(np.uint8) # .tobytes() # ) # with open(path + '.splat', "wb") as f: # f.write(buffer.getvalue()) @torch.amp.autocast(device_type="cuda", enabled=False) def quaternion_slerp( q0, q1, fraction, spin: int = 0, shortestpath: bool = True ): """Return spherical linear interpolation between two quaternions. Args: quat0: first quaternion quat1: second quaternion fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1) spin: how much of an additional spin to place on the interpolation shortestpath: whether to return the short or long path to rotation """ d = (q0 * q1).sum(-1) if shortestpath: # invert rotation d[d < 0.0] = -d[d < 0.0] q1[d < 0.0] = q1[d < 0.0] _d = d.clamp(0, 1.0) # theta = torch.arccos(d) * fraction # q2 = q1 - q0 * d # q2 = q2 / (q2.norm(dim=-1) + 1e-10) # return torch.cos(theta) * q0 + torch.sin(theta) * q2 angle = torch.acos(_d) + spin * math.pi isin = 1.0 / (torch.sin(angle)+ 1e-10) q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None] q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None] q = q0_ + q1_ q[angle < 1e-5] = q0[angle < 1e-5] # q[fraction < 1e-5] = q0[fraction < 1e-5] # q[fraction > 1 - 1e-5] = q1[fraction > 1 - 1e-5] # q[(d.abs() - 1).abs() < 1e-5] = q0[(d.abs() - 1).abs() < 1e-5] return q def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]): """ Args: pose_a: first pose pose_b: second pose fraction """ quat_a = pose_a[..., :4] quat_b = pose_b[..., :4] dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True) quat_b = torch.where(dot < 0, -quat_b, quat_b) quaternion = quaternion_slerp(quat_a, quat_b, fraction) quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1) T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:] T = T + torch.randn_like(T) * noise_strengths[1] new_pose = pose_a.clone() new_pose[..., :4] = quaternion new_pose[..., 4:] = T return new_pose def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]): N, C = dense_cameras.shape M = t.shape left = torch.floor(t * (N-1)).long().clamp(0, N-2) right = left + 1 fraction = t * (N-1) - left a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C)) b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C)) new_pose = sample_from_two_pose(a[:, :7], b[:, :7], fraction, noise_strengths=noise_strengths[:2]) new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:] return torch.cat([new_pose, new_ins], dim=1)