Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| ''' | |
| /* | |
| *Copyright (c) 2021, Alibaba Group; | |
| *Licensed under the Apache License, Version 2.0 (the "License"); | |
| *you may not use this file except in compliance with the License. | |
| *You may obtain a copy of the License at | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| *Unless required by applicable law or agreed to in writing, software | |
| *distributed under the License is distributed on an "AS IS" BASIS, | |
| *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| *See the License for the specific language governing permissions and | |
| *limitations under the License. | |
| */ | |
| ''' | |
| import os | |
| import re | |
| import os.path as osp | |
| import sys | |
| sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) | |
| import json | |
| import math | |
| import torch | |
| import pynvml | |
| import logging | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import torch.cuda.amp as amp | |
| from importlib import reload | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import random | |
| from einops import rearrange | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as TF | |
| from torch.nn.parallel import DistributedDataParallel | |
| import utils.transforms as data | |
| from ..modules.config import cfg | |
| from utils.seed import setup_seed | |
| from utils.multi_port import find_free_port | |
| from utils.assign_cfg import assign_signle_cfg | |
| from utils.distributed import generalized_all_gather, all_reduce | |
| from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col | |
| from tools.modules.autoencoder import get_first_stage_encoding | |
| from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION | |
| from copy import copy | |
| import cv2 | |
| def inference_unianimate_long_entrance(cfg_update, **kwargs): | |
| for k, v in cfg_update.items(): | |
| if isinstance(v, dict) and k in cfg: | |
| cfg[k].update(v) | |
| else: | |
| cfg[k] = v | |
| if not 'MASTER_ADDR' in os.environ: | |
| os.environ['MASTER_ADDR']='localhost' | |
| os.environ['MASTER_PORT']= find_free_port() | |
| cfg.pmi_rank = int(os.getenv('RANK', 0)) | |
| cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) | |
| if cfg.debug: | |
| cfg.gpus_per_machine = 1 | |
| cfg.world_size = 1 | |
| else: | |
| cfg.gpus_per_machine = torch.cuda.device_count() | |
| cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine | |
| if cfg.world_size == 1: | |
| worker(0, cfg, cfg_update) | |
| else: | |
| mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) | |
| return cfg | |
| def make_masked_images(imgs, masks): | |
| masked_imgs = [] | |
| for i, mask in enumerate(masks): | |
| # concatenation | |
| masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) | |
| return torch.stack(masked_imgs, dim=0) | |
| def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): | |
| for _ in range(5): | |
| try: | |
| dwpose_all = {} | |
| frames_all = {} | |
| for ii_index in sorted(os.listdir(pose_file_path)): | |
| if ii_index != "ref_pose.jpg": | |
| dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index) | |
| frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB)) | |
| # frames_all[ii_index] = Image.open(ref_image_path) | |
| pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) | |
| first_eq_ref = False | |
| # sample max_frames poses for video generation | |
| stride = frame_interval | |
| _total_frame_num = len(frames_all) | |
| if max_frames == "None": | |
| max_frames = (_total_frame_num-1)//frame_interval + 1 | |
| cover_frame_num = (stride * (max_frames-1)+1) | |
| if _total_frame_num < cover_frame_num: | |
| print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed') | |
| start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame | |
| end_frame = _total_frame_num | |
| stride = max((_total_frame_num-1//(max_frames-1)),1) | |
| end_frame = stride*max_frames | |
| else: | |
| start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame | |
| end_frame = start_frame + cover_frame_num | |
| frame_list = [] | |
| dwpose_list = [] | |
| random_ref_frame = frames_all[list(frames_all.keys())[0]] | |
| if random_ref_frame.mode != 'RGB': | |
| random_ref_frame = random_ref_frame.convert('RGB') | |
| random_ref_dwpose = pose_ref | |
| if random_ref_dwpose.mode != 'RGB': | |
| random_ref_dwpose = random_ref_dwpose.convert('RGB') | |
| for i_index in range(start_frame, end_frame, stride): | |
| if i_index == start_frame and first_eq_ref: | |
| i_key = list(frames_all.keys())[i_index] | |
| i_frame = frames_all[i_key] | |
| if i_frame.mode != 'RGB': | |
| i_frame = i_frame.convert('RGB') | |
| i_dwpose = frames_pose_ref | |
| if i_dwpose.mode != 'RGB': | |
| i_dwpose = i_dwpose.convert('RGB') | |
| frame_list.append(i_frame) | |
| dwpose_list.append(i_dwpose) | |
| else: | |
| # added | |
| if first_eq_ref: | |
| i_index = i_index - stride | |
| i_key = list(frames_all.keys())[i_index] | |
| i_frame = frames_all[i_key] | |
| if i_frame.mode != 'RGB': | |
| i_frame = i_frame.convert('RGB') | |
| i_dwpose = dwpose_all[i_key] | |
| if i_dwpose.mode != 'RGB': | |
| i_dwpose = i_dwpose.convert('RGB') | |
| frame_list.append(i_frame) | |
| dwpose_list.append(i_dwpose) | |
| have_frames = len(frame_list)>0 | |
| middle_indix = 0 | |
| if have_frames: | |
| ref_frame = frame_list[middle_indix] | |
| vit_frame = vit_transforms(ref_frame) | |
| random_ref_frame_tmp = train_trans_pose(random_ref_frame) | |
| random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) | |
| misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) | |
| video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) | |
| dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) | |
| video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
| dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
| misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
| random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768] | |
| random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
| if have_frames: | |
| video_data[:len(frame_list), ...] = video_data_tmp | |
| misc_data[:len(frame_list), ...] = misc_data_tmp | |
| dwpose_data[:len(frame_list), ...] = dwpose_data_tmp | |
| random_ref_frame_data[:,...] = random_ref_frame_tmp | |
| random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp | |
| break | |
| except Exception as e: | |
| logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e)) | |
| continue | |
| return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames | |
| def worker(gpu, cfg, cfg_update): | |
| ''' | |
| Inference worker for each gpu | |
| ''' | |
| for k, v in cfg_update.items(): | |
| if isinstance(v, dict) and k in cfg: | |
| cfg[k].update(v) | |
| else: | |
| cfg[k] = v | |
| cfg.gpu = gpu | |
| cfg.seed = int(cfg.seed) | |
| cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu | |
| setup_seed(cfg.seed + cfg.rank) | |
| if not cfg.debug: | |
| torch.cuda.set_device(gpu) | |
| torch.backends.cudnn.benchmark = True | |
| if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
| torch.backends.cudnn.benchmark = False | |
| dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) | |
| # [Log] Save logging and make log dir | |
| log_dir = generalized_all_gather(cfg.log_dir)[0] | |
| inf_name = osp.basename(cfg.cfg_file).split('.')[0] | |
| test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] | |
| cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) | |
| os.makedirs(cfg.log_dir, exist_ok=True) | |
| log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) | |
| cfg.log_file = log_file | |
| reload(logging) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[%(asctime)s] %(levelname)s: %(message)s', | |
| handlers=[ | |
| logging.FileHandler(filename=log_file), | |
| logging.StreamHandler(stream=sys.stdout)]) | |
| logging.info(cfg) | |
| logging.info(f"Running UniAnimate inference on gpu {gpu}") | |
| # [Diffusion] | |
| diffusion = DIFFUSION.build(cfg.Diffusion) | |
| # [Data] Data Transform | |
| train_trans = data.Compose([ | |
| data.Resize(cfg.resolution), | |
| data.ToTensor(), | |
| data.Normalize(mean=cfg.mean, std=cfg.std) | |
| ]) | |
| train_trans_pose = data.Compose([ | |
| data.Resize(cfg.resolution), | |
| data.ToTensor(), | |
| ] | |
| ) | |
| vit_transforms = T.Compose([ | |
| data.Resize(cfg.vit_resolution), | |
| T.ToTensor(), | |
| T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) | |
| # [Model] embedder | |
| clip_encoder = EMBEDDER.build(cfg.embedder) | |
| clip_encoder.model.to(gpu) | |
| with torch.no_grad(): | |
| _, _, zero_y = clip_encoder(text="") | |
| # [Model] auotoencoder | |
| autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) | |
| autoencoder.eval() # freeze | |
| for param in autoencoder.parameters(): | |
| param.requires_grad = False | |
| autoencoder.cuda() | |
| # [Model] UNet | |
| if "config" in cfg.UNet: | |
| cfg.UNet["config"] = cfg | |
| cfg.UNet["zero_y"] = zero_y | |
| model = MODEL.build(cfg.UNet) | |
| state_dict = torch.load(cfg.test_model, map_location='cpu') | |
| if 'state_dict' in state_dict: | |
| state_dict = state_dict['state_dict'] | |
| if 'step' in state_dict: | |
| resume_step = state_dict['step'] | |
| else: | |
| resume_step = 0 | |
| status = model.load_state_dict(state_dict, strict=True) | |
| logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) | |
| model = model.to(gpu) | |
| model.eval() | |
| if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
| model.to(torch.float16) | |
| else: | |
| model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model | |
| torch.cuda.empty_cache() | |
| test_list = cfg.test_list_path | |
| num_videos = len(test_list) | |
| logging.info(f'There are {num_videos} videos. with {cfg.round} times') | |
| test_list = [item for _ in range(cfg.round) for item in test_list] | |
| for idx, file_path in enumerate(test_list): | |
| cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] | |
| manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) | |
| setup_seed(manual_seed) | |
| logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") | |
| vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) | |
| cfg.max_frames_new = max_frames | |
| misc_data = misc_data.unsqueeze(0).to(gpu) | |
| vit_frame = vit_frame.unsqueeze(0).to(gpu) | |
| dwpose_data = dwpose_data.unsqueeze(0).to(gpu) | |
| random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) | |
| random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) | |
| ### save for visualization | |
| misc_backups = copy(misc_data) | |
| frames_num = misc_data.shape[1] | |
| misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') | |
| mv_data_video = [] | |
| ### local image (first frame) | |
| image_local = [] | |
| if 'local_image' in cfg.video_compositions: | |
| frames_num = misc_data.shape[1] | |
| bs_vd_local = misc_data.shape[0] | |
| image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) | |
| image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) | |
| image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) | |
| if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: | |
| with torch.no_grad(): | |
| temporal_length = frames_num | |
| encoder_posterior = autoencoder.encode(video_data[:,0]) | |
| local_image_data = get_first_stage_encoding(encoder_posterior).detach() | |
| image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] | |
| ### encode the video_data | |
| bs_vd = misc_data.shape[0] | |
| misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') | |
| misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) | |
| with torch.no_grad(): | |
| random_ref_frame = [] | |
| if 'randomref' in cfg.video_compositions: | |
| random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') | |
| if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: | |
| temporal_length = random_ref_frame_data.shape[1] | |
| encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) | |
| random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() | |
| random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] | |
| random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') | |
| if 'dwpose' in cfg.video_compositions: | |
| bs_vd_local = dwpose_data.shape[0] | |
| dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) | |
| if 'randomref_pose' in cfg.video_compositions: | |
| dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) | |
| dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) | |
| y_visual = [] | |
| if 'image' in cfg.video_compositions: | |
| with torch.no_grad(): | |
| vit_frame = vit_frame.squeeze(1) | |
| y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] | |
| y_visual0 = y_visual.clone() | |
| with amp.autocast(enabled=True): | |
| pynvml.nvmlInit() | |
| handle=pynvml.nvmlDeviceGetHandleByIndex(0) | |
| meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) | |
| cur_seed = torch.initial_seed() | |
| logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") | |
| noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) | |
| noise = noise.to(gpu) | |
| # add a noise prior | |
| noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) | |
| if hasattr(cfg.Diffusion, "noise_strength"): | |
| b, c, f, _, _= noise.shape | |
| offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) | |
| noise = noise + cfg.Diffusion.noise_strength * offset_noise | |
| # construct model inputs (CFG) | |
| full_model_kwargs=[{ | |
| 'y': None, | |
| "local_image": None if len(image_local) == 0 else image_local[:], | |
| 'image': None if len(y_visual) == 0 else y_visual0[:], | |
| 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], | |
| 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], | |
| }, | |
| { | |
| 'y': None, | |
| "local_image": None, | |
| 'image': None, | |
| 'randomref': None, | |
| 'dwpose': None, | |
| }] | |
| # for visualization | |
| full_model_kwargs_vis =[{ | |
| 'y': None, | |
| "local_image": None if len(image_local) == 0 else image_local_clone[:], | |
| 'image': None, | |
| 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], | |
| 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], | |
| }, | |
| { | |
| 'y': None, | |
| "local_image": None, | |
| 'image': None, | |
| 'randomref': None, | |
| 'dwpose': None, | |
| }] | |
| partial_keys = [ | |
| ['image', 'randomref', "dwpose"], | |
| ] | |
| if hasattr(cfg, "partial_keys") and cfg.partial_keys: | |
| partial_keys = cfg.partial_keys | |
| for partial_keys_one in partial_keys: | |
| model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, | |
| full_model_kwargs = full_model_kwargs, | |
| use_fps_condition = cfg.use_fps_condition) | |
| model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, | |
| full_model_kwargs = full_model_kwargs_vis, | |
| use_fps_condition = cfg.use_fps_condition) | |
| noise_one = noise | |
| if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
| clip_encoder.cpu() # add this line | |
| autoencoder.cpu() # add this line | |
| torch.cuda.empty_cache() # add this line | |
| video_data = diffusion.ddim_sample_loop( | |
| noise=noise_one, | |
| context_size=cfg.context_size, | |
| context_stride=cfg.context_stride, | |
| context_overlap=cfg.context_overlap, | |
| model=model.eval(), | |
| model_kwargs=model_kwargs_one, | |
| guide_scale=cfg.guide_scale, | |
| ddim_timesteps=cfg.ddim_timesteps, | |
| eta=0.0, | |
| context_batch_size=getattr(cfg, "context_batch_size", 1) | |
| ) | |
| if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
| # if run forward of autoencoder or clip_encoder second times, load them again | |
| clip_encoder.cuda() | |
| autoencoder.cuda() | |
| video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] | |
| video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') | |
| chunk_size = min(cfg.decoder_bs, video_data.shape[0]) | |
| video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) | |
| decode_data = [] | |
| for vd_data in video_data_list: | |
| gen_frames = autoencoder.decode(vd_data) | |
| decode_data.append(gen_frames) | |
| video_data = torch.cat(decode_data, dim=0) | |
| video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() | |
| text_size = cfg.resolution[-1] | |
| cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') | |
| name = f'seed_{cur_seed}' | |
| for ii in partial_keys_one: | |
| name = name + "_" + ii | |
| file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' | |
| local_path = os.path.join(cfg.log_dir, f'{file_name}') | |
| os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| captions = "human" | |
| del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] | |
| del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] | |
| save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, | |
| cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) | |
| # try: | |
| # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) | |
| # logging.info('Save video to dir %s:' % (local_path)) | |
| # except Exception as e: | |
| # logging.info(f'Step: save text or video error with {e}') | |
| logging.info('Congratulations! The inference is completed!') | |
| # synchronize to finish some processes | |
| if not cfg.debug: | |
| torch.cuda.synchronize() | |
| dist.barrier() | |
| def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): | |
| if use_fps_condition is True: | |
| partial_keys.append('fps') | |
| partial_model_kwargs = [{}, {}] | |
| for partial_key in partial_keys: | |
| partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] | |
| partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] | |
| return partial_model_kwargs | |