import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from typing import Any, List, Dict
import base64
import mimetypes
from pathlib import Path
def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]:
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "r") as file:
system_prompt = file.read()
index_begin_think = system_prompt.find("[THINK]")
index_end_think = system_prompt.find("[/THINK]")
return {
"role": "system",
"content": [
{"type": "text", "text": system_prompt[:index_begin_think]},
{
"type": "text",
"text": system_prompt[index_end_think + len("[/THINK]") :],
},
],
}
model_id = "mistralai/Magistral-Small-2509"
tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral")
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt")
@spaces.GPU(duration=120)
def predict(message: dict, history: list) -> str:
# Build messages for the model from history
messages = [SYSTEM_PROMPT]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg:
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
# Process current user message (with potential image)
user_content = [{"type": "text", "text": message['text']}]
if message['files']:
# Assuming one image file from multimodal textbox
image_path = Path(message['files'][0])
image_bytes = image_path.read_bytes()
encoded_image = base64.b64encode(image_bytes).decode("utf-8")
mime_type, _ = mimetypes.guess_type(image_path)
if mime_type is None:
mime_type = "image/png"
data_url = f"data:{mime_type};base64,{encoded_image}"
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
messages.append({"role": "user", "content": user_content})
tokenized = tokenizer.apply_chat_template(messages, return_dict=True)
input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)
if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0:
pixel_values = torch.tensor(
tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda"
).unsqueeze(0)
image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0)
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_sizes=image_sizes,
)[0]
else:
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
)[0]
decoded_output = tokenizer.decode(
output[
len(tokenized.input_ids) : (
-1 if output[-1] == tokenizer.eos_token_id else len(output)
)
]
)
return decoded_output
demo = gr.ChatInterface(
fn=predict,
multimodal=True,
title='Magistral Chat App',
description='Chat with Magistral AI. Upload an image if relevant to your question.
Built with anycoder',
)
if __name__ == "__main__":
demo.launch()