Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| from mistral_inference.transformer import Transformer | |
| from mistral_inference.generate import generate | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| # モデルのダウンロードと準備 | |
| mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral') | |
| mistral_models_path.mkdir(parents=True, exist_ok=True) | |
| snapshot_download(repo_id="mistral-community/pixtral-12b-240910", | |
| allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], | |
| local_dir=mistral_models_path) | |
| # トークナイザーとモデルのロード | |
| tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json") | |
| model = Transformer.from_folder(mistral_models_path) | |
| # 推論処理 | |
| def mistral_inference(prompt, image_url): | |
| completion_request = ChatCompletionRequest( | |
| messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])] | |
| ) | |
| encoded = tokenizer.encode_chat_completion(completion_request) | |
| images = encoded.images | |
| tokens = encoded.tokens | |
| out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
| result = tokenizer.decode(out_tokens[0]) | |
| return result | |
| # 言語によるUIラベルの設定 | |
| def get_labels(language): | |
| labels = { | |
| 'en': { | |
| 'title': "Pixtral Model Image Description", | |
| 'text_prompt': "Text Prompt", | |
| 'image_url': "Image URL", | |
| 'output': "Model Output", | |
| 'image_display': "Input Image", | |
| 'submit': "Run Inference" | |
| }, | |
| 'zh': { | |
| 'title': "Pixtral模型图像描述", | |
| 'text_prompt': "文本提示", | |
| 'image_url': "图片网址", | |
| 'output': "模型输出", | |
| 'image_display': "输入图片", | |
| 'submit': "运行推理" | |
| }, | |
| 'jp': { | |
| 'title': "Pixtralモデルによる画像説明生成", | |
| 'text_prompt': "テキストプロンプト", | |
| 'image_url': "画像URL", | |
| 'output': "モデルの出力結果", | |
| 'image_display': "入力された画像", | |
| 'submit': "推論を実行" | |
| } | |
| } | |
| return labels[language] | |
| # Gradioインターフェース | |
| def process_input(text, image_url): | |
| result = mistral_inference(text, image_url) | |
| return result, f'<img src="{image_url}" alt="Input Image" width="300">' | |
| def update_ui(language): | |
| labels = get_labels(language) | |
| return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit'] | |
| with gr.Blocks() as demo: | |
| language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en') | |
| title = gr.Markdown("## Pixtral Model Image Description") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.") | |
| image_input = gr.Textbox(label="Image URL", placeholder="e.g. https://example.com/image.png") | |
| result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整 | |
| image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所 | |
| submit_button = gr.Button("Run Inference") | |
| submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output]) | |
| # 言語変更時にUIラベルを更新 | |
| language_choice.change( | |
| fn=update_ui, | |
| inputs=[language_choice], | |
| outputs=[title, text_input, image_input, result_output, image_output, submit_button] | |
| ) | |
| demo.launch() |