Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import logging | |
| import torch | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from ovi.utils.io_utils import save_video | |
| from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt | |
| from ovi.utils.utils import get_arguments | |
| from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank | |
| from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info | |
| from ovi.ovi_fusion_engine import OviFusionEngine | |
| def _init_logging(rank): | |
| # logging | |
| if rank == 0: | |
| # set format | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="[%(asctime)s] %(levelname)s: %(message)s", | |
| handlers=[logging.StreamHandler(stream=sys.stdout)]) | |
| else: | |
| logging.basicConfig(level=logging.ERROR) | |
| def main(config, args): | |
| world_size = get_world_size() | |
| global_rank = get_global_rank() | |
| local_rank = get_local_rank() | |
| device = local_rank | |
| torch.cuda.set_device(local_rank) | |
| sp_size = config.get("sp_size", 1) | |
| assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size." | |
| _init_logging(global_rank) | |
| if world_size > 1: | |
| torch.distributed.init_process_group( | |
| backend="nccl", | |
| init_method="env://", | |
| rank=global_rank, | |
| world_size=world_size) | |
| else: | |
| assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}." | |
| ## TODO: assert not sharding t5 etc... | |
| initialize_sequence_parallel_state(sp_size) | |
| logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}") | |
| args.local_rank = local_rank | |
| args.device = device | |
| target_dtype = torch.bfloat16 | |
| # validate inputs before loading model to not waste time if input is not valid | |
| text_prompt = config.get("text_prompt") | |
| image_path = config.get("image_path", None) | |
| assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']" | |
| text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode")) | |
| if config.get("mode") != "i2v": | |
| logging.info(f"mode: {config.get('mode')}, setting all image_paths to None") | |
| image_paths = [None] * len(text_prompts) | |
| else: | |
| assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}" | |
| logging.info("Loading OVI Fusion Engine...") | |
| ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype) | |
| logging.info("OVI Fusion Engine loaded!") | |
| output_dir = config.get("output_dir", "./outputs") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load CSV data | |
| all_eval_data = list(zip(text_prompts, image_paths)) | |
| # Get SP configuration | |
| use_sp = get_sequence_parallel_state() | |
| if use_sp: | |
| sp_size = nccl_info.sp_size | |
| sp_rank = nccl_info.rank_within_group | |
| sp_group_id = global_rank // sp_size | |
| num_sp_groups = world_size // sp_size | |
| else: | |
| # No SP: treat each GPU as its own group | |
| sp_size = 1 | |
| sp_rank = 0 | |
| sp_group_id = global_rank | |
| num_sp_groups = world_size | |
| # Data distribution - by SP groups | |
| total_files = len(all_eval_data) | |
| require_sample_padding = False | |
| if total_files == 0: | |
| logging.error(f"ERROR: No evaluation files found") | |
| this_rank_eval_data = [] | |
| else: | |
| # Pad to match number of SP groups | |
| remainder = total_files % num_sp_groups | |
| if require_sample_padding and remainder != 0: | |
| pad_count = num_sp_groups - remainder | |
| all_eval_data += [all_eval_data[0]] * pad_count | |
| # Distribute across SP groups | |
| this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups] | |
| for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)): | |
| video_frame_height_width = config.get("video_frame_height_width", None) | |
| seed = config.get("seed", 100) | |
| solver_name = config.get("solver_name", "unipc") | |
| sample_steps = config.get("sample_steps", 50) | |
| shift = config.get("shift", 5.0) | |
| video_guidance_scale = config.get("video_guidance_scale", 4.0) | |
| audio_guidance_scale = config.get("audio_guidance_scale", 3.0) | |
| slg_layer = config.get("slg_layer", 11) | |
| video_negative_prompt = config.get("video_negative_prompt", "") | |
| audio_negative_prompt = config.get("audio_negative_prompt", "") | |
| for idx in range(config.get("each_example_n_times", 1)): | |
| generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt, | |
| image_path=image_path, | |
| video_frame_height_width=video_frame_height_width, | |
| seed=seed+idx, | |
| solver_name=solver_name, | |
| sample_steps=sample_steps, | |
| shift=shift, | |
| video_guidance_scale=video_guidance_scale, | |
| audio_guidance_scale=audio_guidance_scale, | |
| slg_layer=slg_layer, | |
| video_negative_prompt=video_negative_prompt, | |
| audio_negative_prompt=audio_negative_prompt) | |
| if sp_rank == 0: | |
| formatted_prompt = format_prompt_for_filename(text_prompt) | |
| output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4") | |
| save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000) | |
| if generated_image is not None: | |
| generated_image.save(output_path.replace('.mp4', '.png')) | |
| if __name__ == "__main__": | |
| args = get_arguments() | |
| config = OmegaConf.load(args.config_file) | |
| main(config=config,args=args) |