import torch import os import json from safetensors.torch import load_file from ovi.modules.fusion import FusionModel from ovi.modules.t5 import T5EncoderModel from ovi.modules.vae2_2 import Wan2_2_VAE from ovi.modules.mmaudio.features_utils import FeaturesUtils def init_wan_vae_2_2(ckpt_dir, rank=0): vae_config = {} vae_config['device'] = rank vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth") vae_config['vae_pth'] = vae_pth vae_model = Wan2_2_VAE(**vae_config) return vae_model def init_mmaudio_vae(ckpt_dir, rank=0): vae_config = {} vae_config['mode'] = '16k' vae_config['need_vae_encoder'] = True tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth") bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt") vae_config['tod_vae_ckpt'] = tod_vae_ckpt vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt vae = FeaturesUtils(**vae_config).to(rank) return vae def init_fusion_score_model_ovi(rank: int = 0, meta_init=False): video_config = "ovi/configs/model/dit/video.json" audio_config = "ovi/configs/model/dit/audio.json" assert os.path.exists(video_config), f"{video_config} does not exist" assert os.path.exists(audio_config), f"{audio_config} does not exist" with open(video_config) as f: video_config = json.load(f) with open(audio_config) as f: audio_config = json.load(f) if meta_init: with torch.device("meta"): fusion_model = FusionModel(video_config, audio_config) else: fusion_model = FusionModel(video_config, audio_config) params_all = sum(p.numel() for p in fusion_model.parameters()) if rank == 0: print( f"Score model (Fusion) all parameters:{params_all}" ) return fusion_model, video_config, audio_config def init_text_model(ckpt_dir, rank): wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B") text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth") text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl") text_encoder = T5EncoderModel( text_len=512, dtype=torch.bfloat16, device=rank, checkpoint_path=text_encoder_path, tokenizer_path=text_tokenizer_path, shard_fn=None) return text_encoder def load_fusion_checkpoint(model, checkpoint_path, from_meta=False): if checkpoint_path and os.path.exists(checkpoint_path): if checkpoint_path.endswith(".safetensors"): df = load_file(checkpoint_path, device="cpu") elif checkpoint_path.endswith(".pt"): try: df = torch.load(checkpoint_path, map_location="cpu", weights_only=False) df = df['module'] if 'module' in df else df except Exception as e: df = torch.load(checkpoint_path, map_location="cpu", weights_only=True) df = df['app']['model'] else: raise RuntimeError("We only support .safetensors and .pt checkpoints") missing, unexpected = model.load_state_dict(df, strict=True, assign=from_meta) del df import gc gc.collect() print(f"Successfully loaded fusion checkpoint from {checkpoint_path}") else: raise RuntimeError("{checkpoint=} does not exists'")