File size: 16,687 Bytes
0ca05b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
from pathlib import Path

import numpy as np
import torch
import moviepy.editor as mpy

from src.models.models.rasterization import GaussianSplatRenderer
from src.models.utils.sh_utils import RGB2SH, SH2RGB
from src.utils.gs_effects import GSEffects
from src.utils.color_map import apply_color_map_to_image
from tqdm import tqdm


def rotation_matrix_to_quaternion(R):
    """Convert rotation matrix to quaternion"""
    trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
    
    q = torch.zeros(R.shape[:-2] + (4,), device=R.device, dtype=R.dtype)
    
    # Case where trace > 0
    mask1 = trace > 0
    s = torch.sqrt(trace[mask1] + 1.0) * 2  # s=4*qw 
    q[mask1, 0] = 0.25 * s  # qw
    q[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s  # qx
    q[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s  # qy
    q[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s  # qz
    
    # Case where R[0,0] > R[1,1] and R[0,0] > R[2,2]
    mask2 = (~mask1) & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
    s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2  # s=4*qx
    q[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s  # qw
    q[mask2, 1] = 0.25 * s  # qx
    q[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s  # qy
    q[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s  # qz
    
    # Case where R[1,1] > R[2,2]
    mask3 = (~mask1) & (~mask2) & (R[..., 1, 1] > R[..., 2, 2])
    s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2  # s=4*qy
    q[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s  # qw
    q[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s  # qx
    q[mask3, 2] = 0.25 * s  # qy
    q[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s  # qz
    
    # Remaining case
    mask4 = (~mask1) & (~mask2) & (~mask3)
    s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2  # s=4*qz
    q[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s  # qw
    q[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s  # qx
    q[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s  # qy
    q[mask4, 3] = 0.25 * s  # qz
    
    return q


def quaternion_to_rotation_matrix(q):
    """Convert quaternion to rotation matrix"""
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    
    # Normalize quaternion
    norm = torch.sqrt(w*w + x*x + y*y + z*z)
    w, x, y, z = w/norm, x/norm, y/norm, z/norm
    
    R = torch.zeros(q.shape[:-1] + (3, 3), device=q.device, dtype=q.dtype)
    
    R[..., 0, 0] = 1 - 2*(y*y + z*z)
    R[..., 0, 1] = 2*(x*y - w*z)
    R[..., 0, 2] = 2*(x*z + w*y)
    R[..., 1, 0] = 2*(x*y + w*z)
    R[..., 1, 1] = 1 - 2*(x*x + z*z)
    R[..., 1, 2] = 2*(y*z - w*x)
    R[..., 2, 0] = 2*(x*z - w*y)
    R[..., 2, 1] = 2*(y*z + w*x)
    R[..., 2, 2] = 1 - 2*(x*x + y*y)
    
    return R

    
def slerp_quaternions(q1, q2, t):
    """Spherical linear interpolation between quaternions"""
    # Compute dot product
    dot = (q1 * q2).sum(dim=-1, keepdim=True)
    
    # If dot product is negative, slerp won't take the shorter path.
    # Note that q and -q represent the same rotation, so we can flip one.
    mask = dot < 0
    q2 = torch.where(mask, -q2, q2)
    dot = torch.where(mask, -dot, dot)
    
    # If the inputs are too close for comfort, linearly interpolate
    # and normalize the result.
    DOT_THRESHOLD = 0.9995
    mask_linear = dot > DOT_THRESHOLD
    
    result = torch.zeros_like(q1)
    
    # Linear interpolation for close quaternions
    if mask_linear.any():
        result_linear = q1 + t * (q2 - q1)
        norm = torch.norm(result_linear, dim=-1, keepdim=True)
        result_linear = result_linear / norm
        result = torch.where(mask_linear, result_linear, result)
    
    # Spherical interpolation for distant quaternions
    mask_slerp = ~mask_linear
    if mask_slerp.any():
        theta_0 = torch.acos(torch.abs(dot))
        sin_theta_0 = torch.sin(theta_0)
        
        theta = theta_0 * t
        sin_theta = torch.sin(theta)
        
        s0 = torch.cos(theta) - dot * sin_theta / sin_theta_0
        s1 = sin_theta / sin_theta_0
        
        result_slerp = (s0 * q1) + (s1 * q2)
        result = torch.where(mask_slerp, result_slerp, result)
    
    return result


def render_interpolated_video(gs_renderer: GaussianSplatRenderer,
                              splats: dict,
                              camtoworlds: torch.Tensor,
                              intrinsics: torch.Tensor,
                              hw: tuple[int, int],
                              out_path: Path,
                              interp_per_pair: int = 20,
                              loop_reverse: bool = True,
                              effects: GSEffects = None,
                              effect_type: int = 2,
                              save_mode: str = "split") -> None:
    # camtoworlds: [B, S, 4, 4], intrinsics: [B, S, 3, 3]
    b, s, _, _ = camtoworlds.shape
    h, w = hw

    # Build interpolated trajectory
    def build_interpolated_traj(index, nums):
        exts, ints = [], []
        tmp_camtoworlds = camtoworlds[:, index]
        tmp_intrinsics = intrinsics[:, index]
        for i in range(len(index)-1):
            exts.append(tmp_camtoworlds[:, i:i+1])
            ints.append(tmp_intrinsics[:, i:i+1])
            # Extract rotation and translation
            R0, t0 = tmp_camtoworlds[:, i, :3, :3], tmp_camtoworlds[:, i, :3, 3]
            R1, t1 = tmp_camtoworlds[:, i + 1, :3, :3], tmp_camtoworlds[:, i + 1, :3, 3]
            
            # Convert rotations to quaternions
            q0 = rotation_matrix_to_quaternion(R0)
            q1 = rotation_matrix_to_quaternion(R1)
            
            # Interpolate using smooth quaternion slerp
            for j in range(1, nums + 1):
                alpha = j / (nums + 1)
                
                # Linear interpolation for translation
                t_interp = (1 - alpha) * t0 + alpha * t1
                
                # Spherical interpolation for rotation
                q_interp = slerp_quaternions(q0, q1, alpha)
                R_interp = quaternion_to_rotation_matrix(q_interp)
                
                # Create interpolated extrinsic matrix
                ext = torch.eye(4, device=R_interp.device, dtype=R_interp.dtype)[None].repeat(b, 1, 1)
                ext[:, :3, :3] = R_interp
                ext[:, :3, 3] = t_interp
                
                # Linear interpolation for intrinsics
                K0 = tmp_intrinsics[:, i]
                K1 = tmp_intrinsics[:, i + 1]
                K = (1 - alpha) * K0 + alpha * K1
                
                exts.append(ext[:, None])
                ints.append(K[:, None])

        exts = torch.cat(exts, dim=1)[:1]
        ints = torch.cat(ints, dim=1)[:1]
        return exts, ints
    
    # Build wobble trajectory
    def build_wobble_traj(nums, delta):
        assert s==1
        t = torch.linspace(0, 1, nums, dtype=torch.float32, device=camtoworlds.device)
        t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
        tf = torch.eye(4, dtype=torch.float32, device=camtoworlds.device)
        radius = delta * 0.15
        tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
        radius = radius[..., None]
        radius = radius * t
        tf[..., 0, 3] = torch.sin(2 * torch.pi * t) * radius
        tf[..., 1, 3] = -torch.cos(2 * torch.pi * t) * radius
        exts = camtoworlds @ tf
        ints = intrinsics.repeat(1, exts.shape[1], 1, 1)
        return exts, ints
    
    if s > 1:
        all_ext, all_int = build_interpolated_traj([i for i in range(s)], interp_per_pair)
    else:
        all_ext, all_int = build_wobble_traj(interp_per_pair * 12, splats["means"][0].median(dim=0).values.norm(dim=-1)[None])

    rendered_rgbs, rendered_depths = [], []
    chunk = 40 if effects is None else 1
    t = 0
    t_skip = 0
    if effects is not None:
        try:
            pruned_splats = gs_renderer.prune_gs(splats, gs_renderer.voxel_size)
        except:
            pruned_splats = splats
        # indices = [x for x in range(0, all_ext.shape[1], 2)][:4]
        # add_ext, add_int = build_interpolated_traj(indices, 150)
        # add_ext = torch.flip(add_ext, dims=[1])
        # add_int = torch.flip(add_int, dims=[1])
        add_ext = all_ext[:, :1, :, :].repeat(1, 320, 1, 1)
        add_int = all_int[:, :1, :, :].repeat(1, 320, 1, 1)
        shift = pruned_splats["means"][0].median(dim=0).values
        scale_factor = (pruned_splats["means"][0] - shift).abs().quantile(0.95, dim=0).max()
        all_ext[0, :, :3, -1] = (all_ext[0, :, :3, -1] - shift) / scale_factor
        add_ext[0, :, :3, -1] = (add_ext[0, :, :3, -1] - shift) / scale_factor
        flag = None
        try:
            raw_splats = gs_renderer.rasterizer.runner.splats
        except:
            pass
        for st in range(0, add_ext.shape[1]):
            ed = min(st + 1, add_ext.shape[1])
            assert gs_renderer.sh_degree == 0
            if flag is not None and (flag < 0.99).any():
                break
            sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, 
                            "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
            effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type)
            t += 0.04
            effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3)
            try:
                gs_renderer.rasterizer.runner.splats
                effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
                effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
                effects_splats["scales"] = effects_splats["scales"].log()
                effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6))
                gs_renderer.rasterizer.runner.splats = effects_splats
                colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
                None, None, None, 
                None, None,
                add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32),
                width=w, height=h, sh_degree=gs_renderer.sh_degree,
                )
            except:
                colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
                effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None], 
                effects_splats["opacities"][None], effects_splats["sh"][None],
                add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32),
                width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
                )
            
            if st > add_ext.shape[1]*0.14:
                t_skip = t if t_skip == 0 else t_skip
                # break
                rendered_rgbs.append(colors)
                rendered_depths.append(depths)
            # if (flag == 0).all():
            #     break
    t_st = t
    t_ed = 0
    loop_dir = 1
    ignore_scale = False
    for st in tqdm(range(0, all_ext.shape[1], chunk)):
        ed = min(st + chunk, all_ext.shape[1])
        if effects is not None:
            try:
                sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, 
                                "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
            except:
                sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, 
                                "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
            effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type, ignore_scale=ignore_scale)
            if loop_dir < 0:
                t -= 0.04
            else:
                t += 0.04
            if flag.mean() < 0.01 and t_ed == 0:
                t_ed = t
            effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3)
            effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
            effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
            try:
                gs_renderer.rasterizer.runner.splats
                effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
                effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
                effects_splats["scales"] = effects_splats["scales"].log()
                effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6))
                gs_renderer.rasterizer.runner.splats = effects_splats
                colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
                None, None, None, 
                None, None,
                all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
                width=w, height=h, sh_degree=gs_renderer.sh_degree,
                )
            except:
                colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
                effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None], 
                effects_splats["opacities"][None], effects_splats["sh"][None],
                all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
                width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
                )
            
            if t > (all_ext.shape[1]) * 0.04 + t_st - (t_ed - t_st)*2 - 15*0.04 or t < t_st:
                # ignore_scale = True
                loop_dir *= -1
                t = t_ed if loop_dir == -1 else t
        else:
            colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
                splats["means"][:1], splats["quats"][:1], splats["scales"][:1], splats["opacities"][:1],
                splats["sh"][:1] if "sh" in splats else splats["colors"][:1],
                all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
                width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in splats else None,
            )
        rendered_rgbs.append(colors)
        rendered_depths.append(depths)


    rgbs = torch.cat(rendered_rgbs, dim=1)[0]          # [N, H, W, 3]
    depths = torch.cat(rendered_depths, dim=1)[0, ..., 0]     # [N, H, W]


    def depth_vis(d: torch.Tensor) -> torch.Tensor:
        valid = d > 0
        if valid.any():
            near = d[valid].float().quantile(0.01).log()
        else:
            near = torch.tensor(0.0, device=d.device)
        far = d.flatten().float().quantile(0.99).log()
        x = d.float().clamp(min=1e-9).log()
        x = 1.0 - (x - near) / (far - near + 1e-9)
        return apply_color_map_to_image(x, "turbo")
    
    frames = []
    rgb_frames = []
    depth_frames = []

    for rgb, dep in zip(rgbs, depths):
        rgb_img = rgb.permute(2, 0, 1)  # [3, H, W]
        depth_img = depth_vis(dep)      # [3, H, W]

        if save_mode == 'both':
            combined = torch.cat([rgb_img, depth_img], dim=1)  # [3, 2*H, W]
            frames.append(combined)
        elif save_mode == 'split':
            rgb_frames.append(rgb_img)
            depth_frames.append(depth_img)
        else:
            raise ValueError("save_mode must be 'both' or 'split'")

    def _make_video(frames, path):
        video = torch.stack(frames).clamp(0, 1)  # [N, 3, H, W]
        video = video.permute(0, 2, 3, 1)  # [N, H, W, 3] for moviepy
        video = (video * 255).to(torch.uint8).cpu().numpy()
        if loop_reverse and video.shape[0] > 1:
            video = np.concatenate([video, video[::-1][1:-1]], axis=0)
        clip = mpy.ImageSequenceClip(list(video), fps=30)
        clip.write_videofile(str(path), logger=None)

    # Save videos
    if save_mode == 'both':
        _make_video(frames, f"{out_path}.mp4")
    elif save_mode == 'split':
        _make_video(rgb_frames, f"{out_path}_rgb.mp4")
        _make_video(depth_frames, f"{out_path}_depth.mp4")

    print(f"Video saved to {out_path} (mode: {save_mode})")

    if effects is not None:
        try:
            gs_renderer.rasterizer.runner.splats = raw_splats
        except:
            pass
    torch.cuda.empty_cache()