Spaces:
Runtime error
Runtime error
| # Install required dependency | |
| # !pip install mistral-common | |
| import gradio as gr | |
| import torch | |
| import tempfile | |
| import os | |
| from typing import List, Tuple | |
| from transformers import VoxtralForConditionalGeneration, AutoProcessor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| repo_id = "mistralai/Voxtral-Mini-3B-2507" | |
| processor = AutoProcessor.from_pretrained(repo_id) | |
| model = VoxtralForConditionalGeneration.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| ) | |
| def respond(audio_files: List[str], question: str) -> Tuple[str, List[str]]: | |
| if not audio_files: | |
| return "Please upload at least one audio file.", [] | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "path": path} for path in audio_files | |
| ] + [{"type": "text", "text": question}], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template(conversation) | |
| inputs = inputs.to(device, dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=500) | |
| decoded = processor.batch_decode( | |
| outputs[:, inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| ) | |
| return decoded[0], audio_files | |
| demo = gr.Interface( | |
| fn=respond, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Audio files", file_count="multiple"), | |
| gr.Textbox(lines=2, placeholder="Ask something about the audio(s)...", label="Question"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Answer"), | |
| gr.Gallery(label="Uploaded audio files"), | |
| ], | |
| title="Voxtral-Mini-3B-2507 Audio Q&A", | |
| description="Upload one or more audio files and ask any question about them.", | |
| examples=[ | |
| [ | |
| [ | |
| "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/mary_had_lamb.mp3", | |
| "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/winning_call.mp3", | |
| ], | |
| "What sport and what nursery rhyme are referenced?", | |
| ] | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |