Spaces:
Configuration error
Configuration error
| from abc import ABC, abstractmethod | |
| from contextlib import nullcontext | |
| import time | |
| import os | |
| from functools import partial | |
| from copy import deepcopy | |
| from multiprocessing import Pool | |
| from threading import Lock | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import einops | |
| from transformers import LlamaForCausalLM | |
| import spaces | |
| from vqvae_muse import VQGANModel, get_tokenizer_muse | |
| from torch_vqvae_model import get_tokenizer | |
| def get_torch_float_dtype(dtype): | |
| if dtype in (torch.float16, torch.bfloat16, torch.float32): | |
| return dtype | |
| return { | |
| 'float16': torch.float16, | |
| 'fp16': torch.float16, | |
| 'f16': torch.float16, | |
| 'bfloat16': torch.bfloat16, | |
| 'bf16': torch.bfloat16, | |
| 'float32': torch.float32, | |
| 'fp32': torch.float32, | |
| 'f32': torch.float32, | |
| }[dtype] | |
| def get_pid(): | |
| time.sleep(1) | |
| return os.getpid() | |
| class InferenceModel(ABC): | |
| def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): | |
| raise NotImplementedError() | |
| class LocalInferenceModel(InferenceModel): | |
| def __init__(self, checkpoint, dtype='float16', torch_device='cuda', | |
| context_frames=16, use_lock=False): | |
| self.checkpoint = checkpoint | |
| self.dtype = dtype | |
| self.torch_device = torch_device | |
| self.context_frames = context_frames | |
| # new tokenizer | |
| self.tokenizer = get_tokenizer_muse() | |
| self.tokenizer.to(self.torch_device) | |
| self.model = LlamaForCausalLM.from_pretrained( | |
| self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype) | |
| ).to(self.torch_device) | |
| print("torch device", self.torch_device) | |
| print("init device", self.model.device) | |
| if use_lock: | |
| self.lock = Lock() | |
| else: | |
| self.lock = nullcontext() | |
| def compute_perplexity(self, input_images, target_images): | |
| input_images = np.array(input_images) | |
| target_images = np.array(target_images) | |
| assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C] | |
| assert input_images.shape[0] == target_images.shape[0] | |
| batch_size = input_images.shape[0] | |
| with self.lock: | |
| input_images = torch.tensor( | |
| einops.rearrange(input_images, 'b s h w c -> b s c h w') | |
| ).to(self.torch_device) | |
| target_images = torch.tensor( | |
| einops.rearrange(target_images, 'b s h w c -> b s c h w') | |
| ).to(self.torch_device) | |
| input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1) | |
| target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1) | |
| all_ids = torch.cat([input_ids, target_ids], dim=1) | |
| logits = self.model(all_ids).logits | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1]) | |
| target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1] | |
| perplexity = torch.exp( | |
| -torch.mean( | |
| torch.sum(target_log_probs * target_ids_onehot, dim=-1), | |
| dim=-1 | |
| ) | |
| ) | |
| return perplexity.detach().cpu().numpy() | |
| def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0): | |
| assert type(input_images) == np.ndarray | |
| with self.lock: | |
| input_images = np.array(input_images, dtype=np.float32) | |
| input_images = torch.tensor( | |
| einops.rearrange(input_images, 'b h w c -> b c h w') | |
| ).to(self.torch_device) | |
| # not quite sure why i need to redo it here | |
| self.model.to(self.torch_device) | |
| self.tokenizer.to(self.torch_device) | |
| # new tokenizer | |
| _, input_ids = self.tokenizer.encode(input_images) | |
| input_ids = input_ids.view(1, -1) | |
| input_ids = input_ids[:, -(self.context_frames - 1) * 256:] | |
| new_tokens = [] | |
| current_context_frames = input_ids.shape[1] // 256 | |
| fisrt_generation_left = self.context_frames - current_context_frames | |
| first_new_frames = min(fisrt_generation_left, n_new_frames) | |
| input_ids = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=torch.ones_like(input_ids), | |
| pad_token_id=8192, | |
| max_new_tokens=256 * first_new_frames, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| suppress_tokens=list(range(8192, self.model.vocab_size)), | |
| ) | |
| new_tokens.append(input_ids[:, -256 * first_new_frames:]) | |
| input_ids = input_ids[:, -(self.context_frames - 1) * 256:] | |
| for _ in range(max(0, n_new_frames - first_new_frames)): | |
| input_ids = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=torch.ones_like(input_ids), | |
| pad_token_id=8192, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| suppress_tokens=list(range(8192, self.model.vocab_size)), | |
| ) | |
| new_tokens.append(input_ids[:, -256:]) | |
| input_ids = input_ids[:, -(self.context_frames - 1) * 256:] | |
| new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256) | |
| new_images = einops.rearrange( | |
| torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0), | |
| 'b c h w -> b h w c' | |
| ).detach().cpu().numpy() | |
| return new_images | |
| def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): | |
| output = [] | |
| for seq in input_images: | |
| output.append( | |
| [self.generate_once(seq, n_new_frames, temperature, top_p) | |
| for _ in range(n_candidates)] | |
| ) | |
| return output | |
| class MultiProcessInferenceModel(InferenceModel): | |
| def __init__(self, checkpoint, torch_devices=None, dtype='float16', | |
| context_frames=16, use_lock=False, perplexity_batch_size=2): | |
| if torch_devices is None or torch_devices == '': | |
| torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())] | |
| self.torch_devices = torch_devices | |
| self.n_processes = len(torch_devices) | |
| print(f'Using {self.n_processes} processes for inference') | |
| self.worker_pool = Pool(self.n_processes) | |
| self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)]) | |
| self.device_map = { | |
| pid: torch_device | |
| for pid, torch_device in zip(self.worker_pids, self.torch_devices) | |
| } | |
| self.worker_pool.starmap( | |
| self.initialize_worker, | |
| [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)] | |
| ) | |
| self.perplexity_batch_size = perplexity_batch_size | |
| if use_lock: | |
| self.lock = Lock() | |
| else: | |
| self.lock = nullcontext() | |
| def initialize_worker(device_map, checkpoint, dtype, context_frames): | |
| global _current_process_backend | |
| torch_device = device_map[os.getpid()] | |
| _current_process_backend = LocalInferenceModel( | |
| checkpoint, dtype, torch_device, context_frames | |
| ) | |
| def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0): | |
| return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p) | |
| def compute_perplexity_once(input_images, target_images): | |
| return _current_process_backend.compute_perplexity(input_images, target_images) | |
| def compute_perplexity(self, input_images, target_images): | |
| with self.lock: | |
| map_args = [] | |
| for i in range(0, len(input_images), self.perplexity_batch_size): | |
| map_args.append(( | |
| input_images[i : i + self.perplexity_batch_size], | |
| target_images[i : i + self.perplexity_batch_size] | |
| )) | |
| outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args) | |
| return np.concatenate(outputs, axis=0) | |
| def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): | |
| with self.lock: | |
| map_args = [] | |
| for seq in input_images: | |
| for _ in range(n_candidates): | |
| map_args.append((seq, n_new_frames, temperature, top_p)) | |
| outputs = self.worker_pool.starmap(self.generate_once, map_args) | |
| reshaped_output = [] | |
| index = 0 | |
| for _ in range(len(input_images)): | |
| candidates = [] | |
| for _ in range(n_candidates): | |
| candidates.append(outputs[index]) | |
| index += 1 | |
| reshaped_output.append(candidates) | |
| return reshaped_output | |