Ovi-ZEROGPU / ovi /utils /model_loading_utils.py
alexnasa's picture
Upload 121 files
a3a2e41 verified
import torch
import os
import json
from safetensors.torch import load_file
from ovi.modules.fusion import FusionModel
from ovi.modules.t5 import T5EncoderModel
from ovi.modules.vae2_2 import Wan2_2_VAE
from ovi.modules.mmaudio.features_utils import FeaturesUtils
def init_wan_vae_2_2(ckpt_dir, rank=0):
vae_config = {}
vae_config['device'] = rank
vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth")
vae_config['vae_pth'] = vae_pth
vae_model = Wan2_2_VAE(**vae_config)
return vae_model
def init_mmaudio_vae(ckpt_dir, rank=0):
vae_config = {}
vae_config['mode'] = '16k'
vae_config['need_vae_encoder'] = True
tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth")
bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt")
vae_config['tod_vae_ckpt'] = tod_vae_ckpt
vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt
vae = FeaturesUtils(**vae_config).to(rank)
return vae
def init_fusion_score_model_ovi(rank: int = 0, meta_init=False):
video_config = "ovi/configs/model/dit/video.json"
audio_config = "ovi/configs/model/dit/audio.json"
assert os.path.exists(video_config), f"{video_config} does not exist"
assert os.path.exists(audio_config), f"{audio_config} does not exist"
with open(video_config) as f:
video_config = json.load(f)
with open(audio_config) as f:
audio_config = json.load(f)
if meta_init:
with torch.device("meta"):
fusion_model = FusionModel(video_config, audio_config)
else:
fusion_model = FusionModel(video_config, audio_config)
params_all = sum(p.numel() for p in fusion_model.parameters())
if rank == 0:
print(
f"Score model (Fusion) all parameters:{params_all}"
)
return fusion_model, video_config, audio_config
def init_text_model(ckpt_dir, rank):
wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth")
text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl")
text_encoder = T5EncoderModel(
text_len=512,
dtype=torch.bfloat16,
device=rank,
checkpoint_path=text_encoder_path,
tokenizer_path=text_tokenizer_path,
shard_fn=None)
return text_encoder
def load_fusion_checkpoint(model, checkpoint_path, from_meta=False):
if checkpoint_path and os.path.exists(checkpoint_path):
if checkpoint_path.endswith(".safetensors"):
df = load_file(checkpoint_path, device="cpu")
elif checkpoint_path.endswith(".pt"):
try:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
df = df['module'] if 'module' in df else df
except Exception as e:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
df = df['app']['model']
else:
raise RuntimeError("We only support .safetensors and .pt checkpoints")
missing, unexpected = model.load_state_dict(df, strict=True, assign=from_meta)
del df
import gc
gc.collect()
print(f"Successfully loaded fusion checkpoint from {checkpoint_path}")
else:
raise RuntimeError("{checkpoint=} does not exists'")