Spaces:
Runtime error
Runtime error
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import argparse | |
| import binascii | |
| import gc | |
| import os | |
| import os.path as osp | |
| import cv2 | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import inspect | |
| from einops import rearrange | |
| __all__ = ['cache_video', 'cache_image', 'str2bool'] | |
| from PIL import Image | |
| def filter_kwargs(cls, kwargs): | |
| sig = inspect.signature(cls.__init__) | |
| valid_params = set(sig.parameters.keys()) - {'self', 'cls'} | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} | |
| return filtered_kwargs | |
| def rand_name(length=8, suffix=''): | |
| name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') | |
| if suffix: | |
| if not suffix.startswith('.'): | |
| suffix = '.' + suffix | |
| name += suffix | |
| return name | |
| def cache_video(tensor, | |
| save_file=None, | |
| fps=30, | |
| suffix='.mp4', | |
| nrow=8, | |
| normalize=True, | |
| value_range=(-1, 1), | |
| retry=5): | |
| # cache file | |
| cache_file = osp.join('/tmp', rand_name( | |
| suffix=suffix)) if save_file is None else save_file | |
| # save to cache | |
| error = None | |
| for _ in range(retry): | |
| try: | |
| # preprocess | |
| tensor = tensor.clamp(min(value_range), max(value_range)) | |
| tensor = torch.stack([ | |
| torchvision.utils.make_grid( | |
| u, nrow=nrow, normalize=normalize, value_range=value_range) | |
| for u in tensor.unbind(2) | |
| ], | |
| dim=1).permute(1, 2, 3, 0) | |
| tensor = (tensor * 255).type(torch.uint8).cpu() | |
| # write video | |
| writer = imageio.get_writer( | |
| cache_file, fps=fps, codec='libx264', quality=8) | |
| for frame in tensor.numpy(): | |
| writer.append_data(frame) | |
| writer.close() | |
| return cache_file | |
| except Exception as e: | |
| error = e | |
| continue | |
| else: | |
| print(f'cache_video failed, error: {error}', flush=True) | |
| return None | |
| def cache_image(tensor, | |
| save_file, | |
| nrow=8, | |
| normalize=True, | |
| value_range=(-1, 1), | |
| retry=5): | |
| # cache file | |
| suffix = osp.splitext(save_file)[1] | |
| if suffix.lower() not in [ | |
| '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' | |
| ]: | |
| suffix = '.png' | |
| # save to cache | |
| error = None | |
| for _ in range(retry): | |
| try: | |
| tensor = tensor.clamp(min(value_range), max(value_range)) | |
| torchvision.utils.save_image( | |
| tensor, | |
| save_file, | |
| nrow=nrow, | |
| normalize=normalize, | |
| value_range=value_range) | |
| return save_file | |
| except Exception as e: | |
| error = e | |
| continue | |
| def str2bool(v): | |
| """ | |
| Convert a string to a boolean. | |
| Supported true values: 'yes', 'true', 't', 'y', '1' | |
| Supported false values: 'no', 'false', 'f', 'n', '0' | |
| Args: | |
| v (str): String to convert. | |
| Returns: | |
| bool: Converted boolean value. | |
| Raises: | |
| argparse.ArgumentTypeError: If the value cannot be converted to boolean. | |
| """ | |
| if isinstance(v, bool): | |
| return v | |
| v_lower = v.lower() | |
| if v_lower in ('yes', 'true', 't', 'y', '1'): | |
| return True | |
| elif v_lower in ('no', 'false', 'f', 'n', '0'): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError('Boolean value expected (True/False)') | |
| def color_transfer(sc, dc): | |
| """ | |
| Transfer color distribution from of sc, referred to dc. | |
| Args: | |
| sc (numpy.ndarray): input image to be transfered. | |
| dc (numpy.ndarray): reference image | |
| Returns: | |
| numpy.ndarray: Transferred color distribution on the sc. | |
| """ | |
| def get_mean_and_std(img): | |
| x_mean, x_std = cv2.meanStdDev(img) | |
| x_mean = np.hstack(np.around(x_mean, 2)) | |
| x_std = np.hstack(np.around(x_std, 2)) | |
| return x_mean, x_std | |
| sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) | |
| s_mean, s_std = get_mean_and_std(sc) | |
| dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) | |
| t_mean, t_std = get_mean_and_std(dc) | |
| img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean | |
| np.putmask(img_n, img_n > 255, 255) | |
| np.putmask(img_n, img_n < 0, 0) | |
| dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) | |
| return dst | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, | |
| color_transfer_post_process=False): | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = (x * 255).numpy().astype(np.uint8) | |
| outputs.append(Image.fromarray(x)) | |
| if color_transfer_post_process: | |
| for i in range(1, len(outputs)): | |
| outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| if imageio_backend: | |
| if path.endswith("mp4"): | |
| imageio.mimsave(path, outputs, fps=fps) | |
| else: | |
| imageio.mimsave(path, outputs, duration=(1000 * 1 / fps)) | |
| else: | |
| if path.endswith("mp4"): | |
| path = path.replace('.mp4', '.gif') | |
| outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) | |
| def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): | |
| if validation_image_start is not None and validation_image_end is not None: | |
| if type(validation_image_start) is str and os.path.isfile(validation_image_start): | |
| image_start = clip_image = Image.open(validation_image_start).convert("RGB") | |
| image_start = image_start.resize([sample_size[1], sample_size[0]]) | |
| clip_image = clip_image.resize([sample_size[1], sample_size[0]]) | |
| else: | |
| image_start = clip_image = validation_image_start | |
| image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] | |
| clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] | |
| if type(validation_image_end) is str and os.path.isfile(validation_image_end): | |
| image_end = Image.open(validation_image_end).convert("RGB") | |
| image_end = image_end.resize([sample_size[1], sample_size[0]]) | |
| else: | |
| image_end = validation_image_end | |
| image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] | |
| if type(image_start) is list: | |
| clip_image = clip_image[0] | |
| start_video = torch.cat( | |
| [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in | |
| image_start], | |
| dim=2 | |
| ) | |
| input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) | |
| input_video[:, :, :len(image_start)] = start_video | |
| input_video_mask = torch.zeros_like(input_video[:, :1]) | |
| input_video_mask[:, :, len(image_start):] = 255 | |
| else: | |
| input_video = torch.tile( | |
| torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), | |
| [1, 1, video_length, 1, 1] | |
| ) | |
| input_video_mask = torch.zeros_like(input_video[:, :1]) | |
| input_video_mask[:, :, 1:] = 255 | |
| if type(image_end) is list: | |
| image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for | |
| _image_end in image_end] | |
| end_video = torch.cat( | |
| [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in | |
| image_end], | |
| dim=2 | |
| ) | |
| input_video[:, :, -len(end_video):] = end_video | |
| input_video_mask[:, :, -len(image_end):] = 0 | |
| else: | |
| image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) | |
| input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) | |
| input_video_mask[:, :, -1:] = 0 | |
| input_video = input_video / 255 | |
| elif validation_image_start is not None: | |
| if type(validation_image_start) is str and os.path.isfile(validation_image_start): | |
| image_start = clip_image = Image.open(validation_image_start).convert("RGB") | |
| image_start = image_start.resize([sample_size[1], sample_size[0]]) | |
| clip_image = clip_image.resize([sample_size[1], sample_size[0]]) | |
| else: | |
| image_start = clip_image = validation_image_start | |
| image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] | |
| clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] | |
| image_end = None | |
| if type(image_start) is list: | |
| clip_image = clip_image[0] | |
| start_video = torch.cat( | |
| [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in | |
| image_start], | |
| dim=2 | |
| ) | |
| input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) | |
| input_video[:, :, :len(image_start)] = start_video | |
| input_video = input_video / 255 | |
| input_video_mask = torch.zeros_like(input_video[:, :1]) | |
| input_video_mask[:, :, len(image_start):] = 255 | |
| else: | |
| input_video = torch.tile( | |
| torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), | |
| [1, 1, video_length, 1, 1] | |
| ) / 255 | |
| input_video_mask = torch.zeros_like(input_video[:, :1]) | |
| input_video_mask[:, :, 1:, ] = 255 | |
| else: | |
| image_start = None | |
| image_end = None | |
| input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) | |
| input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 | |
| clip_image = None | |
| del image_start | |
| del image_end | |
| gc.collect() | |
| return input_video, input_video_mask, clip_image |