Spaces:
Runtime error
Runtime error
| ''' | |
| Tool for generating editing videos across different domains. | |
| Given a set of latent codes and pre-trained models, it will interpolate between the different codes in each of the target domains | |
| and combine the resulting images into a video. | |
| Example run command: | |
| python generate_videos.py --ckpt /model_dir/pixar.pt \ | |
| /model_dir/ukiyoe.pt \ | |
| /model_dir/edvard_munch.pt \ | |
| /model_dir/botero.pt \ | |
| --out_dir /output/video/ \ | |
| --source_latent /latents/latent000.npy \ | |
| --target_latents /latents/ | |
| ''' | |
| import os | |
| import argparse | |
| import torch | |
| from torchvision import utils | |
| from model.sg2_model import Generator | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import numpy as np | |
| import subprocess | |
| import shutil | |
| import copy | |
| from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor | |
| VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"] | |
| SUGGESTED_DISTANCES = { | |
| "pose": 3.0, | |
| "smile": 2.0, | |
| "age": 4.0, | |
| "gender": 3.0, | |
| "hair_length": -4.0, | |
| "beard": 2.0 | |
| } | |
| def project_code(latent_code, boundary, distance=3.0): | |
| if len(boundary) == 2: | |
| boundary = boundary.reshape(1, 1, -1) | |
| return latent_code + distance * boundary | |
| def project_code_by_edit_name(latent_code, name, strength): | |
| boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries") | |
| distance = SUGGESTED_DISTANCES[name] * strength | |
| boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy() | |
| return project_code(latent_code, boundary, distance) | |
| def generate_frames(source_latent, target_latents, g_ema_list, output_dir): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| code_is_s = target_latents[0].size()[1] == 9088 | |
| if code_is_s: | |
| source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0] | |
| np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy() | |
| else: | |
| np_latent = source_latent.squeeze(0).cpu().detach().numpy() | |
| np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents] | |
| num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents)) | |
| alphas = np.linspace(0, 1, num=num_alphas) | |
| latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas) | |
| segments = len(g_ema_list) - 1 | |
| if segments: | |
| segment_length = len(latents) / segments | |
| g_ema = copy.deepcopy(g_ema_list[0]) | |
| src_pars = dict(g_ema.named_parameters()) | |
| mix_pars = [dict(model.named_parameters()) for model in g_ema_list] | |
| else: | |
| g_ema = g_ema_list[0] | |
| print("Generating frames for video...") | |
| for idx, latent in tqdm(enumerate(latents), total=len(latents)): | |
| if segments: | |
| mix_alpha = (idx % segment_length) * 1.0 / segment_length | |
| segment_id = int(idx // segment_length) | |
| for k in src_pars.keys(): | |
| src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha) | |
| if idx == 0 or segments or latent is not latents[idx - 1]: | |
| latent_tensor = torch.from_numpy(latent).float().to(device) | |
| with torch.no_grad(): | |
| if code_is_s: | |
| latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema) | |
| img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False) | |
| else: | |
| img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False) | |
| utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1)) | |
| def interpolate_forward_backward(source_latent, target_latent, alphas): | |
| latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target | |
| latents_backward = latents_forward[::-1] # interpolate from target to source | |
| return latents_forward + [target_latent] * len(alphas) + latents_backward # forward + short delay at target + return | |
| def interpolate_with_target_latents(source_latent, target_latents, alphas): | |
| # interpolate latent codes with all targets | |
| print("Interpolating latent codes...") | |
| latents = [] | |
| for target_latent in target_latents: | |
| latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas)) | |
| return latents | |
| def video_from_interpolations(fps, output_dir): | |
| # combine frames to a video | |
| command = ["ffmpeg", | |
| "-r", f"{fps}", | |
| "-i", f"{output_dir}/%03d.jpg", | |
| "-c:v", "libx264", | |
| "-vf", f"fps={fps}", | |
| "-pix_fmt", "yuv420p", | |
| f"{output_dir}/out.mp4"] | |
| subprocess.call(command) | |