Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Chameleon License found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import base64 | |
| import io | |
| import json | |
| import math | |
| import queue | |
| import threading | |
| from dataclasses import dataclass, field | |
| from tqdm import tqdm | |
| from enum import Enum | |
| from multiprocessing import managers, queues, synchronize | |
| from typing import Literal, Union | |
| import PIL | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from PIL.Image import Image | |
| from tokenizers import Tokenizer | |
| from transformers import ( | |
| LogitsProcessor, | |
| RepetitionPenaltyLogitsProcessor, | |
| TemperatureLogitsWarper, | |
| TopPLogitsWarper, | |
| enable_full_determinism, | |
| ) | |
| from chameleon.inference import loader | |
| from chameleon.inference.alignment import AlignPromptRight | |
| from chameleon.inference.generation import ChameleonGenerator | |
| from chameleon.inference.image_tokenizer import ImageTokenizer | |
| from chameleon.inference.logits_processor import ( | |
| AllowOnlyTokensLogitsProcessor, | |
| DisallowTokensAtOrAfterIndexLogitsProcessor, | |
| InBatchInstructCFGLogitsProcessor, | |
| ) | |
| from chameleon.inference.model_adapter import ChameleonModelAdapter | |
| from chameleon.inference.stopping_criteria import ( | |
| MaxLengthCriteria, | |
| StopOnEOSAfterBatchIndex, | |
| ) | |
| from chameleon.inference.token_selector import ( | |
| ArgmaxTokenSelector, | |
| MultinomialTokenSelector, | |
| ReplicatedInputTokenSelector, | |
| ) | |
| from chameleon.inference.transformer import Transformer | |
| from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port | |
| from chameleon.inference.vocab import VocabInfo, VocabTranslation | |
| class Options: | |
| class Text: | |
| repetition_penalty: float = 1.2 | |
| temp: float = 1.0 | |
| top_p: float = 0.9 | |
| greedy: bool = False | |
| class Image: | |
| class CFG: | |
| guidance_scale_text: float = 3.0 | |
| guidance_scale_image: float = 1.2 | |
| cfg: CFG = field(default_factory=CFG) | |
| temp: float = 0.7 | |
| top_p: float = 0.9 | |
| greedy: bool = False | |
| max_seq_len: int = 4096 | |
| max_gen_len: int = 4096 | |
| seed: int | None = None | |
| txt: Text | bool = True | |
| img: Image | bool = True | |
| extra_eos_tokens: list[int | str] = field(default_factory=lambda: []) | |
| def __post_init__(self): | |
| if self.txt is True: | |
| self.txt = Options.Text() | |
| if self.img is True: | |
| self.img = Options.Image() | |
| class TokenManager: | |
| def __init__( | |
| self, | |
| tokenizer_path: str, | |
| vqgan_cfg_path: str, | |
| vqgan_ckpt_path: str, | |
| device: str | None = None, | |
| ): | |
| self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
| self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) | |
| self.translation = VocabTranslation(self.vocab, device=device) | |
| self.image_tokenizer = ImageTokenizer( | |
| cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device | |
| ) | |
| def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image: | |
| image_tensor = self.translation.convert_bpe2img(bpe_tokens) | |
| if image_tensor.shape[0] < 1024: | |
| padding = ( | |
| torch.ones( | |
| [1024 - image_tensor.shape[0]], | |
| dtype=int, | |
| device=image_tensor.device, | |
| ) | |
| * image_tensor[0] | |
| ) | |
| image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0) | |
| return self.image_tokenizer.pil_from_img_toks(image_tensor) | |
| def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes: | |
| pil = self.pil_from_bpe_tokens(bpe_tokens) | |
| img_io = io.BytesIO() | |
| pil.save(img_io, format="PNG") | |
| return img_io.getvalue() | |
| def tokenize_text(self, text: str) -> list[int]: | |
| return self.tokenizer.encode(text).ids | |
| def tokenize_image(self, img: Image) -> list[int]: | |
| return ( | |
| [self.vocab.begin_image] | |
| + self.translation.convert_img2bp2( | |
| self.image_tokenizer.img_tokens_from_pil(img) # [0 : 8191], vqgan codebook ids | |
| ).tolist() | |
| + [self.vocab.end_image] | |
| ) | |
| def tokenize_b64img(self, b64img: str) -> list[int]: | |
| image_data = base64.b64decode(b64img) | |
| image_file = io.BytesIO(image_data) | |
| return self.tokenize_image(PIL.Image.open(image_file)) | |
| def tokens_from_ui(self, inputs: list[dict]) -> list[int]: | |
| tokens = [self.vocab.bos_id] | |
| for input_ in inputs: | |
| if input_["type"] == "text": | |
| tokens += self.tokenize_text(input_["value"]) | |
| elif input_["type"] == "image": | |
| if isinstance(input_["value"], str): | |
| if input_["value"].startswith("data:"): | |
| # Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}' | |
| tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1]) | |
| elif input_["value"].startswith("file:"): | |
| tokens += self.tokenize_image( | |
| PIL.Image.open(input_["value"].split(":", 1)[1]) | |
| ) | |
| else: | |
| raise ValueError("Unknown image format.") | |
| elif isinstance(input_["value"], Image): | |
| tokens += self.tokenize_image(input_["value"]) | |
| else: | |
| raise ValueError("Unknown image type.") | |
| elif input_["type"] == "sentinel": | |
| tokens += [ | |
| { | |
| "<START-OF-IMAGE>": self.vocab.begin_image, | |
| "<END-OF-TURN>": self.vocab.eot_id, | |
| }[input_["value"]] | |
| ] | |
| elif input_["type"] == "ids": | |
| tokens += input_["value"] | |
| else: | |
| raise ValueError("Unknown input type.") | |
| return tokens | |
| def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]: | |
| if isinstance(ids, torch.Tensor): | |
| ids = ids.tolist() | |
| for row, values in enumerate(ids): | |
| try: | |
| ids[row] = values[: values.index(self.vocab.eos_id)] | |
| except ValueError: | |
| pass | |
| return self.tokenizer.decode_batch(ids) | |
| def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]: | |
| return [self.pil_from_bpe_tokens(sample) for sample in ids] | |
| class DecodePiece: | |
| token: ChameleonGenerator.Token | |
| next_decoder: type["Decoder"] | None | |
| class Decoder: | |
| def __init__( | |
| self, | |
| model: Transformer, | |
| vocab: VocabInfo, | |
| options: Options, | |
| input_ids: list[int], | |
| ): ... | |
| def __next__(self) -> DecodePiece: ... | |
| class TextDecoder(Decoder): | |
| def __init__( | |
| self, | |
| model: Transformer, | |
| vocab: VocabInfo, | |
| options: Options, | |
| input_ids: list[list[int]], | |
| ): | |
| self.vocab = vocab | |
| self.options = options | |
| assert vocab.eos_id is not None | |
| prompt_lens = [len(inp) for inp in input_ids] | |
| max_prompt_len = max(prompt_lens) | |
| max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len) | |
| self.eos_ids = [vocab.eos_id] | |
| for extra_eos_token in options.extra_eos_tokens: | |
| if isinstance(extra_eos_token, str): | |
| extra_eos_token = vocab.name2val[extra_eos_token] | |
| assert isinstance(extra_eos_token, int) | |
| self.eos_ids.append(extra_eos_token) | |
| stopping_criteria = [ | |
| MaxLengthCriteria(max_seq_len), | |
| ] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids] | |
| self.gen = ChameleonGenerator( | |
| model=ChameleonModelAdapter(model, max_seq_len=max_seq_len), | |
| input_ids=input_ids, | |
| stopping_criteria=stopping_criteria, | |
| logits_processors=self._logits_processors(), | |
| alignment=AlignPromptRight(vocab.pad_id), | |
| token_selector=( | |
| ArgmaxTokenSelector() | |
| if options.txt.greedy | |
| else MultinomialTokenSelector() | |
| ), | |
| ) | |
| advance(self.gen, max_prompt_len) | |
| def _allowed_tokens(self) -> list[int]: | |
| allowed_tokens = [self.vocab.eos_id] | |
| if self.options.txt: | |
| allowed_tokens += self.vocab.text_tokens | |
| if self.options.img: | |
| allowed_tokens += [self.vocab.begin_image] | |
| return allowed_tokens | |
| def _logits_processors(self) -> list[LogitsProcessor]: | |
| logits_processors = [ | |
| AllowOnlyTokensLogitsProcessor(self._allowed_tokens()), | |
| ] | |
| if isinstance(self.options.img, Options.Image): | |
| logits_processors += [ | |
| DisallowTokensAtOrAfterIndexLogitsProcessor( | |
| [self.vocab.begin_image], | |
| self.options.max_seq_len - 1026, | |
| ), | |
| ] | |
| if isinstance(self.options.txt, Options.Text): | |
| logits_processors += [ | |
| RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty), | |
| TemperatureLogitsWarper(self.options.txt.temp), | |
| TopPLogitsWarper(self.options.txt.top_p), | |
| ] | |
| return logits_processors | |
| def __next__(self) -> DecodePiece: | |
| tok = next(self.gen) | |
| next_decoder = None | |
| if ( | |
| self.vocab.begin_image not in self.eos_ids | |
| and (tok.id == self.vocab.begin_image).all() | |
| ): | |
| next_decoder = ImageDecoder | |
| return DecodePiece(tok, next_decoder) | |
| class ImageDecoder(Decoder): | |
| def __init__( | |
| self, | |
| model: Transformer, | |
| vocab: VocabInfo, | |
| options: Options, | |
| input_ids: list[list[int]], | |
| ): | |
| assert isinstance(options.img, Options.Image) | |
| self.vocab = vocab | |
| self.options = options | |
| self.batch_size = len(input_ids) | |
| logits_processors = [ | |
| InBatchInstructCFGLogitsProcessor( | |
| options.img.cfg.guidance_scale_text, | |
| options.img.cfg.guidance_scale_image, | |
| ), | |
| AllowOnlyTokensLogitsProcessor(vocab.image_tokens), | |
| TemperatureLogitsWarper(options.img.temp), | |
| TopPLogitsWarper(options.img.top_p), | |
| ] | |
| for inp in input_ids: | |
| if inp[-1] != self.vocab.begin_image: | |
| inp.append(self.vocab.begin_image) | |
| max_prompt_len = max(len(inp) for inp in input_ids) | |
| self.gen = ChameleonGenerator( | |
| model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024), | |
| input_ids=self._split_inputs_for_cfg(input_ids), | |
| logits_processors=logits_processors, | |
| alignment=AlignPromptRight(vocab.pad_id), | |
| token_selector=ReplicatedInputTokenSelector( | |
| ( | |
| ArgmaxTokenSelector() | |
| if options.img.greedy | |
| else MultinomialTokenSelector() | |
| ), | |
| n=3, | |
| ), | |
| ) | |
| advance(self.gen, max_prompt_len) | |
| self.gen_count = 0 | |
| def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]: | |
| image_conditioned_allowed = set(self.vocab.image_tokens) | { | |
| self.vocab.bos_id, | |
| self.vocab.begin_image, | |
| self.vocab.end_image, | |
| } | |
| full_conditioned = input_ids | |
| image_conditioned = [ | |
| [id for id in sample if id in image_conditioned_allowed] | |
| for sample in input_ids | |
| ] | |
| unconditioned = [ | |
| [ | |
| self.vocab.bos_id, | |
| self.vocab.begin_image, | |
| ] | |
| ] * self.batch_size | |
| return full_conditioned + image_conditioned + unconditioned | |
| def __next__(self) -> DecodePiece: | |
| if self.gen_count == 1024: | |
| id = torch.tensor([self.vocab.end_image] * self.batch_size) | |
| logits = torch.full( | |
| (self.batch_size, len(self.vocab.all_tokens)), -math.inf | |
| ) | |
| logits[:, self.vocab.end_image] = 0 | |
| return DecodePiece( | |
| ChameleonGenerator.Token(id=id, logits=logits), | |
| TextDecoder, | |
| ) | |
| tok = next(self.gen) | |
| tok.id = tok.id.chunk(3)[0] | |
| self.gen_count += 1 | |
| return DecodePiece(tok, None) | |
| class Generator(Decoder): | |
| def __init__( | |
| self, | |
| model: Transformer, | |
| vocab: VocabInfo, | |
| options: Options, | |
| input_ids: list[list[int]], | |
| ): | |
| if options.seed is not None: | |
| enable_full_determinism(options.seed, warn_only=True) | |
| self.model = model | |
| self.vocab = vocab | |
| self.input_ids = input_ids[:] | |
| self.generated_token_ids: list[torch.LongTensor] = [] | |
| self.options = options | |
| if not self.options.txt: | |
| self.dyngen = DynamicGenerator( | |
| ImageDecoder(model, vocab, options, input_ids) | |
| ) | |
| else: | |
| self.dyngen = DynamicGenerator( | |
| TextDecoder(model, vocab, options, input_ids) | |
| ) | |
| def __iter__(self): | |
| return self | |
| def __next__(self) -> ChameleonGenerator.Token: | |
| piece = next(self.dyngen) | |
| self.generated_token_ids.append(piece.token.id) | |
| if piece.next_decoder is not None: | |
| if not self.options.txt: | |
| raise StopIteration | |
| self.input_ids = [ | |
| old_list + generated | |
| for old_list, generated in zip( | |
| self.input_ids, torch.stack(self.generated_token_ids).T.tolist() | |
| ) | |
| ] | |
| self.generated_token_ids = [] | |
| self.dyngen.gen = piece.next_decoder( | |
| self.model, | |
| self.vocab, | |
| self.options, | |
| self.input_ids, | |
| ) | |
| return piece.token | |
| class DistributedMode(Enum): | |
| AUTO = 0 | |
| THREAD = 1 | |
| PROCESS = 2 | |
| class _DistributedContext: | |
| req_q: Union[queue.Queue, queues.Queue] | |
| res_q: Union[queue.Queue, queues.Queue] | |
| active_key: Union[dict[int, Literal[True]], managers.DictProxy] | |
| active_key_lock: Union[threading.Lock, synchronize.Lock] | |
| ready_barrier: Union[threading.Barrier, synchronize.Barrier] | |
| worker_launcher: Union[type[threading.Thread], type[mp.Process]] | |
| def make_for_threading(world_size: int): | |
| return _DistributedContext( | |
| req_q=queue.Queue(), | |
| res_q=queue.Queue(), | |
| active_key={}, | |
| active_key_lock=threading.Lock(), | |
| ready_barrier=threading.Barrier(world_size + 1), | |
| worker_launcher=threading.Thread, | |
| ) | |
| def make_for_multiprocessing(world_size: int): | |
| local_mp = mp.get_context("spawn") | |
| return _DistributedContext( | |
| req_q=local_mp.Queue(), | |
| res_q=local_mp.Queue(), | |
| active_key=local_mp.Manager().dict(), | |
| active_key_lock=local_mp.Lock(), | |
| ready_barrier=local_mp.Barrier(world_size + 1), | |
| worker_launcher=local_mp.Process, | |
| ) | |
| def make(mode: DistributedMode, world_size: int): | |
| if mode == DistributedMode.AUTO: | |
| mode = DistributedMode.PROCESS | |
| if mode == DistributedMode.THREAD: | |
| return _DistributedContext.make_for_threading(world_size) | |
| elif mode == DistributedMode.PROCESS: | |
| return _DistributedContext.make_for_multiprocessing(world_size) | |
| else: | |
| raise ValueError("Unknown DistributedMode") | |
| def _worker_impl( | |
| init_method: str, | |
| model: Transformer | str, | |
| world_size: int, | |
| rank: int, | |
| vocab: VocabInfo, | |
| dctx: _DistributedContext, | |
| ): | |
| dist.init_process_group( | |
| "nccl", | |
| init_method=init_method, | |
| world_size=world_size, | |
| rank=rank, | |
| ) | |
| torch.set_default_device(f"cuda:{rank}") | |
| torch.cuda.set_device(rank) | |
| if isinstance(model, str): | |
| model = loader.load_model(model, rank=rank) | |
| dctx.ready_barrier.wait() | |
| is_coord = rank == 0 | |
| while True: | |
| req = [Options(), [], 0, False] | |
| if is_coord: | |
| req = dctx.req_q.get() | |
| dist.broadcast_object_list(req, src=0) | |
| options, input_ids, key, shutdown = req | |
| if shutdown: | |
| break | |
| for token in Generator( | |
| model=model, | |
| vocab=vocab, | |
| options=options, | |
| input_ids=input_ids, | |
| ): | |
| if is_coord: | |
| dctx.res_q.put((key, token)) | |
| to_continue = [True] | |
| if is_coord: | |
| with dctx.active_key_lock: | |
| to_continue = [key in dctx.active_key] | |
| dist.broadcast_object_list(to_continue, src=0) | |
| if not to_continue[0]: | |
| break | |
| if is_coord: | |
| dctx.res_q.put((key, None)) | |
| class ChameleonInferenceModel: | |
| def __init__( | |
| self, | |
| model: Transformer | str, | |
| tokenizer_path: str, | |
| vqgan_cfg_path: str, | |
| vqgan_ckpt_path: str, | |
| *, | |
| options: Options | None = None, | |
| distributed_mode: DistributedMode = DistributedMode.AUTO, | |
| ): | |
| self.options = options or Options() | |
| self.next_key = 0 | |
| self.token_manager = TokenManager( | |
| tokenizer_path=tokenizer_path, | |
| vqgan_cfg_path=vqgan_cfg_path, | |
| vqgan_ckpt_path=vqgan_ckpt_path, | |
| device="cuda", | |
| ) | |
| self.vocab = self.token_manager.vocab | |
| world_size = 1 | |
| if isinstance(model, str): | |
| world_size = loader.detect_shard_count(model) | |
| self.dctx = _DistributedContext.make(distributed_mode, world_size) | |
| init_method = f"tcp://0.0.0.0:{random_unused_port()}" | |
| self.workers = [ | |
| self.dctx.worker_launcher( | |
| target=_worker_impl, | |
| args=(init_method, model, world_size, i, self.vocab, self.dctx), | |
| daemon=True, | |
| ) | |
| for i in range(world_size) | |
| ] | |
| for w in self.workers: | |
| w.start() | |
| self.dctx.ready_barrier.wait() | |
| def __del__(self): | |
| try: | |
| with self.dctx.active_key_lock: | |
| self.dctx.active_key.clear() | |
| self.dctx.req_q.put([None, None, None, True]) | |
| for w in self.workers: | |
| w.join() | |
| except FileNotFoundError: | |
| pass | |
| def stream( | |
| self, | |
| *, | |
| input_ids: list[int] | None = None, | |
| prompt_text: str | None = None, | |
| prompt_ui: list[dict] | None = None, | |
| batch_input_ids: list[list[int]] | None = None, | |
| batch_prompt_text: list[str] | None = None, | |
| batch_prompt_ui: list[list[dict]] | None = None, | |
| options: Options | None = None, | |
| ): | |
| # NOTE: Not thread-safe! Only one instance of generate may be run at a time. | |
| if ( | |
| sum( | |
| x is not None | |
| for x in [ | |
| input_ids, | |
| prompt_text, | |
| prompt_ui, | |
| batch_input_ids, | |
| batch_prompt_text, | |
| batch_prompt_ui, | |
| ] | |
| ) | |
| != 1 | |
| ): | |
| raise ValueError( | |
| "Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui" | |
| ) | |
| options = options or self.options | |
| if prompt_text is not None: | |
| batch_prompt_text = [prompt_text] | |
| if prompt_ui is not None: | |
| batch_prompt_ui = [prompt_ui] | |
| if input_ids is not None: | |
| batch_input_ids = [input_ids] | |
| if batch_prompt_text is not None: | |
| batch_prompt_ui = [ | |
| [{"type": "text", "value": prompt_text}] | |
| for prompt_text in batch_prompt_text | |
| ] | |
| if batch_prompt_ui is not None: | |
| batch_input_ids = [ | |
| self.token_manager.tokens_from_ui(prompt_ui) | |
| for prompt_ui in batch_prompt_ui | |
| ] | |
| assert batch_input_ids | |
| if not options.txt and not options.img: | |
| raise ValueError("Must specify at least one modality.") | |
| if options.txt and options.img and len(batch_input_ids) > 1: | |
| raise ValueError( | |
| "Batch generation only supported for one modality at a time." | |
| ) | |
| req_key = self.next_key | |
| self.next_key += 1 | |
| with self.dctx.active_key_lock: | |
| self.dctx.active_key[req_key] = True | |
| self.dctx.req_q.put([options, batch_input_ids, req_key, False]) | |
| try: | |
| while key_token := self.dctx.res_q.get(): | |
| key, token = key_token | |
| if key != req_key: | |
| # Residual from prior calls to generation. Skip. | |
| continue | |
| if token is None: | |
| break | |
| yield token | |
| finally: | |
| with self.dctx.active_key_lock: | |
| del self.dctx.active_key[req_key] | |
| def step(self, *args, **kwargs) -> ChameleonGenerator.Token: | |
| return next(self.stream(*args, **kwargs)) | |
| def generate(self, *args, **kwargs) -> torch.LongTensor: | |
| tokens = [t.id for t in self.stream(*args, **kwargs)] | |
| if not tokens: | |
| return torch.LongTensor() | |
| return torch.stack(tokens).T | |
| def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]: | |
| return self.token_manager.decode_text(ids) | |
| def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]: | |
| return self.token_manager.decode_image(ids) | |
| def sft_tokenization(self, json_path: str) -> list[dict]: | |
| with open(json_path, 'r') as input_file: | |
| jsonl_input = [json.loads(line) for line in input_file] | |
| output_data = [] | |
| for entry in tqdm(jsonl_input, desc="Tokenize dataset"): | |
| # print(i) | |
| text_tokens = self.token_manager.tokenize_text(entry['text']) | |
| image_tokens = self.token_manager.tokenize_image(PIL.Image.open(entry['image'])) | |
| entry['text_tokens'] = text_tokens | |
| entry['image_tokens'] = image_tokens | |
| output_data.append(entry) | |
| return output_data | |