Spaces:
Runtime error
Runtime error
| # from transformers_stream_generator import init_stream_support | |
| # init_stream_support() | |
| import spaces | |
| import os | |
| import numpy as np | |
| import argparse | |
| import torch | |
| import gradio as gr | |
| from typing import Any, Iterator | |
| from typing import Iterator, List, Optional, Tuple | |
| import filelock | |
| import glob | |
| import json | |
| import time | |
| from gradio.routes import Request | |
| from gradio.utils import SyncToAsyncIterator, async_iteration | |
| from gradio.helpers import special_args | |
| import anyio | |
| from typing import AsyncGenerator, Callable, Literal, Union, cast | |
| from gradio_client.documentation import document, set_documentation_group | |
| from typing import List, Optional, Union, Dict, Tuple | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| from gradio.components import Button | |
| from gradio.events import Dependency, EventListenerMethod | |
| from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer | |
| import types | |
| import sys | |
| from .base_engine import BaseEngine | |
| from .transformers_engine import TransformersEngine, NewGenerationMixin | |
| from ..configs import ( | |
| STREAM_CHECK_MULTIPLE, | |
| STREAM_YIELD_MULTIPLE, | |
| ) | |
| CODE_PATH = os.environ.get("CODE_PATH", "") | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "") | |
| IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]" | |
| IMAGE_LENGTH = 576 | |
| MAX_PACHES = 1 | |
| BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", "")) | |
| BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else [] | |
| LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) | |
| KEYWORDS = os.environ.get("KEYWORDS", "").strip() | |
| KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] | |
| KEYWORDS = [x.lower() for x in KEYWORDS] | |
| LANG_BLOCK_MESSAGE = """Unsupported language.""" | |
| KEYWORD_BLOCK_MESSAGE = "Invalid request." | |
| def _detect_lang(text): | |
| # Disable language that may have safety risk | |
| from langdetect import detect as detect_lang | |
| dlang = None | |
| try: | |
| dlang = detect_lang(text) | |
| except Exception as e: | |
| if "No features in text." in str(e): | |
| return "en" | |
| else: | |
| return "zh" | |
| return dlang | |
| def block_lang( | |
| message: str, | |
| history: List[Tuple[str, str]] = None, | |
| ) -> str: | |
| # relieve history base block | |
| if len(BLOCK_LANGS) == 0: | |
| return False | |
| if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): | |
| return True | |
| else: | |
| _lang = _detect_lang(message) | |
| if _lang in BLOCK_LANGS: | |
| # print(f'Detect blocked {_lang}: {message}') | |
| return True | |
| else: | |
| return False | |
| def safety_check(text, history=None, ) -> Optional[str]: | |
| """ | |
| Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
| This provides an additional security measure to enhance safety and compliance with local regulations. | |
| """ | |
| if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
| return KEYWORD_BLOCK_MESSAGE | |
| if len(BLOCK_LANGS) > 0: | |
| if block_lang(text, history): | |
| return LANG_BLOCK_MESSAGE | |
| return None | |
| def safety_check_conversation_string(text, delimiter=None) -> Optional[str]: | |
| if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
| return KEYWORD_BLOCK_MESSAGE | |
| if len(BLOCK_LANGS) > 0: | |
| import re | |
| delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n") | |
| turns = re.split(r"|".join(delimiter), text) | |
| turns = [t for t in turns if t.strip() != ''] | |
| for t in turns: | |
| if block_lang(t): | |
| return LANG_BLOCK_MESSAGE | |
| return None | |
| def is_check_safety(): | |
| return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0 | |
| def safety_check_conversation(conversation) -> Optional[str]: | |
| """ | |
| Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
| This provides an additional security measure to enhance safety and compliance with local regulations. | |
| """ | |
| texts = [c['content'] for c in conversation] | |
| for text in texts: | |
| if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
| return KEYWORD_BLOCK_MESSAGE | |
| if len(BLOCK_LANGS) > 0: | |
| if block_lang(text): | |
| return LANG_BLOCK_MESSAGE | |
| return None | |
| class SeaLMMMv0Engine(TransformersEngine): | |
| def image_token(self): | |
| return IMAGE_TOKEN | |
| def max_position_embeddings(self) -> int: | |
| return self._model.config.max_position_embeddings | |
| def tokenizer(self): | |
| return self._tokenizer | |
| def processor(self): | |
| return self._processor | |
| def load_model(self): | |
| from transformers import AutoProcessor | |
| import sys | |
| # caution: path[0] is reserved for script path (or '' in REPL) | |
| # sys.path.append(CODE_PATH) | |
| # from examples.llm.src.models.sealmm.modeling_sealmm import ( | |
| # SeaLMMForCausalLM | |
| # ) | |
| from .modeling_sealmm import (SeaLMMForCausalLM, ) | |
| model_path = MODEL_PATH | |
| print(f'Loading model from {model_path}') | |
| print(f'model_path={model_path}') | |
| if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"): | |
| os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin") | |
| self._processor = AutoProcessor.from_pretrained(model_path) | |
| self._model = SeaLMMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval() | |
| self._model.sample_old = self._model.sample | |
| self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
| self._tokenizer = self._processor.tokenizer | |
| print(self._model) | |
| print(f"{self.max_position_embeddings=}") | |
| def get_multimodal_tokens(self, full_prompt, image_paths=None): | |
| num_tokens = len(self.tokenizer.encode(full_prompt)) | |
| for image_path in image_paths: | |
| num_tokens += IMAGE_LENGTH * MAX_PACHES | |
| return num_tokens | |
| def maybe_raise_safety(self, message, gen_index=-1): | |
| if is_check_safety(): | |
| if gen_index < 0: | |
| message_safety = safety_check_conversation_string(message) | |
| if message_safety is not None: | |
| raise gr.Error(message_safety) | |
| else: | |
| if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0: | |
| message_safety = safety_check_conversation_string(message) | |
| if message_safety is not None: | |
| raise gr.Error(message_safety) | |
| def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): | |
| from transformers.generation.utils import GenerationConfig | |
| from PIL import Image | |
| image_paths = kwargs.get("image_paths", None) | |
| image_paths = image_paths or [] | |
| images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None | |
| # 4.38 .sample | |
| # 4.39 ._sample | |
| # need to put @spaces.GPU on the gradio function call | |
| self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
| with torch.no_grad(): | |
| inputs = self.processor(prompt, images, return_tensors='pt') | |
| # inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None} | |
| # model.device | |
| inputs = {k: v.to(self._model.device) for k, v in inputs.items() if v is not None} | |
| num_tokens = self.get_multimodal_tokens(prompt, image_paths) | |
| # non-streaming generation | |
| # output = self._model.generate( | |
| # **inputs, | |
| # do_sample=True, | |
| # temperature=temperature, | |
| # max_new_tokens=max_tokens, | |
| # pad_token_id=self.processor.tokenizer.pad_token_id, | |
| # ) | |
| # # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True) | |
| # full_output_text = self.processor.decode(output[0], skip_special_tokens=True) | |
| # response = full_output_text.split("<|im_start|>assistant\n")[-1] | |
| # num_tokens = self.get_multimodal_tokens(prompt + response, image_paths) | |
| # print(prompt) | |
| # print(response) | |
| # print(num_tokens) | |
| # yield response, num_tokens | |
| # if i % 4 == 0 and i > 1: | |
| # message_safety = safety_check(response) | |
| # if message_safety is not None: | |
| # history = undo_history(history) | |
| # yield history, "", None | |
| # raise gr.Error(message_safety) | |
| self.maybe_raise_safety(prompt) | |
| # # ! streaming | |
| generator = self._model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| pad_token_id=self.processor.tokenizer.pad_token_id, | |
| ) | |
| out_tokens = [] | |
| response = None | |
| for index, token in enumerate(generator): | |
| out_tokens.append(token.item()) | |
| response = self.processor.tokenizer.decode(out_tokens) | |
| self.maybe_raise_safety(response, gen_index=index) | |
| yield response, num_tokens | |
| del generator | |
| if response is not None: | |
| self.maybe_raise_safety(prompt) | |
| full_text = prompt + response | |
| num_tokens = self.get_multimodal_tokens(full_text, image_paths) | |
| yield response, num_tokens | |