Ovi-ZEROGPU / ovi /ovi_fusion_engine.py
alexnasa's picture
Upload 121 files
a3a2e41 verified
import os
import sys
import uuid
import cv2
import glob
import torch
import logging
from textwrap import indent
import torch.nn as nn
from diffusers import FluxPipeline
from tqdm import tqdm
from ovi.distributed_comms.parallel_states import get_sequence_parallel_state, nccl_info
from ovi.utils.model_loading_utils import init_fusion_score_model_ovi, init_text_model, init_mmaudio_vae, init_wan_vae_2_2, load_fusion_checkpoint
from ovi.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from diffusers import FlowMatchEulerDiscreteScheduler
from ovi.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
import traceback
from omegaconf import OmegaConf
from ovi.utils.processing_utils import clean_text, preprocess_image_tensor, snap_hw_to_multiple_of_32, scale_hw_to_area_divisible
DEFAULT_CONFIG = OmegaConf.load('ovi/configs/inference/inference_fusion.yaml')
class OviFusionEngine:
def __init__(self, config=DEFAULT_CONFIG, device=0, target_dtype=torch.bfloat16):
# Load fusion model
self.device = device
self.target_dtype = target_dtype
meta_init = True
self.cpu_offload = config.get("cpu_offload", False) or config.get("mode") == "t2i2v"
if self.cpu_offload:
logging.info("CPU offloading is enabled. Initializing all models aside from VAEs on CPU")
model, video_config, audio_config = init_fusion_score_model_ovi(rank=device, meta_init=meta_init)
if not meta_init:
model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval()
# Load VAEs
vae_model_video = init_wan_vae_2_2(config.ckpt_dir, rank=device)
vae_model_video.model.requires_grad_(False).eval()
vae_model_video.model = vae_model_video.model.bfloat16()
self.vae_model_video = vae_model_video
vae_model_audio = init_mmaudio_vae(config.ckpt_dir, rank=device)
vae_model_audio.requires_grad_(False).eval()
self.vae_model_audio = vae_model_audio.bfloat16()
# Load T5 text model
self.text_model = init_text_model(config.ckpt_dir, rank=device)
if config.get("shard_text_model", False):
raise NotImplementedError("Sharding text model is not implemented yet.")
if self.cpu_offload:
self.offload_to_cpu(self.text_model.model)
# Find fusion ckpt in the same dir used by other components
checkpoint_path = os.path.join(config.ckpt_dir, "Ovi", "model.safetensors")
if not os.path.exists(checkpoint_path):
raise RuntimeError(f"No fusion checkpoint found in {config.ckpt_dir}")
load_fusion_checkpoint(model, checkpoint_path=checkpoint_path, from_meta=meta_init)
if meta_init:
model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval()
model.set_rope_params()
self.model = model
## Load t2i as part of pipeline
self.image_model = None
if config.get("mode") == "t2i2v":
logging.info(f"Loading Flux Krea for first frame generation...")
self.image_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
self.image_model.enable_model_cpu_offload(gpu_id=self.device) #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM
# Fixed attributes, non-configurable
self.audio_latent_channel = audio_config.get("in_dim")
self.video_latent_channel = video_config.get("in_dim")
self.audio_latent_length = 157
self.video_latent_length = 31
logging.info(f"OVI Fusion Engine initialized, cpu_offload={self.cpu_offload}. GPU VRAM allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB, reserved: {torch.cuda.memory_reserved(device)/1e9:.2f} GB")
@torch.inference_mode()
def generate(self,
text_prompt,
image_path=None,
video_frame_height_width=None,
seed=100,
solver_name="unipc",
sample_steps=50,
shift=5.0,
video_guidance_scale=5.0,
audio_guidance_scale=4.0,
slg_layer=9,
video_negative_prompt="",
audio_negative_prompt=""
):
params = {
"Text Prompt": text_prompt,
"Image Path": image_path if image_path else "None (T2V mode)",
"Frame Height Width": video_frame_height_width,
"Seed": seed,
"Solver": solver_name,
"Sample Steps": sample_steps,
"Shift": shift,
"Video Guidance Scale": video_guidance_scale,
"Audio Guidance Scale": audio_guidance_scale,
"SLG Layer": slg_layer,
"Video Negative Prompt": video_negative_prompt,
"Audio Negative Prompt": audio_negative_prompt,
}
pretty = "\n".join(f"{k:>24}: {v}" for k, v in params.items())
logging.info("\n========== Generation Parameters ==========\n"
f"{pretty}\n"
"==========================================")
try:
scheduler_video, timesteps_video = self.get_scheduler_time_steps(
sampling_steps=sample_steps,
device=self.device,
solver_name=solver_name,
shift=shift
)
scheduler_audio, timesteps_audio = self.get_scheduler_time_steps(
sampling_steps=sample_steps,
device=self.device,
solver_name=solver_name,
shift=shift
)
is_t2v = image_path is None
is_i2v = not is_t2v
first_frame = None
image = None
if is_i2v and not self.image_model:
# Load first frame from path
first_frame = preprocess_image_tensor(image_path, self.device, self.target_dtype)
else:
assert video_frame_height_width is not None, f"If mode=t2v or t2i2v, video_frame_height_width must be provided."
video_h, video_w = video_frame_height_width
video_h, video_w = snap_hw_to_multiple_of_32(video_h, video_w, area = 720 * 720)
video_latent_h, video_latent_w = video_h // 16, video_w // 16
if self.image_model is not None:
# this already means t2v mode with image model
image_h, image_w = scale_hw_to_area_divisible(video_h, video_w, area = 1024 * 1024)
image = self.image_model(
clean_text(text_prompt),
height=image_h,
width=image_w,
guidance_scale=4.5,
generator=torch.Generator().manual_seed(seed)
).images[0]
first_frame = preprocess_image_tensor(image, self.device, self.target_dtype)
is_i2v = True
else:
print(f"Pure T2V mode: calculated video latent size: {video_latent_h} x {video_latent_w}")
if self.cpu_offload:
self.text_model.model = self.text_model.model.to(self.device)
text_embeddings = self.text_model([text_prompt, video_negative_prompt, audio_negative_prompt], self.text_model.device)
text_embeddings = [emb.to(self.target_dtype).to(self.device) for emb in text_embeddings]
if self.cpu_offload:
self.offload_to_cpu(self.text_model.model)
# Split embeddings
text_embeddings_audio_pos = text_embeddings[0]
text_embeddings_video_pos = text_embeddings[0]
text_embeddings_video_neg = text_embeddings[1]
text_embeddings_audio_neg = text_embeddings[2]
if is_i2v:
with torch.no_grad():
latents_images = self.vae_model_video.wrapped_encode(first_frame[:, :, None]).to(self.target_dtype).squeeze(0) # c 1 h w
latents_images = latents_images.to(self.target_dtype)
video_latent_h, video_latent_w = latents_images.shape[2], latents_images.shape[3]
video_noise = torch.randn((self.video_latent_channel, self.video_latent_length, video_latent_h, video_latent_w), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed)) # c, f, h, w
audio_noise = torch.randn((self.audio_latent_length, self.audio_latent_channel), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed)) # 1, l c -> l, c
# Calculate sequence lengths from actual latents
max_seq_len_audio = audio_noise.shape[0] # L dimension from latents_audios shape [1, L, D]
_patch_size_h, _patch_size_w = self.model.video_model.patch_size[1], self.model.video_model.patch_size[2]
max_seq_len_video = video_noise.shape[1] * video_noise.shape[2] * video_noise.shape[3] // (_patch_size_h*_patch_size_w) # f * h * w from [1, c, f, h, w]
# Sampling loop
if self.cpu_offload:
self.model = self.model.to(self.device)
with torch.amp.autocast('cuda', enabled=self.target_dtype != torch.float32, dtype=self.target_dtype):
for i, (t_v, t_a) in tqdm(enumerate(zip(timesteps_video, timesteps_audio))):
timestep_input = torch.full((1,), t_v, device=self.device)
if is_i2v:
video_noise[:, :1] = latents_images
# Positive (conditional) forward pass
pos_forward_args = {
'audio_context': [text_embeddings_audio_pos],
'vid_context': [text_embeddings_video_pos],
'vid_seq_len': max_seq_len_video,
'audio_seq_len': max_seq_len_audio,
'first_frame_is_clean': is_i2v
}
pred_vid_pos, pred_audio_pos = self.model(
vid=[video_noise],
audio=[audio_noise],
t=timestep_input,
**pos_forward_args
)
# Negative (unconditional) forward pass
neg_forward_args = {
'audio_context': [text_embeddings_audio_neg],
'vid_context': [text_embeddings_video_neg],
'vid_seq_len': max_seq_len_video,
'audio_seq_len': max_seq_len_audio,
'first_frame_is_clean': is_i2v,
'slg_layer': slg_layer
}
pred_vid_neg, pred_audio_neg = self.model(
vid=[video_noise],
audio=[audio_noise],
t=timestep_input,
**neg_forward_args
)
# Apply classifier-free guidance
pred_video_guided = pred_vid_neg[0] + video_guidance_scale * (pred_vid_pos[0] - pred_vid_neg[0])
pred_audio_guided = pred_audio_neg[0] + audio_guidance_scale * (pred_audio_pos[0] - pred_audio_neg[0])
# Update noise using scheduler
video_noise = scheduler_video.step(
pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False
)[0].squeeze(0)
audio_noise = scheduler_audio.step(
pred_audio_guided.unsqueeze(0), t_a, audio_noise.unsqueeze(0), return_dict=False
)[0].squeeze(0)
if self.cpu_offload:
self.offload_to_cpu(self.model)
if is_i2v:
video_noise[:, :1] = latents_images
# Decode audio
audio_latents_for_vae = audio_noise.unsqueeze(0).transpose(1, 2) # 1, c, l
generated_audio = self.vae_model_audio.wrapped_decode(audio_latents_for_vae)
generated_audio = generated_audio.squeeze().cpu().float().numpy()
# Decode video
video_latents_for_vae = video_noise.unsqueeze(0) # 1, c, f, h, w
generated_video = self.vae_model_video.wrapped_decode(video_latents_for_vae)
generated_video = generated_video.squeeze(0).cpu().float().numpy() # c, f, h, w
return generated_video, generated_audio, image
except Exception as e:
logging.error(traceback.format_exc())
return None
def offload_to_cpu(self, model):
model = model.cpu()
torch.cuda.synchronize()
torch.cuda.empty_cache()
return model
def get_scheduler_time_steps(self, sampling_steps, solver_name='unipc', device=0, shift=5.0):
torch.manual_seed(4)
if solver_name == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=device, shift=shift)
timesteps = sample_scheduler.timesteps
elif solver_name == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=1000,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift=shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=device,
sigmas=sampling_sigmas)
elif solver_name == 'euler':
sample_scheduler = FlowMatchEulerDiscreteScheduler(
shift=shift
)
timesteps, sampling_steps = retrieve_timesteps(
sample_scheduler,
sampling_steps,
device=device,
)
else:
raise NotImplementedError("Unsupported solver.")
return sample_scheduler, timesteps