Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| # set CUDA_MODULE_LOADING=LAZY to speed up the serverless function | |
| os.environ["CUDA_MODULE_LOADING"] = "LAZY" | |
| # set SAFETENSORS_FAST_GPU=1 to speed up the serverless function | |
| os.environ["SAFETENSORS_FAST_GPU"] = "1" | |
| import cv2 | |
| import torch | |
| import time | |
| import imageio | |
| import numpy as np | |
| from tqdm import tqdm | |
| import moviepy.editor as mp | |
| import torch | |
| from audio import load_wav, melspectrogram | |
| from fete_model import FETE_model | |
| from preprocess_videos import face_detect, load_from_npz | |
| fps = 25 | |
| mel_idx_multiplier = 80.0 / fps | |
| mel_step_size = 16 | |
| batch_size = 64 if torch.cuda.is_available() else 4 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Using {} for inference.".format(device)) | |
| use_fp16 = True if torch.cuda.is_available() else False | |
| print("Using FP16 for inference.") if use_fp16 else None | |
| torch.backends.cudnn.benchmark = True if device == "cuda" else False | |
| def init_model(): | |
| checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors") | |
| model = FETE_model() | |
| if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"): | |
| if device == "cuda": | |
| checkpoint = torch.load(checkpoint_path) | |
| else: | |
| checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
| s = checkpoint["state_dict"] | |
| else: | |
| from safetensors import safe_open | |
| s = {} | |
| with safe_open(checkpoint_path, framework="pt", device=device) as f: | |
| for key in f.keys(): | |
| s[key] = f.get_tensor(key) | |
| new_s = {} | |
| for k, v in s.items(): | |
| new_s[k.replace("module.", "")] = v | |
| model.load_state_dict(new_s) | |
| model = model.to(device) | |
| model.eval() | |
| print("Model loaded") | |
| if use_fp16: | |
| for name, module in model.named_modules(): | |
| if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name: | |
| # keep attention layers in full precision to avoid error | |
| module.to(torch.float) | |
| else: | |
| module.to(torch.half) | |
| print("Model converted to half precision to accelerate inference") | |
| return model | |
| def make_mask(image_size=256, border_size=32): | |
| mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0) | |
| mask = np.zeros((image_size, image_size), dtype=np.float32) | |
| mask[-border_size:, :] += mask_bar.T[::-1] | |
| mask[:, :border_size] = mask_bar | |
| mask[:, -border_size:] = mask_bar[:, ::-1] | |
| mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6 | |
| mask = np.stack([mask] * 3, axis=-1).astype(np.float32) | |
| return mask | |
| face_mask = make_mask() | |
| def blend_images(foreground, background): | |
| # Blend the foreground and background images using the mask | |
| temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0])) | |
| blended = cv2.multiply(foreground.astype(np.float32), temp_mask) | |
| blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask) | |
| blended = np.clip(blended, 0, 255).astype(np.uint8) | |
| return blended | |
| def smooth_coord(last_coord, current_coord, factor=0.4): | |
| change = np.array(current_coord) - np.array(last_coord) | |
| change = change * factor | |
| return (np.array(last_coord) + np.array(change)).astype(int).tolist() | |
| def add_black(imgs): | |
| for i in range(len(imgs)): | |
| # print('x', imgs[i].shape) | |
| imgs[i] = cv2.vconcat( | |
| [np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)] | |
| ) | |
| # imgs[i] = cv2.hconcat([np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8), imgs[i], np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8)])[:480+150,740-100:-740+100,:] | |
| # print('xx', imgs[i].shape) | |
| return imgs | |
| def remove_black(img): | |
| return img[100:-20] | |
| def resize_length(input_attributes, length): | |
| input_attributes = np.array(input_attributes) | |
| resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)] | |
| return np.array(resized_attributes).T | |
| def output_chunks(input_attributes): | |
| output_chunks = [] | |
| len_ = len(input_attributes[0]) | |
| i = 0 | |
| # print(mel.shape, pose.shape) | |
| # (80, 801) (3, 801) | |
| while 1: | |
| start_idx = int(i * mel_idx_multiplier) | |
| if start_idx + mel_step_size > len_: | |
| output_chunks.append(input_attributes[:, len_ - mel_step_size :]) | |
| break | |
| output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size]) | |
| i += 1 | |
| return output_chunks | |
| def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]): | |
| if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]: | |
| static = True | |
| full_frames = [cv2.imread(face_path)] | |
| else: | |
| static = False | |
| video_stream = cv2.VideoCapture(face_path) | |
| # print('Reading video frames...') | |
| full_frames = [] | |
| while 1: | |
| still_reading, frame = video_stream.read() | |
| if not still_reading: | |
| video_stream.release() | |
| break | |
| full_frames.append(frame) | |
| print("Number of frames available for inference: " + str(len(full_frames))) | |
| wav = load_wav(audio_path, 16000) | |
| mel = melspectrogram(wav) | |
| # take half | |
| len_ = mel.shape[1] # //2 | |
| mel = mel[:, :len_] | |
| # print('>>>', mel.shape) | |
| pose = resize_length(pose, len_) | |
| emotion = resize_length(emotion, len_) | |
| blink = resize_length(blink, len_) | |
| if np.isnan(mel.reshape(-1)).sum() > 0: | |
| raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again") | |
| mel_chunks = output_chunks(mel) | |
| pose_chunks = output_chunks(pose) | |
| emotion_chunks = output_chunks(emotion) | |
| blink_chunks = output_chunks(blink) | |
| gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads) | |
| steps = int(np.ceil(float(len(mel_chunks)) / batch_size)) | |
| return gen, steps | |
| def preprocess_batch(batch): | |
| return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device) | |
| def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]): | |
| img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], [] | |
| scale_factor = img_size // 128 | |
| # print("Length of mel chunks: {}".format(len(mel_chunks))) | |
| frames = frames[: len(mels)] | |
| frames = add_black(frames) | |
| try: | |
| video_name = os.path.basename(face_path).split(".")[0] | |
| coords = load_from_npz(video_name) | |
| face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] | |
| except Exception as e: | |
| print("No existing coords found, running face detection...", "Error: ", e) | |
| if not static: | |
| coords = face_detect(frames, pads) | |
| face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] | |
| else: | |
| coords = face_detect([frames[0]], pads) | |
| face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] | |
| face_det_results = face_det_results[: len(mels)] | |
| while len(frames) < len(mels): | |
| face_det_results = face_det_results + face_det_results[::-1] | |
| frames = frames + frames[::-1] | |
| else: | |
| face_det_results = face_det_results[: len(mels)] | |
| frames = frames[: len(mels)] | |
| for i in range(len(mels)): | |
| idx = 0 if static else i % len(frames) | |
| frame_to_save = frames[idx].copy() | |
| face, coords = face_det_results[idx].copy() | |
| face = cv2.resize(face, (img_size, img_size)) | |
| img_batch.append(face) | |
| mel_batch.append(mels[i]) | |
| pose_batch.append(poses[i]) | |
| emotion_batch.append(emotions[i]) | |
| blink_batch.append(blinks[i]) | |
| frame_batch.append(frame_to_save) | |
| coords_batch.append(coords) | |
| # print(m.shape, poses[i].shape) | |
| # (80, 16) (3, 16) | |
| if len(img_batch) >= batch_size: | |
| img_masked = np.asarray(img_batch).copy() | |
| img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0 | |
| img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 | |
| img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) | |
| mel_batch = preprocess_batch(mel_batch) | |
| pose_batch = preprocess_batch(pose_batch) | |
| emotion_batch = preprocess_batch(emotion_batch) | |
| blink_batch = preprocess_batch(blink_batch) | |
| if use_fp16: | |
| yield ( | |
| img_batch.half(), | |
| mel_batch.half(), | |
| pose_batch.half(), | |
| emotion_batch.half(), | |
| blink_batch.half(), | |
| ), frame_batch, coords_batch | |
| else: | |
| yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch | |
| img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], [] | |
| if len(img_batch) > 0: | |
| img_masked = np.asarray(img_batch).copy() | |
| img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0 | |
| img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 | |
| img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) | |
| mel_batch = preprocess_batch(mel_batch) | |
| pose_batch = preprocess_batch(pose_batch) | |
| emotion_batch = preprocess_batch(emotion_batch) | |
| blink_batch = preprocess_batch(blink_batch) | |
| if use_fp16: | |
| yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch | |
| else: | |
| yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch | |
| def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False): | |
| timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time())) | |
| gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink) | |
| steps = 1 if preview else steps | |
| # duration = librosa.get_duration(filename=audio_path) | |
| if preview: | |
| outfile = "/tmp/{}.jpg".format(timestamp) | |
| else: | |
| outfile = "/tmp/{}.mp4".format(timestamp) | |
| tmp_video = "/tmp/temp_{}.mp4".format(timestamp) | |
| writer = ( | |
| imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1) | |
| if not preview | |
| else None | |
| ) | |
| # print('Generating frames...', outfile, steps) | |
| for inputs, frames, coords in tqdm(gen, total=steps): | |
| with torch.no_grad(): | |
| pred = model(*inputs) | |
| pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0 | |
| for p, f, c in zip(pred, frames, coords): | |
| y1, y2, x1, x2 = c | |
| y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2) | |
| y = round(y2 - y1) | |
| x = round(x2 - x1) | |
| p = cv2.resize(p.astype(np.uint8), (x, y)) | |
| try: | |
| f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p) | |
| except Exception as e: | |
| print(e) | |
| f[y1 : y1 + y, x1 : x1 + x] = p | |
| f = remove_black(f) | |
| if preview: | |
| cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) | |
| return outfile | |
| writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)) | |
| writer.close() | |
| video_clip = mp.VideoFileClip(tmp_video) | |
| audio_clip = mp.AudioFileClip(audio_path) | |
| video_clip = video_clip.set_audio(audio_clip) | |
| video_clip.write_videofile(outfile, codec="libx264") | |
| print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile)) | |
| try: | |
| os.remove(tmp_video) | |
| del video_clip | |
| del audio_clip | |
| del gen | |
| except: | |
| pass | |
| return outfile | |
| if __name__ == "__main__": | |
| model = init_model() | |
| from attributtes_utils import input_pose, input_emotion, input_blink | |
| pose = input_pose() | |
| emotion = input_emotion() | |
| blink = input_blink() | |
| audio_path = "./assets/sample.wav" | |
| face_path = "./assets/sample.mp4" | |
| infenrece(model, face_path, audio_path, pose, emotion, blink) | |