Spaces:
Running
Running
| """Model-related code and constants.""" | |
| import spaces | |
| import dataclasses | |
| import os | |
| import re | |
| import PIL.Image | |
| # pylint: disable=g-bad-import-order | |
| import gradio_helpers | |
| import llama_cpp | |
| ORGANIZATION = 'abetlen' | |
| BASE_MODELS = [ | |
| ('paligemma-3b-mix-224-gguf', 'paligemma-3b-mix-224'), | |
| ] | |
| MODELS = { | |
| **{ | |
| model_name: ( | |
| f'{ORGANIZATION}/{repo}', | |
| (f'{model_name}-text-model-q4_k_m.gguf', f'{model_name}-mmproj-f16.gguf'), | |
| ) | |
| for repo, model_name in BASE_MODELS | |
| }, | |
| } | |
| MODELS_INFO = { | |
| 'paligemma-3b-mix-224': ( | |
| 'GGUF PaliGemma 3B weights quantized in Q4_K_M Format, finetuned with 224x224 input images and 256 token input/output ' | |
| 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' | |
| 'bfloat16 and float16 format for research purposes only.' | |
| ), | |
| } | |
| MODELS_RES_SEQ = { | |
| 'paligemma-3b-mix-224': (224, 256), | |
| } | |
| # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. | |
| # Below value should be smaller than "available RAM - one model". | |
| # A single bf16 is about 5860 MB. | |
| MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) | |
| # config = paligemma_bv.PaligemmaConfig( | |
| # ckpt='', # will be set below | |
| # res=224, | |
| # text_len=64, | |
| # tokenizer='gemma(tokensets=("loc", "seg"))', | |
| # vocab_size=256_000 + 1024 + 128, | |
| # ) | |
| def get_cached_model( | |
| model_name: str, | |
| ):# -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: | |
| """Returns model and params, using RAM cache.""" | |
| res, seq = MODELS_RES_SEQ[model_name] | |
| model_path = gradio_helpers.get_paths()[model_name] | |
| config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) | |
| model, params_cpu = gradio_helpers.get_memory_cache( | |
| config_, | |
| lambda: paligemma_bv.load_model(config_), | |
| max_cache_size_bytes=MAX_RAM_CACHE, | |
| ) | |
| return model, params_cpu | |
| def pil_image_to_base64(image: PIL.Image.Image) -> str: | |
| """Converts PIL image to base64.""" | |
| import io | |
| import base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format='JPEG') | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| def generate( | |
| model_name: str, sampler: str, image: PIL.Image.Image, prompt: str | |
| ) -> str: | |
| """Generates output with specified `model_name`, `sampler`.""" | |
| # model, params_cpu = get_cached_model(model_name) | |
| # batch = model.shard_batch(model.prepare_batch([image], [prompt])) | |
| # with gradio_helpers.timed('sharding'): | |
| # params = model.shard_params(params_cpu) | |
| # with gradio_helpers.timed('computation', start_message=True): | |
| # tokens = model.predict(params, batch, sampler=sampler) | |
| model_path, clip_path = gradio_helpers.get_paths()[model_name] | |
| print(model_path) | |
| print(gradio_helpers.get_paths()) | |
| model = llama_cpp.Llama( | |
| model_path, | |
| chat_handler=llama_cpp.llama_chat_format.PaliGemmaChatHandler( | |
| clip_path | |
| ), | |
| n_ctx=1024, | |
| n_ubatch=512, | |
| n_batch=512, | |
| n_gpu_layers=-1, | |
| ) | |
| return model.create_chat_completion(messages=[{ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": prompt | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": "data:image/jpeg;base64," + pil_image_to_base64(image) | |
| } | |
| ] | |
| }])["choices"][0]["message"]["content"] | |