Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import uuid | |
| import cv2 | |
| import glob | |
| import torch | |
| import logging | |
| from textwrap import indent | |
| import torch.nn as nn | |
| from diffusers import FluxPipeline | |
| from tqdm import tqdm | |
| from ovi.distributed_comms.parallel_states import get_sequence_parallel_state, nccl_info | |
| from ovi.utils.model_loading_utils import init_fusion_score_model_ovi, init_text_model, init_mmaudio_vae, init_wan_vae_2_2, load_fusion_checkpoint | |
| from ovi.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from ovi.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, | |
| get_sampling_sigmas, retrieve_timesteps) | |
| import traceback | |
| from omegaconf import OmegaConf | |
| from ovi.utils.processing_utils import clean_text, preprocess_image_tensor, snap_hw_to_multiple_of_32, scale_hw_to_area_divisible | |
| DEFAULT_CONFIG = OmegaConf.load('ovi/configs/inference/inference_fusion.yaml') | |
| class OviFusionEngine: | |
| def __init__(self, config=DEFAULT_CONFIG, device=0, target_dtype=torch.bfloat16): | |
| # Load fusion model | |
| self.device = device | |
| self.target_dtype = target_dtype | |
| meta_init = True | |
| self.cpu_offload = config.get("cpu_offload", False) or config.get("mode") == "t2i2v" | |
| if self.cpu_offload: | |
| logging.info("CPU offloading is enabled. Initializing all models aside from VAEs on CPU") | |
| model, video_config, audio_config = init_fusion_score_model_ovi(rank=device, meta_init=meta_init) | |
| if not meta_init: | |
| model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval() | |
| # Load VAEs | |
| vae_model_video = init_wan_vae_2_2(config.ckpt_dir, rank=device) | |
| vae_model_video.model.requires_grad_(False).eval() | |
| vae_model_video.model = vae_model_video.model.bfloat16() | |
| self.vae_model_video = vae_model_video | |
| vae_model_audio = init_mmaudio_vae(config.ckpt_dir, rank=device) | |
| vae_model_audio.requires_grad_(False).eval() | |
| self.vae_model_audio = vae_model_audio.bfloat16() | |
| # Load T5 text model | |
| self.text_model = init_text_model(config.ckpt_dir, rank=device) | |
| if config.get("shard_text_model", False): | |
| raise NotImplementedError("Sharding text model is not implemented yet.") | |
| if self.cpu_offload: | |
| self.offload_to_cpu(self.text_model.model) | |
| # Find fusion ckpt in the same dir used by other components | |
| checkpoint_path = os.path.join(config.ckpt_dir, "Ovi", "model.safetensors") | |
| if not os.path.exists(checkpoint_path): | |
| raise RuntimeError(f"No fusion checkpoint found in {config.ckpt_dir}") | |
| load_fusion_checkpoint(model, checkpoint_path=checkpoint_path, from_meta=meta_init) | |
| if meta_init: | |
| model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval() | |
| model.set_rope_params() | |
| self.model = model | |
| ## Load t2i as part of pipeline | |
| self.image_model = None | |
| if config.get("mode") == "t2i2v": | |
| logging.info(f"Loading Flux Krea for first frame generation...") | |
| self.image_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16) | |
| self.image_model.enable_model_cpu_offload(gpu_id=self.device) #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM | |
| # Fixed attributes, non-configurable | |
| self.audio_latent_channel = audio_config.get("in_dim") | |
| self.video_latent_channel = video_config.get("in_dim") | |
| self.audio_latent_length = 157 | |
| self.video_latent_length = 31 | |
| logging.info(f"OVI Fusion Engine initialized, cpu_offload={self.cpu_offload}. GPU VRAM allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB, reserved: {torch.cuda.memory_reserved(device)/1e9:.2f} GB") | |
| def generate(self, | |
| text_prompt, | |
| image_path=None, | |
| video_frame_height_width=None, | |
| seed=100, | |
| solver_name="unipc", | |
| sample_steps=50, | |
| shift=5.0, | |
| video_guidance_scale=5.0, | |
| audio_guidance_scale=4.0, | |
| slg_layer=9, | |
| video_negative_prompt="", | |
| audio_negative_prompt="" | |
| ): | |
| params = { | |
| "Text Prompt": text_prompt, | |
| "Image Path": image_path if image_path else "None (T2V mode)", | |
| "Frame Height Width": video_frame_height_width, | |
| "Seed": seed, | |
| "Solver": solver_name, | |
| "Sample Steps": sample_steps, | |
| "Shift": shift, | |
| "Video Guidance Scale": video_guidance_scale, | |
| "Audio Guidance Scale": audio_guidance_scale, | |
| "SLG Layer": slg_layer, | |
| "Video Negative Prompt": video_negative_prompt, | |
| "Audio Negative Prompt": audio_negative_prompt, | |
| } | |
| pretty = "\n".join(f"{k:>24}: {v}" for k, v in params.items()) | |
| logging.info("\n========== Generation Parameters ==========\n" | |
| f"{pretty}\n" | |
| "==========================================") | |
| try: | |
| scheduler_video, timesteps_video = self.get_scheduler_time_steps( | |
| sampling_steps=sample_steps, | |
| device=self.device, | |
| solver_name=solver_name, | |
| shift=shift | |
| ) | |
| scheduler_audio, timesteps_audio = self.get_scheduler_time_steps( | |
| sampling_steps=sample_steps, | |
| device=self.device, | |
| solver_name=solver_name, | |
| shift=shift | |
| ) | |
| is_t2v = image_path is None | |
| is_i2v = not is_t2v | |
| first_frame = None | |
| image = None | |
| if is_i2v and not self.image_model: | |
| # Load first frame from path | |
| first_frame = preprocess_image_tensor(image_path, self.device, self.target_dtype) | |
| else: | |
| assert video_frame_height_width is not None, f"If mode=t2v or t2i2v, video_frame_height_width must be provided." | |
| video_h, video_w = video_frame_height_width | |
| video_h, video_w = snap_hw_to_multiple_of_32(video_h, video_w, area = 720 * 720) | |
| video_latent_h, video_latent_w = video_h // 16, video_w // 16 | |
| if self.image_model is not None: | |
| # this already means t2v mode with image model | |
| image_h, image_w = scale_hw_to_area_divisible(video_h, video_w, area = 1024 * 1024) | |
| image = self.image_model( | |
| clean_text(text_prompt), | |
| height=image_h, | |
| width=image_w, | |
| guidance_scale=4.5, | |
| generator=torch.Generator().manual_seed(seed) | |
| ).images[0] | |
| first_frame = preprocess_image_tensor(image, self.device, self.target_dtype) | |
| is_i2v = True | |
| else: | |
| print(f"Pure T2V mode: calculated video latent size: {video_latent_h} x {video_latent_w}") | |
| if self.cpu_offload: | |
| self.text_model.model = self.text_model.model.to(self.device) | |
| text_embeddings = self.text_model([text_prompt, video_negative_prompt, audio_negative_prompt], self.text_model.device) | |
| text_embeddings = [emb.to(self.target_dtype).to(self.device) for emb in text_embeddings] | |
| if self.cpu_offload: | |
| self.offload_to_cpu(self.text_model.model) | |
| # Split embeddings | |
| text_embeddings_audio_pos = text_embeddings[0] | |
| text_embeddings_video_pos = text_embeddings[0] | |
| text_embeddings_video_neg = text_embeddings[1] | |
| text_embeddings_audio_neg = text_embeddings[2] | |
| if is_i2v: | |
| with torch.no_grad(): | |
| latents_images = self.vae_model_video.wrapped_encode(first_frame[:, :, None]).to(self.target_dtype).squeeze(0) # c 1 h w | |
| latents_images = latents_images.to(self.target_dtype) | |
| video_latent_h, video_latent_w = latents_images.shape[2], latents_images.shape[3] | |
| video_noise = torch.randn((self.video_latent_channel, self.video_latent_length, video_latent_h, video_latent_w), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed)) # c, f, h, w | |
| audio_noise = torch.randn((self.audio_latent_length, self.audio_latent_channel), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed)) # 1, l c -> l, c | |
| # Calculate sequence lengths from actual latents | |
| max_seq_len_audio = audio_noise.shape[0] # L dimension from latents_audios shape [1, L, D] | |
| _patch_size_h, _patch_size_w = self.model.video_model.patch_size[1], self.model.video_model.patch_size[2] | |
| max_seq_len_video = video_noise.shape[1] * video_noise.shape[2] * video_noise.shape[3] // (_patch_size_h*_patch_size_w) # f * h * w from [1, c, f, h, w] | |
| # Sampling loop | |
| if self.cpu_offload: | |
| self.model = self.model.to(self.device) | |
| with torch.amp.autocast('cuda', enabled=self.target_dtype != torch.float32, dtype=self.target_dtype): | |
| for i, (t_v, t_a) in tqdm(enumerate(zip(timesteps_video, timesteps_audio))): | |
| timestep_input = torch.full((1,), t_v, device=self.device) | |
| if is_i2v: | |
| video_noise[:, :1] = latents_images | |
| # Positive (conditional) forward pass | |
| pos_forward_args = { | |
| 'audio_context': [text_embeddings_audio_pos], | |
| 'vid_context': [text_embeddings_video_pos], | |
| 'vid_seq_len': max_seq_len_video, | |
| 'audio_seq_len': max_seq_len_audio, | |
| 'first_frame_is_clean': is_i2v | |
| } | |
| pred_vid_pos, pred_audio_pos = self.model( | |
| vid=[video_noise], | |
| audio=[audio_noise], | |
| t=timestep_input, | |
| **pos_forward_args | |
| ) | |
| # Negative (unconditional) forward pass | |
| neg_forward_args = { | |
| 'audio_context': [text_embeddings_audio_neg], | |
| 'vid_context': [text_embeddings_video_neg], | |
| 'vid_seq_len': max_seq_len_video, | |
| 'audio_seq_len': max_seq_len_audio, | |
| 'first_frame_is_clean': is_i2v, | |
| 'slg_layer': slg_layer | |
| } | |
| pred_vid_neg, pred_audio_neg = self.model( | |
| vid=[video_noise], | |
| audio=[audio_noise], | |
| t=timestep_input, | |
| **neg_forward_args | |
| ) | |
| # Apply classifier-free guidance | |
| pred_video_guided = pred_vid_neg[0] + video_guidance_scale * (pred_vid_pos[0] - pred_vid_neg[0]) | |
| pred_audio_guided = pred_audio_neg[0] + audio_guidance_scale * (pred_audio_pos[0] - pred_audio_neg[0]) | |
| # Update noise using scheduler | |
| video_noise = scheduler_video.step( | |
| pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False | |
| )[0].squeeze(0) | |
| audio_noise = scheduler_audio.step( | |
| pred_audio_guided.unsqueeze(0), t_a, audio_noise.unsqueeze(0), return_dict=False | |
| )[0].squeeze(0) | |
| if self.cpu_offload: | |
| self.offload_to_cpu(self.model) | |
| if is_i2v: | |
| video_noise[:, :1] = latents_images | |
| # Decode audio | |
| audio_latents_for_vae = audio_noise.unsqueeze(0).transpose(1, 2) # 1, c, l | |
| generated_audio = self.vae_model_audio.wrapped_decode(audio_latents_for_vae) | |
| generated_audio = generated_audio.squeeze().cpu().float().numpy() | |
| # Decode video | |
| video_latents_for_vae = video_noise.unsqueeze(0) # 1, c, f, h, w | |
| generated_video = self.vae_model_video.wrapped_decode(video_latents_for_vae) | |
| generated_video = generated_video.squeeze(0).cpu().float().numpy() # c, f, h, w | |
| return generated_video, generated_audio, image | |
| except Exception as e: | |
| logging.error(traceback.format_exc()) | |
| return None | |
| def offload_to_cpu(self, model): | |
| model = model.cpu() | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| return model | |
| def get_scheduler_time_steps(self, sampling_steps, solver_name='unipc', device=0, shift=5.0): | |
| torch.manual_seed(4) | |
| if solver_name == 'unipc': | |
| sample_scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=1000, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sample_scheduler.set_timesteps( | |
| sampling_steps, device=device, shift=shift) | |
| timesteps = sample_scheduler.timesteps | |
| elif solver_name == 'dpm++': | |
| sample_scheduler = FlowDPMSolverMultistepScheduler( | |
| num_train_timesteps=1000, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sampling_sigmas = get_sampling_sigmas(sampling_steps, shift=shift) | |
| timesteps, _ = retrieve_timesteps( | |
| sample_scheduler, | |
| device=device, | |
| sigmas=sampling_sigmas) | |
| elif solver_name == 'euler': | |
| sample_scheduler = FlowMatchEulerDiscreteScheduler( | |
| shift=shift | |
| ) | |
| timesteps, sampling_steps = retrieve_timesteps( | |
| sample_scheduler, | |
| sampling_steps, | |
| device=device, | |
| ) | |
| else: | |
| raise NotImplementedError("Unsupported solver.") | |
| return sample_scheduler, timesteps | |