|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer |
|
|
from typing import Any, List, Dict |
|
|
import base64 |
|
|
import mimetypes |
|
|
from pathlib import Path |
|
|
|
|
|
def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]: |
|
|
file_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
with open(file_path, "r") as file: |
|
|
system_prompt = file.read() |
|
|
|
|
|
index_begin_think = system_prompt.find("[THINK]") |
|
|
index_end_think = system_prompt.find("[/THINK]") |
|
|
|
|
|
return { |
|
|
"role": "system", |
|
|
"content": [ |
|
|
{"type": "text", "text": system_prompt[:index_begin_think]}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": system_prompt[index_end_think + len("[/THINK]") :], |
|
|
}, |
|
|
], |
|
|
} |
|
|
|
|
|
model_id = "mistralai/Magistral-Small-2509" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral") |
|
|
model = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
model_id, torch_dtype=torch.bfloat16, device_map="auto" |
|
|
).eval() |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def predict(message: dict, history: list) -> str: |
|
|
|
|
|
messages = [SYSTEM_PROMPT] |
|
|
for user_msg, assistant_msg in history: |
|
|
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) |
|
|
|
|
|
|
|
|
user_content = [{"type": "text", "text": message['text']}] |
|
|
if message['files']: |
|
|
|
|
|
image_path = Path(message['files'][0]) |
|
|
image_bytes = image_path.read_bytes() |
|
|
encoded_image = base64.b64encode(image_bytes).decode("utf-8") |
|
|
mime_type, _ = mimetypes.guess_type(image_path) |
|
|
if mime_type is None: |
|
|
mime_type = "image/png" |
|
|
data_url = f"data:{mime_type};base64,{encoded_image}" |
|
|
user_content.append({"type": "image_url", "image_url": {"url": data_url}}) |
|
|
|
|
|
messages.append({"role": "user", "content": user_content}) |
|
|
|
|
|
tokenized = tokenizer.apply_chat_template(messages, return_dict=True) |
|
|
|
|
|
input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0) |
|
|
attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0) |
|
|
|
|
|
if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0: |
|
|
pixel_values = torch.tensor( |
|
|
tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda" |
|
|
).unsqueeze(0) |
|
|
image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0) |
|
|
output = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=pixel_values, |
|
|
image_sizes=image_sizes, |
|
|
)[0] |
|
|
else: |
|
|
output = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
)[0] |
|
|
|
|
|
decoded_output = tokenizer.decode( |
|
|
output[ |
|
|
len(tokenized.input_ids) : ( |
|
|
-1 if output[-1] == tokenizer.eos_token_id else len(output) |
|
|
) |
|
|
] |
|
|
) |
|
|
return decoded_output |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=predict, |
|
|
multimodal=True, |
|
|
title="Magistral Chat App", |
|
|
description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>', |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |