File size: 19,170 Bytes
babafa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 |
from io import BytesIO
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
from plyfile import PlyData, PlyElement
import copy
class EmbedContainer(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = nn.Parameter(tensor)
def forward(self):
return self.tensor
@torch.no_grad
def zero_init(module):
if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear:
module.weight.zero_()
module.bias.zero_()
return module
def import_str(string):
# From https://github.com/CompVis/taming-transformers
module, cls = string.rsplit(".", 1)
return getattr(importlib.import_module(module, package=None), cls)
"""
from https://github.com/Kai-46/minFM/blob/main/utils/ema.py
Exponential Moving Average (EMA) utilities for PyTorch models.
This module provides utilities for maintaining and updating EMA models,
which are commonly used to improve model stability and generalization
in training deep neural networks. It supports both regular tensors and
DTensors (from FSDP-wrapped models).
"""
class EMA_FSDP:
def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
self.decay = decay
self.shadow = {}
self._init_shadow(fsdp_module)
@torch.no_grad()
def _init_shadow(self, fsdp_module):
# 判断是否是FSDP模型
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if isinstance(fsdp_module, FSDP):
with FSDP.summon_full_params(fsdp_module, writeback=False):
for n, p in fsdp_module.module.named_parameters():
self.shadow[n] = p.detach().clone().float().cpu()
else:
for n, p in fsdp_module.named_parameters():
self.shadow[n] = p.detach().clone().float().cpu()
@torch.no_grad()
def update(self, fsdp_module):
d = self.decay
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if isinstance(fsdp_module, FSDP):
with FSDP.summon_full_params(fsdp_module, writeback=False):
for n, p in fsdp_module.module.named_parameters():
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
else:
for n, p in fsdp_module.named_parameters():
print(n, self.shadow[n])
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
# Optional helpers ---------------------------------------------------
def state_dict(self):
return self.shadow # picklable
def load_state_dict(self, sd):
self.shadow = {k: v.clone() for k, v in sd.items()}
def copy_to(self, fsdp_module):
# load EMA weights into an (unwrapped) copy of the generator
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(fsdp_module, writeback=True):
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
def create_raymaps(cameras, h, w):
rays_o, rays_d = create_rays(cameras, h, w)
raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1)
return raymaps
# def create_raymaps(cameras, h, w):
# rays_o, rays_d = create_rays(cameras, h, w)
# raymaps = torch.cat([rays_d, torch.cross(rays_d, rays_o, dim=-1)], dim=-1)
# return raymaps
class EMANorm(nn.Module):
def __init__(self, beta):
super().__init__()
self.register_buffer('magnitude_ema', torch.ones([]))
self.beta = beta
def forward(self, x):
if self.training:
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta))
input_gain = self.magnitude_ema.rsqrt()
x = x.mul(input_gain)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True):
super().__init__()
self.max_period = max_period
self.time_factor = time_factor
self.dim = dim
if zero_weight:
self.weight = nn.Parameter(torch.zeros(dim))
else:
self.weight = None
def forward(self, t):
if self.weight is None:
return timestep_embedding(t, self.dim, self.max_period, self.time_factor)
else:
return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
# from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
indices = q_abs.argmax(dim=-1, keepdim=True)
expand_dims = list(batch_dim) + [1, 4]
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
return standardize_quaternion(out)
@torch.amp.autocast(device_type="cuda", enabled=False)
def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None):
B, N = cameras.shape[:2]
c2ws = torch.zeros(B, N, 3, 4, device=cameras.device)
c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4])
c2ws[..., :, 3] = cameras[..., 4:7]
_c2ws = c2ws
ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c
_c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :]
if n_frame is not None:
T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
else:
T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
_c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2)
R = matrix_to_quaternion(_c2ws[..., :3, :3])
T = _c2ws[..., :3, 3]
cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1)
if return_meta:
return cameras, ref_w2c, T_norm
else:
return cameras
def create_rays(cameras, h, w, uv_offset=None):
prefix_shape = cameras.shape[:-1]
cameras = cameras.flatten(0, -2)
device = cameras.device
N = cameras.shape[0]
c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1)
c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4])
c2w[:, :3, 3] = cameras[:, 4:7]
# fx, fy, cx, cy should be divided by original H, W
fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1)
fx, cx = fx * w, cx * w
fy, cy = fy * h, cy * h
inds = torch.arange(0, h*w, device=device).expand(N, h*w)
i = inds % w + 0.5
j = torch.div(inds, w, rounding_mode='floor') + 0.5
u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0)
v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0)
zs = - torch.ones_like(i)
xs = - (u - 1) * cx / fx * zs
ys = (v - 1) * cy / fy * zs
directions = torch.stack((xs, ys, zs), dim=-1)
rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1)
rays_o = c2w[..., :3, 3] # [B, 3]
rays_o = rays_o[..., None, :].expand_as(rays_d)
rays_o = rays_o.reshape(*prefix_shape, h, w, 3)
rays_d = rays_d.reshape(*prefix_shape, h, w, 3)
return rays_o, rays_d
def matrix_to_square(mat):
l = len(mat.shape)
if l==3:
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1)
elif l==4:
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
means3D = xyz.contiguous().float()
opacity = opacity.contiguous().float()
scales = scale.contiguous().float()
rotations = rotation.contiguous().float()
shs = feature.contiguous().float() # [N, 1, 3]
# print(means3D.shape, opacity.shape, scales.shape, rotations.shape, shs.shape)
# prune by opacity
if opacity_threshold > 0:
mask = opacity[..., 0] >= opacity_threshold
means3D = means3D[mask]
opacity = opacity[mask]
scales = scales[mask]
rotations = rotations[mask]
shs = shs[mask]
print("Gaussian percentage: ", mask.float().mean())
if T_norm is not None:
means3D = means3D * T_norm.item()
scales = scales * T_norm.item()
# invert activation to make it compatible with the original ply format
opacity = torch.log(opacity/(1-opacity))
scales = torch.log(scales + 1e-8)
xyzs = means3D.detach() # .cpu().numpy()
f_dc = shs.detach().flatten(start_dim=1).contiguous() #.cpu().numpy()
opacities = opacity.detach() #.cpu().numpy()
scales = scales.detach() #.cpu().numpy()
rotations = rotations.detach() #.cpu().numpy()
l = ['x', 'y', 'z']
# All channels except the 3 DC
for i in range(f_dc.shape[1]):
l.append('f_dc_{}'.format(i))
l.append('opacity')
for i in range(scales.shape[1]):
l.append('scale_{}'.format(i))
for i in range(rotations.shape[1]):
l.append('rot_{}'.format(i))
dtype_full = [(attribute, 'f4') for attribute in l]
# 最优化方案:使用numpy的recarray直接创建
attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
# 使用recarray直接创建,避免循环和类型转换
elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
el = PlyElement.describe(elements, 'vertex')
print(path)
PlyData([el]).write(path)
# plydata = PlyData([el])
# vert = plydata["vertex"]
# sorted_indices = np.argsort(
# -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
# / (1 + np.exp(-vert["opacity"]))
# )
# buffer = BytesIO()
# for idx in sorted_indices:
# v = plydata["vertex"][idx]
# position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
# scales = np.exp(
# np.array(
# [v["scale_0"], v["scale_1"], v["scale_2"]],
# dtype=np.float32,
# )
# )
# rot = np.array(
# [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
# dtype=np.float32,
# )
# SH_C0 = 0.28209479177387814
# color = np.array(
# [
# 0.5 + SH_C0 * v["f_dc_0"],
# 0.5 + SH_C0 * v["f_dc_1"],
# 0.5 + SH_C0 * v["f_dc_2"],
# 1 / (1 + np.exp(-v["opacity"])),
# ]
# )
# buffer.write(position.tobytes())
# buffer.write(scales.tobytes())
# buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
# buffer.write(
# ((rot / np.linalg.norm(rot)) * 128 + 128)
# .clip(0, 255)
# .astype(np.uint8)
# .tobytes()
# )
# with open(path + '.splat', "wb") as f:
# f.write(buffer.getvalue())
@torch.amp.autocast(device_type="cuda", enabled=False)
def quaternion_slerp(
q0, q1, fraction, spin: int = 0, shortestpath: bool = True
):
"""Return spherical linear interpolation between two quaternions.
Args:
quat0: first quaternion
quat1: second quaternion
fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
spin: how much of an additional spin to place on the interpolation
shortestpath: whether to return the short or long path to rotation
"""
d = (q0 * q1).sum(-1)
if shortestpath:
# invert rotation
d[d < 0.0] = -d[d < 0.0]
q1[d < 0.0] = q1[d < 0.0]
_d = d.clamp(0, 1.0)
# theta = torch.arccos(d) * fraction
# q2 = q1 - q0 * d
# q2 = q2 / (q2.norm(dim=-1) + 1e-10)
# return torch.cos(theta) * q0 + torch.sin(theta) * q2
angle = torch.acos(_d) + spin * math.pi
isin = 1.0 / (torch.sin(angle)+ 1e-10)
q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None]
q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None]
q = q0_ + q1_
q[angle < 1e-5] = q0[angle < 1e-5]
# q[fraction < 1e-5] = q0[fraction < 1e-5]
# q[fraction > 1 - 1e-5] = q1[fraction > 1 - 1e-5]
# q[(d.abs() - 1).abs() < 1e-5] = q0[(d.abs() - 1).abs() < 1e-5]
return q
def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]):
"""
Args:
pose_a: first pose
pose_b: second pose
fraction
"""
quat_a = pose_a[..., :4]
quat_b = pose_b[..., :4]
dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True)
quat_b = torch.where(dot < 0, -quat_b, quat_b)
quaternion = quaternion_slerp(quat_a, quat_b, fraction)
quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1)
T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:]
T = T + torch.randn_like(T) * noise_strengths[1]
new_pose = pose_a.clone()
new_pose[..., :4] = quaternion
new_pose[..., 4:] = T
return new_pose
def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]):
N, C = dense_cameras.shape
M = t.shape
left = torch.floor(t * (N-1)).long().clamp(0, N-2)
right = left + 1
fraction = t * (N-1) - left
a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C))
b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C))
new_pose = sample_from_two_pose(a[:, :7],
b[:, :7], fraction, noise_strengths=noise_strengths[:2])
new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:]
return torch.cat([new_pose, new_ins], dim=1)
|