| import numpy as np | |
| from PIL import Image | |
| from configs import global_config | |
| import torch | |
| import matplotlib.pyplot as plt | |
| def log_image_from_w(w, G, name): | |
| img = get_image_from_w(w, G) | |
| pillow_image = Image.fromarray(img) | |
| wandb.log( | |
| {f"{name}": [ | |
| wandb.Image(pillow_image, caption=f"current inversion {name}")]}, | |
| step=global_config.training_step) | |
| def log_images_from_w(ws, G, names): | |
| for name, w in zip(names, ws): | |
| w = w.to(global_config.device) | |
| log_image_from_w(w, G, name) | |
| def plot_image_from_w(w, G): | |
| img = get_image_from_w(w, G) | |
| pillow_image = Image.fromarray(img) | |
| plt.imshow(pillow_image) | |
| plt.show() | |
| def plot_image(img): | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() | |
| pillow_image = Image.fromarray(img[0]) | |
| plt.imshow(pillow_image) | |
| plt.show() | |
| def save_image(name, method_type, results_dir, image, run_id): | |
| image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg') | |
| def save_w(w, G, name, method_type, results_dir): | |
| im = get_image_from_w(w, G) | |
| im = Image.fromarray(im, mode='RGB') | |
| save_image(name, method_type, results_dir, im) | |
| def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G, | |
| old_G, | |
| file_name, | |
| extra_image=None): | |
| images_to_save = [] | |
| if extra_image is not None: | |
| images_to_save.append(extra_image) | |
| for latent in image_latents: | |
| images_to_save.append(get_image_from_w(latent, old_G)) | |
| images_to_save.append(get_image_from_w(new_inv_image_latent, new_G)) | |
| result_image = create_alongside_images(images_to_save) | |
| result_image.save(f'{base_dir}/{file_name}.jpg') | |
| def save_single_image(base_dir, image_latent, G, file_name): | |
| image_to_save = get_image_from_w(image_latent, G) | |
| image_to_save = Image.fromarray(image_to_save, mode='RGB') | |
| image_to_save.save(f'{base_dir}/{file_name}.jpg') | |
| def create_alongside_images(images): | |
| res = np.concatenate([np.array(image) for image in images], axis=1) | |
| return Image.fromarray(res, mode='RGB') | |
| def get_image_from_w(w, G): | |
| if len(w.size()) <= 2: | |
| w = w.unsqueeze(0) | |
| with torch.no_grad(): | |
| img = G.synthesis(w, noise_mode='const') | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() | |
| return img[0] | |