Ovi-ZEROGPU / ovi /utils /utils.py
alexnasa's picture
Upload 121 files
a3a2e41 verified
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import binascii
import os
import os.path as osp
import imageio
import torch
import torchvision
from sys import argv
__all__ = ['cache_video', 'cache_image', 'str2bool']
def get_arguments(args=argv[1:]):
parser = get_argument_parser()
args = parser.parse_args(args)
# If local_rank wasn't provided, try to infer from common env vars
if getattr(args, "local_rank", -1) == -1:
env_lr = os.environ.get("LOCAL_RANK") or os.environ.get("SLURM_LOCALID")
try:
if env_lr is not None:
args.local_rank = int(env_lr)
except ValueError:
pass
# no cuda mode is not supported
args.no_cuda = False
# Optionally bind this process to a specific CUDA device
if torch.cuda.is_available() and getattr(args, "local_rank", -1) >= 0:
try:
torch.cuda.set_device(args.local_rank % torch.cuda.device_count())
except Exception:
pass
return args
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file",
type=str,
default="ovi/configs/inference/inference_fusion.yaml")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
return parser
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def cache_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
error = None
for _ in range(retry):
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(
cache_file, fps=fps, codec='libx264', quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
return cache_file
except Exception as e:
error = e
continue
else:
print(f'cache_video failed, error: {error}', flush=True)
return None
def cache_image(tensor,
save_file,
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
]:
suffix = '.png'
# save to cache
error = None
for _ in range(retry):
try:
tensor = tensor.clamp(min(value_range), max(value_range))
torchvision.utils.save_image(
tensor,
save_file,
nrow=nrow,
normalize=normalize,
value_range=value_range)
return save_file
except Exception as e:
error = e
continue
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')