superrgod commited on
Commit
babafa4
·
verified ·
1 Parent(s): a7df54f

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +651 -0
  2. quant.py +195 -0
  3. requirements.txt +19 -0
  4. utils.py +531 -0
app.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ GPU = spaces.GPU
4
+ print("spaces GPU is available")
5
+ except ImportError:
6
+ def GPU(func):
7
+ return func
8
+
9
+ import os
10
+ import subprocess
11
+
12
+ # def install_cuda_toolkit():
13
+ # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
14
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
15
+ # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
+ # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
+ # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
+ # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
+
20
+ # os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
+ # os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
+ # os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
23
+ # os.environ["CUDA_HOME"],
24
+ # "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
25
+ # )
26
+ # # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
+ # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
28
+
29
+ # print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
30
+
31
+ # subprocess.call('rm /usr/bin/gcc', shell=True)
32
+ # subprocess.call('rm /usr/bin/g++', shell=True)
33
+ # subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
34
+ # subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
35
+
36
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
37
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
38
+
39
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
40
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
41
+
42
+ # subprocess.call('gcc --version', shell=True)
43
+ # subprocess.call('g++ --version', shell=True)
44
+
45
+ # install_cuda_toolkit()
46
+
47
+ # subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
48
+
49
+ from flask import Flask, jsonify, request, send_file, render_template
50
+ import base64
51
+ import io
52
+ from PIL import Image
53
+ import torch
54
+ import numpy as np
55
+ import os
56
+ import argparse
57
+ import imageio
58
+ import json
59
+
60
+ import time
61
+ import threading
62
+
63
+ from concurrency_manager import ConcurrencyManager
64
+
65
+ from huggingface_hub import hf_hub_download
66
+
67
+ import einops
68
+ import torch
69
+ import torch.nn as nn
70
+ import torch.nn.functional as F
71
+ import numpy as np
72
+
73
+ import imageio
74
+
75
+ from models import *
76
+ from utils import *
77
+
78
+ from transformers import T5TokenizerFast, UMT5EncoderModel
79
+
80
+ from diffusers import FlowMatchEulerDiscreteScheduler
81
+
82
+ class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
83
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
84
+ if schedule_timesteps is None:
85
+ schedule_timesteps = self.timesteps
86
+
87
+ return torch.argmin(
88
+ (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
89
+
90
+ class GenerationSystem(nn.Module):
91
+ def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
92
+ super().__init__()
93
+ self.device = device
94
+ self.offload_t5 = offload_t5
95
+ self.offload_vae = offload_vae
96
+
97
+ self.latent_dim = 48
98
+ self.temporal_downsample_factor = 4
99
+ self.spatial_downsample_factor = 16
100
+
101
+ self.feat_dim = 1024
102
+
103
+ self.latent_patch_size = 2
104
+
105
+ self.denoising_steps = [0, 250, 500, 750]
106
+
107
+ model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
108
+
109
+ self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
110
+
111
+ from models.autoencoder_kl_wan import WanCausalConv3d
112
+ with torch.no_grad():
113
+ for name, module in self.vae.named_modules():
114
+ if isinstance(module, WanCausalConv3d):
115
+ time_pad = module._padding[4]
116
+ module.padding = (0, module._padding[2], module._padding[0])
117
+ module._padding = (0, 0, 0, 0, 0, 0)
118
+ module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
119
+
120
+ self.vae.requires_grad_(False)
121
+
122
+ self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
123
+ self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
124
+
125
+ self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std
126
+ self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean
127
+
128
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
129
+
130
+ self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
131
+
132
+ self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
133
+
134
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
135
+ # self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1]
136
+
137
+ weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
138
+ bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
139
+
140
+ extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
141
+ extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
142
+
143
+ self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
144
+ self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
145
+
146
+ self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
147
+
148
+ self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
149
+
150
+ self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
151
+
152
+ self.transformer.disable_gradient_checkpointing()
153
+ self.transformer.gradient_checkpointing = False
154
+
155
+ self.add_feedback_for_transformer()
156
+
157
+ if ckpt_path is not None:
158
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
159
+ self.transformer.load_state_dict(state_dict["transformer"])
160
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
161
+ print(f"Loaded {ckpt_path}.")
162
+
163
+ from quant import FluxFp8GeMMProcessor
164
+
165
+ FluxFp8GeMMProcessor(self.transformer)
166
+
167
+ del self.vae.post_quant_conv, self.vae.decoder
168
+ self.vae.to(self.device if not self.offload_vae else "cpu")
169
+
170
+ self.transformer.to(self.device)
171
+
172
+ def add_feedback_for_transformer(self):
173
+ self.use_feedback = True
174
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
175
+
176
+ def encode_text(self, texts):
177
+ max_sequence_length = 512
178
+
179
+ text_inputs = self.tokenizer(
180
+ texts,
181
+ padding="max_length",
182
+ max_length=max_sequence_length,
183
+ truncation=True,
184
+ add_special_tokens=True,
185
+ return_attention_mask=True,
186
+ return_tensors="pt",
187
+ )
188
+ if getattr(self, "offload_t5", False):
189
+ text_input_ids = text_inputs.input_ids.to("cpu")
190
+ mask = text_inputs.attention_mask.to("cpu")
191
+ else:
192
+ text_input_ids = text_inputs.input_ids.to(self.device)
193
+ mask = text_inputs.attention_mask.to(self.device)
194
+ seq_lens = mask.gt(0).sum(dim=1).long()
195
+
196
+ if getattr(self, "offload_t5", False):
197
+ with torch.no_grad():
198
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
199
+ else:
200
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
201
+ text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
202
+ text_embeds = torch.stack(
203
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
204
+ )
205
+ return text_embeds.float()
206
+
207
+ def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
208
+
209
+ out = self.transformer(
210
+ hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
211
+ timestep=t,
212
+ encoder_hidden_states=text_embeds,
213
+ return_dict=False,
214
+ )[0]
215
+
216
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
217
+
218
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
219
+ latents_pred_2d = noisy_latents - sigma * v_pred
220
+
221
+ if need_3d_mode:
222
+ scene_params = self.recon_decoder(
223
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
224
+ einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
225
+ cameras
226
+ ).flatten(1, -2)
227
+
228
+ images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
229
+
230
+ latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
231
+ einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
232
+ ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
233
+
234
+ return {
235
+ '2d': latents_pred_2d,
236
+ '3d': latents_pred_3d if need_3d_mode else None,
237
+ 'rgb_3d': images_pred if need_3d_mode else None,
238
+ 'scene': scene_params if need_3d_mode else None,
239
+ 'feat': feats
240
+ }
241
+
242
+ @torch.no_grad()
243
+ @torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
244
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
245
+ with torch.no_grad():
246
+ batch_size = 1
247
+
248
+ cameras = cameras.to(self.device).unsqueeze(0)
249
+
250
+ if cameras.shape[1] != n_frame:
251
+ render_cameras = cameras.clone()
252
+ cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
253
+ else:
254
+ render_cameras = cameras
255
+
256
+ cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
257
+
258
+ render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
259
+
260
+ text = "[Static] " + text
261
+
262
+ text_embeds = self.encode_text([text])
263
+ # neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1)
264
+
265
+ masks = torch.zeros(batch_size, n_frame, device=self.device)
266
+
267
+ condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
268
+
269
+ if image is not None:
270
+ image = image.to(self.device)
271
+
272
+ latent = self.latent_scale_fn(self.vae.encode(
273
+ image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
274
+ ).latent_dist.sample().to(self.device)).squeeze(2)
275
+
276
+ masks[:, image_index] = 1
277
+ condition_latents[:, :, image_index] = latent
278
+
279
+ raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
280
+ raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
281
+
282
+ noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
283
+
284
+ noisy_latents = noise
285
+
286
+ torch.cuda.empty_cache()
287
+
288
+ if self.use_feedback:
289
+ prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
290
+
291
+ prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
292
+
293
+ for i in range(len(self.denoising_steps)):
294
+ t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
295
+
296
+ t = self.timesteps[t_ids]
297
+
298
+ if self.use_feedback:
299
+ _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
300
+ else:
301
+ _condition_latents = condition_latents
302
+
303
+ if i < len(self.denoising_steps) - 1:
304
+ out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
305
+
306
+ latents_pred = out["3d"]
307
+
308
+ if self.use_feedback:
309
+ prev_latents_pred = latents_pred
310
+ prev_feats = out['feat']
311
+
312
+ noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
313
+
314
+ else:
315
+ out = self.transformer(
316
+ hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
317
+ timestep=t,
318
+ encoder_hidden_states=text_embeds,
319
+ return_dict=False,
320
+ )[0]
321
+
322
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
323
+
324
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
325
+ latents_pred = noisy_latents - sigma * v_pred
326
+
327
+ scene_params = self.recon_decoder(
328
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
329
+ einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
330
+ cameras
331
+ ).flatten(1, -2)
332
+
333
+ if video_output_path is not None:
334
+ interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
335
+
336
+ interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
337
+
338
+ interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
339
+
340
+ imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
341
+
342
+ scene_params = scene_params[0]
343
+
344
+ scene_params = scene_params.detach().cpu()
345
+
346
+ return scene_params, ref_w2c, T_norm
347
+
348
+ if __name__ == "__main__":
349
+ parser = argparse.ArgumentParser()
350
+ parser.add_argument('--port', type=int, default=7860)
351
+ parser.add_argument("--ckpt", default=None)
352
+ parser.add_argument("--gpu", type=int, default=0)
353
+ parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
354
+ parser.add_argument("--offload_t5", type=bool, default=False)
355
+ parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
356
+ args, _ = parser.parse_known_args()
357
+
358
+ # Ensure model.ckpt exists, download if not present
359
+ if args.ckpt is None:
360
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
361
+ ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
362
+ if not os.path.exists(ckpt_path):
363
+ hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
364
+ else:
365
+ ckpt_path = args.ckpt
366
+
367
+ app = Flask(__name__)
368
+
369
+ # 初始化GenerationSystem
370
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
371
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
372
+
373
+ # 初始化并发管理器
374
+ concurrency_manager = ConcurrencyManager(max_concurrent=args.max_concurrent)
375
+
376
+ @app.after_request
377
+ def after_request(response):
378
+ response.headers.add('Access-Control-Allow-Origin', '*')
379
+ response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
380
+ response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
381
+ return response
382
+
383
+ @GPU
384
+ def generate_wrapper(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None):
385
+ """生成函数的包装器,用于并发控制"""
386
+ return generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path)
387
+
388
+ def job_generate(file_id, cache_dir, payload):
389
+ """工作线程执行的生成任务:负责生成并落盘,返回可下载信息"""
390
+ # 解包参数
391
+ cameras = payload["cameras"]
392
+ n_frame = payload["n_frame"]
393
+ image = payload["image"]
394
+ text_prompt = payload["text_prompt"]
395
+ image_index = payload["image_index"]
396
+ image_height = payload["image_height"]
397
+ image_width = payload["image_width"]
398
+ data = payload["raw_request"]
399
+
400
+ # 执行生成
401
+ scene_params, ref_w2c, T_norm = generation_system.generate(
402
+ cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None
403
+ )
404
+
405
+ # 保存请求元数据
406
+ with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
407
+ json.dump(data, f)
408
+
409
+ # 导出PLY文件
410
+ splat_path = os.path.join(cache_dir, f'{file_id}.ply')
411
+ export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
412
+
413
+ file_size = os.path.getsize(splat_path) if os.path.exists(splat_path) else 0
414
+
415
+ return {
416
+ 'file_id': file_id,
417
+ 'file_path': splat_path,
418
+ 'file_size': file_size,
419
+ 'download_url': f'/download/{file_id}'
420
+ }
421
+
422
+ @app.route('/generate', methods=['POST', 'OPTIONS'])
423
+ def generate():
424
+ # Handle preflight request
425
+ if request.method == 'OPTIONS':
426
+ return jsonify({'status': 'ok'})
427
+
428
+ try:
429
+ data = request.get_json(force=True)
430
+
431
+ image_prompt = data.get('image_prompt', None)
432
+ text_prompt = data.get('text_prompt', "")
433
+ cameras = data.get('cameras')
434
+ resolution = data.get('resolution')
435
+ image_index = data.get('image_index', 0)
436
+
437
+ n_frame, image_height, image_width = resolution
438
+
439
+ if not image_prompt and text_prompt == "":
440
+ return jsonify({'error': 'No Prompts provided'}), 400
441
+
442
+ # 处理图像
443
+ if image_prompt:
444
+ # image_prompt可以是路径和base64
445
+ if os.path.exists(image_prompt):
446
+ image_prompt = Image.open(image_prompt)
447
+ else:
448
+ # image_prompt 可能是 "data:image/png;base64,...."
449
+ if ',' in image_prompt:
450
+ image_prompt = image_prompt.split(',', 1)[1]
451
+
452
+ try:
453
+ image_bytes = base64.b64decode(image_prompt)
454
+ image_prompt = Image.open(io.BytesIO(image_bytes))
455
+ except Exception as img_e:
456
+ return jsonify({'error': f'Image decode error: {str(img_e)}'}), 400
457
+
458
+ image = image_prompt.convert('RGB')
459
+
460
+ w, h = image.size
461
+
462
+ # center crop
463
+ if image_height / h > image_width / w:
464
+ scale = image_height / h
465
+ else:
466
+ scale = image_width / w
467
+
468
+ new_h = int(image_height / scale)
469
+ new_w = int(image_width / scale)
470
+
471
+ image = image.crop(((w - new_w) // 2, (h - new_h) // 2,
472
+ new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height))
473
+
474
+ for camera in cameras:
475
+ camera['fx'] = camera['fx'] * scale
476
+ camera['fy'] = camera['fy'] * scale
477
+ camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale
478
+ camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale
479
+
480
+ image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1
481
+ else:
482
+ image = None
483
+
484
+ cameras = torch.stack([
485
+ torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32))
486
+ for camera in cameras
487
+ ], dim=0)
488
+
489
+ file_id = str(int(time.time() * 1000))
490
+
491
+ # 组装任务参数,推迟执行与落盘到工作线程中
492
+ payload = {
493
+ 'cameras': cameras,
494
+ 'n_frame': n_frame,
495
+ 'image': image,
496
+ 'text_prompt': text_prompt,
497
+ 'image_index': image_index,
498
+ 'image_height': image_height,
499
+ 'image_width': image_width,
500
+ 'raw_request': data,
501
+ }
502
+
503
+ # 提交任务到并发管理器(异步)
504
+ task_id = concurrency_manager.submit_task(
505
+ job_generate, file_id, args.cache_dir, payload
506
+ )
507
+
508
+ # 提交后立即返回队列信息
509
+ queue_status = concurrency_manager.get_queue_status()
510
+ queued_tasks = queue_status.get('queued_tasks', [])
511
+ try:
512
+ queue_position = queued_tasks.index(task_id) + 1
513
+ except ValueError:
514
+ # 如果任务已被工作线程立即领取,则认为已开始执行,位置为 0
515
+ queue_position = 0
516
+
517
+ return jsonify({
518
+ 'success': True,
519
+ 'task_id': task_id,
520
+ 'file_id': file_id,
521
+ 'queue': {
522
+ 'queued_count': queue_status.get('queued_count', 0),
523
+ 'running_count': queue_status.get('running_count', 0),
524
+ 'position': queue_position
525
+ }
526
+ }), 202
527
+
528
+ except Exception as e:
529
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
530
+
531
+ @app.route('/download/<file_id>', methods=['GET'])
532
+ def download_file(file_id):
533
+ """下载生成的PLY文件"""
534
+ file_path = os.path.join(args.cache_dir, f'{file_id}.ply')
535
+
536
+ if not os.path.exists(file_path):
537
+ return jsonify({'error': 'File not found'}), 404
538
+
539
+ return send_file(file_path, as_attachment=True, download_name=f'{file_id}.ply')
540
+
541
+ @app.route('/delete/<file_id>', methods=['DELETE', 'POST', 'OPTIONS'])
542
+ def delete_file_endpoint(file_id):
543
+ """删除生成的文件及其元数据(由前端在下载完成后调用)"""
544
+ # CORS preflight
545
+ if request.method == 'OPTIONS':
546
+ return jsonify({'status': 'ok'})
547
+
548
+ try:
549
+ ply_path = os.path.join(args.cache_dir, f'{file_id}.ply')
550
+ json_path = os.path.join(args.cache_dir, f'{file_id}.json')
551
+ deleted = []
552
+ for path in [ply_path, json_path]:
553
+ if os.path.exists(path):
554
+ os.remove(path)
555
+ deleted.append(os.path.basename(path))
556
+ return jsonify({'success': True, 'deleted': deleted})
557
+ except Exception as e:
558
+ return jsonify({'success': False, 'error': str(e)}), 500
559
+
560
+ @app.route('/status', methods=['GET'])
561
+ def get_status():
562
+ """获取系统状态和队列信息"""
563
+ try:
564
+ queue_status = concurrency_manager.get_queue_status()
565
+ return jsonify({
566
+ 'success': True,
567
+ 'status': queue_status,
568
+ 'timestamp': time.time()
569
+ })
570
+ except Exception as e:
571
+ return jsonify({'error': f'Failed to get status: {str(e)}'}), 500
572
+
573
+ @app.route('/task/<task_id>', methods=['GET'])
574
+ def get_task_status(task_id):
575
+ """获取特定任务的状态(包含排队位置和完成后的文件信息)"""
576
+ try:
577
+ task = concurrency_manager.get_task_status(task_id)
578
+ if not task:
579
+ return jsonify({'error': 'Task not found'}), 404
580
+
581
+ queue_status = concurrency_manager.get_queue_status()
582
+ queued_tasks = queue_status.get('queued_tasks', [])
583
+ try:
584
+ queue_position = queued_tasks.index(task_id) + 1
585
+ except ValueError:
586
+ queue_position = 0
587
+
588
+ resp = {
589
+ 'success': True,
590
+ 'task_id': task_id,
591
+ 'status': task.status.value,
592
+ 'created_at': task.created_at,
593
+ 'started_at': task.started_at,
594
+ 'completed_at': task.completed_at,
595
+ 'error': task.error,
596
+ 'queue': {
597
+ 'queued_count': queue_status.get('queued_count', 0),
598
+ 'running_count': queue_status.get('running_count', 0),
599
+ 'position': queue_position
600
+ }
601
+ }
602
+
603
+ if task.status.value == 'completed' and isinstance(task.result, dict):
604
+ resp.update({
605
+ 'file_id': task.result.get('file_id'),
606
+ 'file_path': task.result.get('file_path'),
607
+ 'file_size': task.result.get('file_size'),
608
+ 'download_url': task.result.get('download_url'),
609
+ 'generation_time': (task.completed_at - task.started_at)
610
+ })
611
+
612
+ # 更新task状态
613
+
614
+ return jsonify(resp)
615
+ except Exception as e:
616
+ return jsonify({'error': f'Failed to get task status: {str(e)}'}), 500
617
+
618
+ @app.route("/")
619
+ def index():
620
+ return send_file("index.html")
621
+
622
+ os.makedirs(args.cache_dir, exist_ok=True)
623
+
624
+ # 后台定时清理:删除超过30分钟未访问/修改的缓存文件
625
+ def cleanup_worker(cache_dir: str, max_age_seconds: int = 1800, interval_seconds: int = 300):
626
+ while True:
627
+ try:
628
+ now = time.time()
629
+ for name in os.listdir(cache_dir):
630
+ # 只清理与任务相关的 .ply/.json 文件
631
+ if not (name.endswith('.ply') or name.endswith('.json')):
632
+ continue
633
+ path = os.path.join(cache_dir, name)
634
+ try:
635
+ mtime = os.path.getmtime(path)
636
+ if now - mtime > max_age_seconds:
637
+ os.remove(path)
638
+ except FileNotFoundError:
639
+ pass
640
+ except Exception:
641
+ # 忽略单个文件的异常,继续清理
642
+ pass
643
+ except Exception:
644
+ # 防止线程因异常退出
645
+ pass
646
+ time.sleep(interval_seconds)
647
+
648
+ cleaner_thread = threading.Thread(target=cleanup_worker, args=(args.cache_dir,), daemon=True)
649
+ cleaner_thread.start()
650
+
651
+ app.run(host='0.0.0.0', port=args.port)
quant.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Tuple
3
+ import copy
4
+ import torch
5
+ import tqdm
6
+
7
+
8
+ def cleanup_memory():
9
+ gc.collect()
10
+ torch.cuda.empty_cache()
11
+
12
+
13
+ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
14
+ """Quantize a tensor using per-tensor static scaling factor.
15
+ Args:
16
+ tensor: The input tensor.
17
+ """
18
+ finfo = torch.finfo(torch.float8_e4m3fn)
19
+ # Calculate the scale as dtype max divided by absmax.
20
+ # Since .abs() creates a new tensor, we use aminmax to get
21
+ # the min and max first and then calculate the absmax.
22
+ if tensor.numel() == 0:
23
+ # Deal with empty tensors (triggered by empty MoE experts)
24
+ min_val, max_val = (
25
+ torch.tensor(-16.0, dtype=tensor.dtype),
26
+ torch.tensor(16.0, dtype=tensor.dtype),
27
+ )
28
+ else:
29
+ min_val, max_val = tensor.aminmax()
30
+ amax = torch.maximum(min_val.abs(), max_val.abs())
31
+ scale = finfo.max / amax.clamp(min=1e-12)
32
+ # scale and clamp the tensor to bring it to
33
+ # the representative range of float8 data type
34
+ # (as default cast is unsaturated)
35
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
36
+ # Return both float8 data and the inverse scale (as float),
37
+ # as both required as inputs to torch._scaled_mm
38
+ qweight = qweight.to(torch.float8_e4m3fn)
39
+ scale = scale.float().reciprocal()
40
+ return qweight, scale
41
+
42
+
43
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
44
+ """Quantizes a floating-point tensor to FP8 (E4M3 format) using static scaling.
45
+
46
+ Performs uniform quantization of the input tensor by:
47
+ 1. Scaling the tensor values using the provided inverse scale factor
48
+ 2. Clamping values to the representable range of FP8 E4M3 format
49
+ 3. Converting to FP8 data type
50
+
51
+ Args:
52
+ tensor (torch.Tensor): Input tensor to be quantized (any floating-point dtype)
53
+ inv_scale (float): Inverse of the quantization scale factor (1/scale)
54
+ (Must be pre-calculated based on tensor statistics)
55
+
56
+ Returns:
57
+ torch.Tensor: Quantized tensor in torch.float8_e4m3fn format
58
+
59
+ Note:
60
+ - Uses the E4M3 format (4 exponent bits, 3 mantissa bits, no infinity/nan)
61
+ - This is a static quantization (scale factor must be pre-determined)
62
+ - For dynamic quantization, see per_tensor_quantize()
63
+ """
64
+ finfo = torch.finfo(torch.float8_e4m3fn)
65
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
66
+ return qweight.to(torch.float8_e4m3fn)
67
+
68
+
69
+ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
70
+ """Performs FP8 GEMM (General Matrix Multiplication) operation with optional native hardware support.
71
+ Args:
72
+ A (torch.Tensor): Input tensor A (FP8 or other dtype)
73
+ A_scale (torch.Tensor/float): Scale factor for tensor A
74
+ B (torch.Tensor): Input tensor B (FP8 or other dtype)
75
+ B_scale (torch.Tensor/float): Scale factor for tensor B
76
+ bias (torch.Tensor/None): Optional bias tensor
77
+ out_dtype (torch.dtype): Output data type
78
+ native_fp8_support (bool): Whether to use hardware-accelerated FP8 operations
79
+
80
+ Returns:
81
+ torch.Tensor: Result of GEMM operation
82
+ """
83
+ if A.numel() == 0:
84
+ # Deal with empty tensors (triggeted by empty MoE experts)
85
+ return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
86
+
87
+ if native_fp8_support:
88
+ need_reshape = A.dim() == 3
89
+ if need_reshape:
90
+ batch_size = A.shape[0]
91
+ A_input = A.reshape(-1, A.shape[-1])
92
+ else:
93
+ batch_size = None
94
+ A_input = A
95
+ output = torch._scaled_mm(
96
+ A_input,
97
+ B.t(),
98
+ out_dtype=out_dtype,
99
+ scale_a=A_scale,
100
+ scale_b=B_scale,
101
+ bias=bias,
102
+ )
103
+ if need_reshape:
104
+ output = output.reshape(
105
+ batch_size, output.shape[0] // batch_size, output.shape[1]
106
+ )
107
+ else:
108
+ output = torch.nn.functional.linear(
109
+ A.to(out_dtype) * A_scale,
110
+ B.to(out_dtype) * B_scale.to(out_dtype),
111
+ bias=bias,
112
+ )
113
+
114
+ return output
115
+
116
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
117
+ if "." in name:
118
+ parent_name = name.rsplit(".", 1)[0]
119
+ child_name = name[len(parent_name) + 1:]
120
+ parent = model.get_submodule(parent_name)
121
+ else:
122
+ parent_name = ""
123
+ parent = model
124
+ child_name = name
125
+ setattr(parent, child_name, new_module)
126
+
127
+
128
+ # Class responsible for quantizing weights
129
+ class FP8DynamicLinear(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ weight: torch.Tensor,
133
+ weight_scale: torch.Tensor,
134
+ bias: torch.nn.Parameter,
135
+ native_fp8_support: bool = False,
136
+ dtype: torch.dtype = torch.bfloat16,
137
+ ):
138
+ super().__init__()
139
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
140
+ self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
141
+ self.bias = bias
142
+ self.native_fp8_support = native_fp8_support
143
+ self.dtype = dtype
144
+
145
+ # @torch.compile
146
+ def forward(self, x):
147
+ if x.dtype !=self.dtype:
148
+ x = x.to(self.dtype)
149
+ qinput, x_scale = per_tensor_quantize(x)
150
+ output = fp8_gemm(
151
+ A=qinput,
152
+ A_scale=x_scale,
153
+ B=self.weight,
154
+ B_scale=self.weight_scale,
155
+ bias=self.bias,
156
+ out_dtype=x.dtype,
157
+ native_fp8_support=self.native_fp8_support,
158
+ )
159
+ return output
160
+
161
+
162
+ def FluxFp8GeMMProcessor(model: torch.nn.Module):
163
+ """Processes a PyTorch model to convert eligible Linear layers to FP8 precision.
164
+
165
+ This function performs the following operations:
166
+ 1. Checks for native FP8 support on the current GPU
167
+ 2. Identifies target Linear layers in transformer blocks
168
+ 3. Quantizes weights to FP8 format
169
+ 4. Replaces original Linear layers with FP8DynamicLinear versions
170
+ 5. Performs memory cleanup
171
+
172
+ Args:
173
+ model (torch.nn.Module): The neural network model to be processed.
174
+ Should contain transformer blocks with Linear layers.
175
+ """
176
+ native_fp8_support = (
177
+ torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
178
+ )
179
+ named_modules = list(model.named_modules())
180
+ for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
181
+ if isinstance(linear, torch.nn.Linear) and "blocks" in name:
182
+ quant_weight, weight_scale = per_tensor_quantize(linear.weight)
183
+ bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
184
+ quant_linear = FP8DynamicLinear(
185
+ weight=quant_weight,
186
+ weight_scale=weight_scale,
187
+ bias=bias,
188
+ native_fp8_support=native_fp8_support,
189
+ dtype=linear.weight.dtype
190
+ )
191
+ replace_module(model, name, quant_linear)
192
+ del linear.weight
193
+ del linear.bias
194
+ del linear
195
+ cleanup_memory()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ triton==3.2.0
4
+ transformers==4.57.0
5
+ omegaconf==2.3.0
6
+ ninja==1.13.0
7
+ numpy==2.2.6
8
+ einops==0.8.1
9
+ moviepy==1.0.3
10
+ opencv-python==4.12.0.88
11
+ av==15.1.0
12
+ plyfile==1.1.2
13
+ ftfy==6.3.1
14
+ flask==3.1.2
15
+ gradio==5.49.1
16
+ gsplat==1.5.2
17
+ accelerate==1.10.1
18
+ git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
19
+ git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
utils.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import importlib
8
+ from plyfile import PlyData, PlyElement
9
+
10
+ import copy
11
+
12
+ class EmbedContainer(nn.Module):
13
+ def __init__(self, tensor):
14
+ super().__init__()
15
+ self.tensor = nn.Parameter(tensor)
16
+
17
+ def forward(self):
18
+ return self.tensor
19
+
20
+ @torch.no_grad
21
+ def zero_init(module):
22
+ if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear:
23
+ module.weight.zero_()
24
+ module.bias.zero_()
25
+ return module
26
+
27
+ def import_str(string):
28
+ # From https://github.com/CompVis/taming-transformers
29
+ module, cls = string.rsplit(".", 1)
30
+ return getattr(importlib.import_module(module, package=None), cls)
31
+
32
+ """
33
+ from https://github.com/Kai-46/minFM/blob/main/utils/ema.py
34
+ Exponential Moving Average (EMA) utilities for PyTorch models.
35
+
36
+ This module provides utilities for maintaining and updating EMA models,
37
+ which are commonly used to improve model stability and generalization
38
+ in training deep neural networks. It supports both regular tensors and
39
+ DTensors (from FSDP-wrapped models).
40
+ """
41
+ class EMA_FSDP:
42
+ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
43
+ self.decay = decay
44
+ self.shadow = {}
45
+ self._init_shadow(fsdp_module)
46
+
47
+ @torch.no_grad()
48
+ def _init_shadow(self, fsdp_module):
49
+ # 判断是否是FSDP模型
50
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
51
+ if isinstance(fsdp_module, FSDP):
52
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
53
+ for n, p in fsdp_module.module.named_parameters():
54
+ self.shadow[n] = p.detach().clone().float().cpu()
55
+ else:
56
+ for n, p in fsdp_module.named_parameters():
57
+ self.shadow[n] = p.detach().clone().float().cpu()
58
+
59
+ @torch.no_grad()
60
+ def update(self, fsdp_module):
61
+ d = self.decay
62
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
63
+ if isinstance(fsdp_module, FSDP):
64
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
65
+ for n, p in fsdp_module.module.named_parameters():
66
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
67
+ else:
68
+ for n, p in fsdp_module.named_parameters():
69
+ print(n, self.shadow[n])
70
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
71
+
72
+ # Optional helpers ---------------------------------------------------
73
+ def state_dict(self):
74
+ return self.shadow # picklable
75
+
76
+ def load_state_dict(self, sd):
77
+ self.shadow = {k: v.clone() for k, v in sd.items()}
78
+
79
+ def copy_to(self, fsdp_module):
80
+ # load EMA weights into an (unwrapped) copy of the generator
81
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
82
+ with FSDP.summon_full_params(fsdp_module, writeback=True):
83
+ for n, p in fsdp_module.module.named_parameters():
84
+ if n in self.shadow:
85
+ p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
86
+
87
+ def create_raymaps(cameras, h, w):
88
+ rays_o, rays_d = create_rays(cameras, h, w)
89
+ raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1)
90
+ return raymaps
91
+
92
+ # def create_raymaps(cameras, h, w):
93
+ # rays_o, rays_d = create_rays(cameras, h, w)
94
+ # raymaps = torch.cat([rays_d, torch.cross(rays_d, rays_o, dim=-1)], dim=-1)
95
+ # return raymaps
96
+
97
+ class EMANorm(nn.Module):
98
+ def __init__(self, beta):
99
+ super().__init__()
100
+ self.register_buffer('magnitude_ema', torch.ones([]))
101
+ self.beta = beta
102
+
103
+ def forward(self, x):
104
+ if self.training:
105
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
106
+ self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta))
107
+ input_gain = self.magnitude_ema.rsqrt()
108
+ x = x.mul(input_gain)
109
+ return x
110
+
111
+ class TimestepEmbedding(nn.Module):
112
+ def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True):
113
+ super().__init__()
114
+ self.max_period = max_period
115
+ self.time_factor = time_factor
116
+ self.dim = dim
117
+ if zero_weight:
118
+ self.weight = nn.Parameter(torch.zeros(dim))
119
+ else:
120
+ self.weight = None
121
+
122
+ def forward(self, t):
123
+ if self.weight is None:
124
+ return timestep_embedding(t, self.dim, self.max_period, self.time_factor)
125
+ else:
126
+ return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
127
+
128
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
129
+ def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
130
+ """
131
+ Create sinusoidal timestep embeddings.
132
+ :param t: a 1-D Tensor of N indices, one per batch element.
133
+ These may be fractional.
134
+ :param dim: the dimension of the output.
135
+ :param max_period: controls the minimum frequency of the embeddings.
136
+ :return: an (N, D) Tensor of positional embeddings.
137
+ """
138
+ t = time_factor * t
139
+ half = dim // 2
140
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
141
+
142
+ args = t[:, None].float() * freqs[None]
143
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144
+ if dim % 2:
145
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
146
+ if torch.is_floating_point(t):
147
+ embedding = embedding.to(t)
148
+ return embedding
149
+
150
+ def quaternion_to_matrix(quaternions):
151
+ """
152
+ Convert rotations given as quaternions to rotation matrices.
153
+ Args:
154
+ quaternions: quaternions with real part first,
155
+ as tensor of shape (..., 4).
156
+ Returns:
157
+ Rotation matrices as tensor of shape (..., 3, 3).
158
+ """
159
+ r, i, j, k = torch.unbind(quaternions, -1)
160
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
161
+
162
+ o = torch.stack(
163
+ (
164
+ 1 - two_s * (j * j + k * k),
165
+ two_s * (i * j - k * r),
166
+ two_s * (i * k + j * r),
167
+ two_s * (i * j + k * r),
168
+ 1 - two_s * (i * i + k * k),
169
+ two_s * (j * k - i * r),
170
+ two_s * (i * k - j * r),
171
+ two_s * (j * k + i * r),
172
+ 1 - two_s * (i * i + j * j),
173
+ ),
174
+ -1,
175
+ )
176
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
177
+
178
+ # from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
179
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Convert a unit quaternion to a standard form: one in which the real
182
+ part is non negative.
183
+
184
+ Args:
185
+ quaternions: Quaternions with real part first,
186
+ as tensor of shape (..., 4).
187
+
188
+ Returns:
189
+ Standardized quaternions as tensor of shape (..., 4).
190
+ """
191
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
192
+
193
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
194
+ """
195
+ Returns torch.sqrt(torch.max(0, x))
196
+ but with a zero subgradient where x is 0.
197
+ """
198
+ ret = torch.zeros_like(x)
199
+ positive_mask = x > 0
200
+ if torch.is_grad_enabled():
201
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
202
+ else:
203
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
204
+ return ret
205
+
206
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
207
+ """
208
+ Convert rotations given as rotation matrices to quaternions.
209
+
210
+ Args:
211
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
212
+
213
+ Returns:
214
+ quaternions with real part first, as tensor of shape (..., 4).
215
+ """
216
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
217
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
218
+
219
+ batch_dim = matrix.shape[:-2]
220
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
221
+ matrix.reshape(batch_dim + (9,)), dim=-1
222
+ )
223
+
224
+ q_abs = _sqrt_positive_part(
225
+ torch.stack(
226
+ [
227
+ 1.0 + m00 + m11 + m22,
228
+ 1.0 + m00 - m11 - m22,
229
+ 1.0 - m00 + m11 - m22,
230
+ 1.0 - m00 - m11 + m22,
231
+ ],
232
+ dim=-1,
233
+ )
234
+ )
235
+
236
+ # we produce the desired quaternion multiplied by each of r, i, j, k
237
+ quat_by_rijk = torch.stack(
238
+ [
239
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
240
+ # `int`.
241
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
242
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
243
+ # `int`.
244
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
245
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
246
+ # `int`.
247
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
248
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
249
+ # `int`.
250
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
251
+ ],
252
+ dim=-2,
253
+ )
254
+
255
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
256
+ # the candidate won't be picked.
257
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
258
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
259
+
260
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
261
+ # forall i; we pick the best-conditioned one (with the largest denominator)
262
+ indices = q_abs.argmax(dim=-1, keepdim=True)
263
+ expand_dims = list(batch_dim) + [1, 4]
264
+ gather_indices = indices.unsqueeze(-1).expand(expand_dims)
265
+ out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
266
+ return standardize_quaternion(out)
267
+
268
+ @torch.amp.autocast(device_type="cuda", enabled=False)
269
+ def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None):
270
+ B, N = cameras.shape[:2]
271
+
272
+ c2ws = torch.zeros(B, N, 3, 4, device=cameras.device)
273
+
274
+ c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4])
275
+ c2ws[..., :, 3] = cameras[..., 4:7]
276
+
277
+ _c2ws = c2ws
278
+
279
+ ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c
280
+ _c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :]
281
+
282
+ if n_frame is not None:
283
+ T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
284
+ else:
285
+ T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
286
+
287
+ _c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2)
288
+
289
+ R = matrix_to_quaternion(_c2ws[..., :3, :3])
290
+ T = _c2ws[..., :3, 3]
291
+ cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1)
292
+
293
+ if return_meta:
294
+ return cameras, ref_w2c, T_norm
295
+ else:
296
+ return cameras
297
+
298
+ def create_rays(cameras, h, w, uv_offset=None):
299
+ prefix_shape = cameras.shape[:-1]
300
+ cameras = cameras.flatten(0, -2)
301
+ device = cameras.device
302
+ N = cameras.shape[0]
303
+
304
+ c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1)
305
+ c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4])
306
+ c2w[:, :3, 3] = cameras[:, 4:7]
307
+
308
+ # fx, fy, cx, cy should be divided by original H, W
309
+ fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1)
310
+
311
+ fx, cx = fx * w, cx * w
312
+ fy, cy = fy * h, cy * h
313
+
314
+ inds = torch.arange(0, h*w, device=device).expand(N, h*w)
315
+
316
+ i = inds % w + 0.5
317
+ j = torch.div(inds, w, rounding_mode='floor') + 0.5
318
+
319
+ u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0)
320
+ v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0)
321
+
322
+ zs = - torch.ones_like(i)
323
+ xs = - (u - 1) * cx / fx * zs
324
+ ys = (v - 1) * cy / fy * zs
325
+ directions = torch.stack((xs, ys, zs), dim=-1)
326
+
327
+ rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1)
328
+
329
+ rays_o = c2w[..., :3, 3] # [B, 3]
330
+ rays_o = rays_o[..., None, :].expand_as(rays_d)
331
+
332
+ rays_o = rays_o.reshape(*prefix_shape, h, w, 3)
333
+ rays_d = rays_d.reshape(*prefix_shape, h, w, 3)
334
+
335
+ return rays_o, rays_d
336
+
337
+ def matrix_to_square(mat):
338
+ l = len(mat.shape)
339
+ if l==3:
340
+ return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1)
341
+ elif l==4:
342
+ return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
343
+
344
+ def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
345
+
346
+ sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
347
+
348
+ xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
349
+
350
+ means3D = xyz.contiguous().float()
351
+ opacity = opacity.contiguous().float()
352
+ scales = scale.contiguous().float()
353
+ rotations = rotation.contiguous().float()
354
+ shs = feature.contiguous().float() # [N, 1, 3]
355
+
356
+ # print(means3D.shape, opacity.shape, scales.shape, rotations.shape, shs.shape)
357
+
358
+ # prune by opacity
359
+ if opacity_threshold > 0:
360
+ mask = opacity[..., 0] >= opacity_threshold
361
+ means3D = means3D[mask]
362
+ opacity = opacity[mask]
363
+ scales = scales[mask]
364
+ rotations = rotations[mask]
365
+ shs = shs[mask]
366
+
367
+ print("Gaussian percentage: ", mask.float().mean())
368
+
369
+ if T_norm is not None:
370
+ means3D = means3D * T_norm.item()
371
+ scales = scales * T_norm.item()
372
+
373
+ # invert activation to make it compatible with the original ply format
374
+ opacity = torch.log(opacity/(1-opacity))
375
+ scales = torch.log(scales + 1e-8)
376
+
377
+ xyzs = means3D.detach() # .cpu().numpy()
378
+ f_dc = shs.detach().flatten(start_dim=1).contiguous() #.cpu().numpy()
379
+ opacities = opacity.detach() #.cpu().numpy()
380
+ scales = scales.detach() #.cpu().numpy()
381
+ rotations = rotations.detach() #.cpu().numpy()
382
+
383
+ l = ['x', 'y', 'z']
384
+ # All channels except the 3 DC
385
+ for i in range(f_dc.shape[1]):
386
+ l.append('f_dc_{}'.format(i))
387
+ l.append('opacity')
388
+ for i in range(scales.shape[1]):
389
+ l.append('scale_{}'.format(i))
390
+ for i in range(rotations.shape[1]):
391
+ l.append('rot_{}'.format(i))
392
+
393
+ dtype_full = [(attribute, 'f4') for attribute in l]
394
+
395
+ # 最优化方案:使用numpy的recarray直接创建
396
+ attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
397
+
398
+ # 使用recarray直接创建,避免循环和类型转换
399
+ elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
400
+ el = PlyElement.describe(elements, 'vertex')
401
+
402
+ print(path)
403
+
404
+ PlyData([el]).write(path)
405
+
406
+ # plydata = PlyData([el])
407
+
408
+ # vert = plydata["vertex"]
409
+ # sorted_indices = np.argsort(
410
+ # -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
411
+ # / (1 + np.exp(-vert["opacity"]))
412
+ # )
413
+ # buffer = BytesIO()
414
+ # for idx in sorted_indices:
415
+ # v = plydata["vertex"][idx]
416
+ # position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
417
+ # scales = np.exp(
418
+ # np.array(
419
+ # [v["scale_0"], v["scale_1"], v["scale_2"]],
420
+ # dtype=np.float32,
421
+ # )
422
+ # )
423
+ # rot = np.array(
424
+ # [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
425
+ # dtype=np.float32,
426
+ # )
427
+ # SH_C0 = 0.28209479177387814
428
+ # color = np.array(
429
+ # [
430
+ # 0.5 + SH_C0 * v["f_dc_0"],
431
+ # 0.5 + SH_C0 * v["f_dc_1"],
432
+ # 0.5 + SH_C0 * v["f_dc_2"],
433
+ # 1 / (1 + np.exp(-v["opacity"])),
434
+ # ]
435
+ # )
436
+ # buffer.write(position.tobytes())
437
+ # buffer.write(scales.tobytes())
438
+ # buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
439
+ # buffer.write(
440
+ # ((rot / np.linalg.norm(rot)) * 128 + 128)
441
+ # .clip(0, 255)
442
+ # .astype(np.uint8)
443
+ # .tobytes()
444
+ # )
445
+
446
+ # with open(path + '.splat', "wb") as f:
447
+ # f.write(buffer.getvalue())
448
+
449
+ @torch.amp.autocast(device_type="cuda", enabled=False)
450
+ def quaternion_slerp(
451
+ q0, q1, fraction, spin: int = 0, shortestpath: bool = True
452
+ ):
453
+ """Return spherical linear interpolation between two quaternions.
454
+ Args:
455
+ quat0: first quaternion
456
+ quat1: second quaternion
457
+ fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
458
+ spin: how much of an additional spin to place on the interpolation
459
+ shortestpath: whether to return the short or long path to rotation
460
+ """
461
+ d = (q0 * q1).sum(-1)
462
+ if shortestpath:
463
+ # invert rotation
464
+ d[d < 0.0] = -d[d < 0.0]
465
+ q1[d < 0.0] = q1[d < 0.0]
466
+
467
+ _d = d.clamp(0, 1.0)
468
+
469
+ # theta = torch.arccos(d) * fraction
470
+ # q2 = q1 - q0 * d
471
+ # q2 = q2 / (q2.norm(dim=-1) + 1e-10)
472
+
473
+ # return torch.cos(theta) * q0 + torch.sin(theta) * q2
474
+
475
+ angle = torch.acos(_d) + spin * math.pi
476
+ isin = 1.0 / (torch.sin(angle)+ 1e-10)
477
+ q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None]
478
+ q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None]
479
+
480
+ q = q0_ + q1_
481
+
482
+ q[angle < 1e-5] = q0[angle < 1e-5]
483
+ # q[fraction < 1e-5] = q0[fraction < 1e-5]
484
+ # q[fraction > 1 - 1e-5] = q1[fraction > 1 - 1e-5]
485
+ # q[(d.abs() - 1).abs() < 1e-5] = q0[(d.abs() - 1).abs() < 1e-5]
486
+
487
+ return q
488
+
489
+ def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]):
490
+ """
491
+ Args:
492
+ pose_a: first pose
493
+ pose_b: second pose
494
+ fraction
495
+ """
496
+
497
+ quat_a = pose_a[..., :4]
498
+ quat_b = pose_b[..., :4]
499
+
500
+ dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True)
501
+ quat_b = torch.where(dot < 0, -quat_b, quat_b)
502
+
503
+ quaternion = quaternion_slerp(quat_a, quat_b, fraction)
504
+ quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1)
505
+
506
+ T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:]
507
+ T = T + torch.randn_like(T) * noise_strengths[1]
508
+
509
+ new_pose = pose_a.clone()
510
+ new_pose[..., :4] = quaternion
511
+ new_pose[..., 4:] = T
512
+ return new_pose
513
+
514
+ def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]):
515
+ N, C = dense_cameras.shape
516
+ M = t.shape
517
+
518
+ left = torch.floor(t * (N-1)).long().clamp(0, N-2)
519
+ right = left + 1
520
+ fraction = t * (N-1) - left
521
+
522
+ a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C))
523
+ b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C))
524
+
525
+ new_pose = sample_from_two_pose(a[:, :7],
526
+ b[:, :7], fraction, noise_strengths=noise_strengths[:2])
527
+
528
+ new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:]
529
+
530
+ return torch.cat([new_pose, new_ins], dim=1)
531
+