OVI / inference.py
alexnasa's picture
Upload 121 files
a3a2e41 verified
raw
history blame
6.79 kB
import os
import sys
import logging
import torch
from tqdm import tqdm
from omegaconf import OmegaConf
from ovi.utils.io_utils import save_video
from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt
from ovi.utils.utils import get_arguments
from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank
from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info
from ovi.ovi_fusion_engine import OviFusionEngine
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def main(config, args):
world_size = get_world_size()
global_rank = get_global_rank()
local_rank = get_local_rank()
device = local_rank
torch.cuda.set_device(local_rank)
sp_size = config.get("sp_size", 1)
assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size."
_init_logging(global_rank)
if world_size > 1:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
rank=global_rank,
world_size=world_size)
else:
assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
## TODO: assert not sharding t5 etc...
initialize_sequence_parallel_state(sp_size)
logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}")
args.local_rank = local_rank
args.device = device
target_dtype = torch.bfloat16
# validate inputs before loading model to not waste time if input is not valid
text_prompt = config.get("text_prompt")
image_path = config.get("image_path", None)
assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']"
text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode"))
if config.get("mode") != "i2v":
logging.info(f"mode: {config.get('mode')}, setting all image_paths to None")
image_paths = [None] * len(text_prompts)
else:
assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}"
logging.info("Loading OVI Fusion Engine...")
ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype)
logging.info("OVI Fusion Engine loaded!")
output_dir = config.get("output_dir", "./outputs")
os.makedirs(output_dir, exist_ok=True)
# Load CSV data
all_eval_data = list(zip(text_prompts, image_paths))
# Get SP configuration
use_sp = get_sequence_parallel_state()
if use_sp:
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
sp_group_id = global_rank // sp_size
num_sp_groups = world_size // sp_size
else:
# No SP: treat each GPU as its own group
sp_size = 1
sp_rank = 0
sp_group_id = global_rank
num_sp_groups = world_size
# Data distribution - by SP groups
total_files = len(all_eval_data)
require_sample_padding = False
if total_files == 0:
logging.error(f"ERROR: No evaluation files found")
this_rank_eval_data = []
else:
# Pad to match number of SP groups
remainder = total_files % num_sp_groups
if require_sample_padding and remainder != 0:
pad_count = num_sp_groups - remainder
all_eval_data += [all_eval_data[0]] * pad_count
# Distribute across SP groups
this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups]
for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)):
video_frame_height_width = config.get("video_frame_height_width", None)
seed = config.get("seed", 100)
solver_name = config.get("solver_name", "unipc")
sample_steps = config.get("sample_steps", 50)
shift = config.get("shift", 5.0)
video_guidance_scale = config.get("video_guidance_scale", 4.0)
audio_guidance_scale = config.get("audio_guidance_scale", 3.0)
slg_layer = config.get("slg_layer", 11)
video_negative_prompt = config.get("video_negative_prompt", "")
audio_negative_prompt = config.get("audio_negative_prompt", "")
for idx in range(config.get("each_example_n_times", 1)):
generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt,
image_path=image_path,
video_frame_height_width=video_frame_height_width,
seed=seed+idx,
solver_name=solver_name,
sample_steps=sample_steps,
shift=shift,
video_guidance_scale=video_guidance_scale,
audio_guidance_scale=audio_guidance_scale,
slg_layer=slg_layer,
video_negative_prompt=video_negative_prompt,
audio_negative_prompt=audio_negative_prompt)
if sp_rank == 0:
formatted_prompt = format_prompt_for_filename(text_prompt)
output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4")
save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
if generated_image is not None:
generated_image.save(output_path.replace('.mp4', '.png'))
if __name__ == "__main__":
args = get_arguments()
config = OmegaConf.load(args.config_file)
main(config=config,args=args)