Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Chameleon License found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import base64 | |
| import os | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from pathlib import Path | |
| import click | |
| import torch | |
| from flask import Flask, request | |
| from flask_socketio import SocketIO | |
| from chameleon.inference.chameleon import ChameleonInferenceModel, Options, TokenManager | |
| class Request: | |
| room: str | |
| key: str | |
| options: dict[str, int | float | bool] | |
| prompt_ui: list[dict] | |
| def convert_options(ui_options: dict) -> Options: | |
| txt = None | |
| if ui_options["enable-text"]: | |
| txt = Options.Text( | |
| repetition_penalty=ui_options["text-rep-penalty"], | |
| temp=ui_options["text-temp"], | |
| top_p=ui_options["text-top-p"], | |
| ) | |
| img = None | |
| if ui_options["enable-image"]: | |
| img = Options.Image( | |
| cfg=Options.Image.CFG( | |
| guidance_scale_image=ui_options["img-cfg-gsimage"], | |
| guidance_scale_text=ui_options["img-cfg-gstext"], | |
| ), | |
| temp=ui_options["img-temp"], | |
| top_p=ui_options["img-top-p"], | |
| ) | |
| return Options( | |
| max_seq_len=ui_options["max-seq-len"], | |
| max_gen_len=ui_options["max-gen-len"], | |
| seed=ui_options["seed"], | |
| txt=txt, | |
| img=img, | |
| ) | |
| class UIDecoder: | |
| class State(Enum): | |
| TXT = 1 | |
| IMG = 2 | |
| IMG_END = 3 | |
| def __init__(self, token_manager: TokenManager): | |
| self.token_manager = token_manager | |
| self.state = UIDecoder.State.TXT | |
| self.image_builder = [] | |
| self.image_yield_every_n = 32 | |
| self.image_has_updated = False | |
| def _image_progress(self) -> dict: | |
| self.image_has_updated = False | |
| png = self.token_manager.png_from_bpe_tokens(torch.cat(self.image_builder)) | |
| return { | |
| "type": "image", | |
| "value": "data:image/png;base64," + base64.b64encode(png).decode(), | |
| } | |
| def next(self, gpu_token: torch.LongTensor) -> dict | None: | |
| if self.state == UIDecoder.State.TXT: | |
| cpu_tok = gpu_token.item() | |
| if cpu_tok == self.token_manager.vocab.begin_image: | |
| self.state = UIDecoder.State.IMG | |
| return {"type": "image_start"} | |
| return { | |
| "type": "text", | |
| "value": self.token_manager.tokenizer.decode([cpu_tok]), | |
| } | |
| elif self.state == UIDecoder.State.IMG: | |
| self.image_builder.append(gpu_token) | |
| self.image_has_updated = True | |
| if len(self.image_builder) == 1024: | |
| self.state = UIDecoder.State.IMG_END | |
| if len(self.image_builder) % self.image_yield_every_n == 0: | |
| return self._image_progress() | |
| elif self.state == UIDecoder.State.IMG_END: | |
| # assert gpu_token == end_image | |
| self.state = UIDecoder.State.TXT | |
| progress = self._image_progress() if self.image_has_updated else None | |
| self.image_builder = [] | |
| return progress | |
| class State: | |
| room_keys: dict[str, set[str]] | |
| pending_requests: list[Request] | |
| cond: threading.Condition | |
| def __enter__(self, *args, **kwargs): | |
| self.cond.__enter__(*args, **kwargs) | |
| return self | |
| def __exit__(self, *args, **kwargs): | |
| self.cond.__exit__(*args, **kwargs) | |
| return self | |
| GlobalState = State(room_keys={}, pending_requests=[], cond=threading.Condition()) | |
| app = Flask(__name__) | |
| socketio = SocketIO(app, max_http_buffer_size=16 * 1024 * 1024) | |
| def index(): | |
| with open(Path(__file__).parent / "miniviewer.html") as f: | |
| return f.read() | |
| def handle_disconnect(): | |
| with GlobalState as state: | |
| try: | |
| del state.room_keys[request.sid] | |
| except KeyError: | |
| pass | |
| def handle_cancel(key): | |
| with GlobalState as state: | |
| try: | |
| state.room_keys[request.sid].remove(key) | |
| except KeyError: | |
| pass | |
| def handle_generate(key, options, prompt_ui): | |
| with GlobalState as state: | |
| if request.sid not in state.room_keys: | |
| state.room_keys[request.sid] = set() | |
| state.room_keys[request.sid].add(key) | |
| state.pending_requests.append(Request(request.sid, key, options, prompt_ui)) | |
| state.cond.notify_all() | |
| def generation_thread(model: ChameleonInferenceModel): | |
| while True: | |
| with GlobalState as state: | |
| state.cond.wait_for(lambda: state.pending_requests) | |
| req = state.pending_requests.pop(0) | |
| start = time.time() | |
| ui_decoder = UIDecoder(model.token_manager) | |
| options = convert_options(req.options) | |
| if not options.txt: | |
| progress = ui_decoder.next( | |
| torch.tensor([model.token_manager.vocab.begin_image]) | |
| ) | |
| socketio.emit( | |
| "progress", | |
| {"key": req.key, **progress}, | |
| room=req.room, | |
| ) | |
| for token in model.stream( | |
| prompt_ui=req.prompt_ui, | |
| options=options, | |
| ): | |
| with GlobalState as state: | |
| if req.key not in state.room_keys.get(req.room, {}): | |
| break | |
| if progress := ui_decoder.next(token.id): | |
| socketio.emit( | |
| "progress", | |
| {"key": req.key, **progress}, | |
| room=req.room, | |
| ) | |
| timing = time.time() - start | |
| socketio.emit( | |
| "progress", | |
| {"key": req.key, "type": "done", "value": timing}, | |
| room=req.room, | |
| ) | |
| def queue_position_thread(): | |
| local_pending_requests = [] | |
| while True: | |
| with GlobalState as state: | |
| state.cond.wait_for( | |
| lambda: local_pending_requests != state.pending_requests | |
| ) | |
| local_pending_requests = state.pending_requests[:] | |
| for i, req in enumerate(local_pending_requests): | |
| progress = { | |
| "type": "queue", | |
| "key": req.key, | |
| "value": i + 1, | |
| } | |
| socketio.emit("progress", progress, room=req.room) | |
| def main(data_path, model_size): | |
| data_path = Path(data_path) | |
| model_path = str(data_path / "models" / model_size) | |
| tokenizer_path = str(data_path / "tokenizer/text_tokenizer.json") | |
| vqgan_cfg_path = str(data_path / "tokenizer/vqgan.yaml") | |
| vqgan_ckpt_path = str(data_path / "tokenizer/vqgan.ckpt") | |
| if not os.path.exists(model_path): | |
| raise ValueError( | |
| "Model not found. Did you run python -m chameleon.download_data {PRESIGNED_URL}" | |
| ) | |
| cm3v2_inference_model = ChameleonInferenceModel( | |
| model_path, tokenizer_path, vqgan_cfg_path, vqgan_ckpt_path | |
| ) | |
| threading.Thread( | |
| target=generation_thread, | |
| args=(cm3v2_inference_model,), | |
| daemon=True, | |
| ).start() | |
| threading.Thread(target=queue_position_thread, daemon=True).start() | |
| socketio.run(app, debug=False) | |
| if __name__ == "__main__": | |
| main() | |