Spaces:
Running
Running
| import os | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import cv2 | |
| from einops import rearrange | |
| from PIL import Image | |
| def color_transfer(sc, dc): | |
| """ | |
| Transfer color distribution from of sc, referred to dc. | |
| Args: | |
| sc (numpy.ndarray): input image to be transfered. | |
| dc (numpy.ndarray): reference image | |
| Returns: | |
| numpy.ndarray: Transferred color distribution on the sc. | |
| """ | |
| def get_mean_and_std(img): | |
| x_mean, x_std = cv2.meanStdDev(img) | |
| x_mean = np.hstack(np.around(x_mean, 2)) | |
| x_std = np.hstack(np.around(x_std, 2)) | |
| return x_mean, x_std | |
| sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) | |
| s_mean, s_std = get_mean_and_std(sc) | |
| dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) | |
| t_mean, t_std = get_mean_and_std(dc) | |
| img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean | |
| np.putmask(img_n, img_n > 255, 255) | |
| np.putmask(img_n, img_n < 0, 0) | |
| dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) | |
| return dst | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False): | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = (x * 255).numpy().astype(np.uint8) | |
| outputs.append(Image.fromarray(x)) | |
| if color_transfer_post_process: | |
| for i in range(1, len(outputs)): | |
| outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| if imageio_backend: | |
| if path.endswith("mp4"): | |
| imageio.mimsave(path, outputs, fps=fps) | |
| else: | |
| imageio.mimsave(path, outputs, duration=(1000 * 1/fps)) | |
| else: | |
| if path.endswith("mp4"): | |
| path = path.replace('.mp4', '.gif') | |
| outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) | |