Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
| from transformers.image_utils import load_image | |
| from threading import Thread | |
| import time | |
| import torch | |
| import spaces | |
| DESCRIPTION = """ | |
| # Qwen2.5-VL-3B/7B-Instruct | |
| """ | |
| css = ''' | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| #duplicate-button { | |
| margin: auto; | |
| color: #fff; | |
| background: #1565c0; | |
| border-radius: 100vh; | |
| } | |
| ''' | |
| # Define an animated progress bar HTML snippet | |
| def progress_bar_html(label: str) -> str: | |
| return f''' | |
| <div style="display: flex; align-items: center;"> | |
| <span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
| <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;"> | |
| <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
| </div> | |
| </div> | |
| <style> | |
| @keyframes loading {{ | |
| 0% {{ transform: translateX(-100%); }} | |
| 100% {{ transform: translateX(100%); }} | |
| }} | |
| </style> | |
| ''' | |
| # Model IDs for 3B and 7B variants | |
| MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| # Load the processor and models for both versions | |
| processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True) | |
| model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_3B, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda").eval() | |
| processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True) | |
| model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_7B, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda").eval() | |
| def model_inference(input_dict, history): | |
| text = input_dict["text"] | |
| files = input_dict["files"] | |
| # Determine which model to use based on the prefix tag | |
| if text.lower().startswith("@3b"): | |
| yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct") | |
| selected_model = model_3b | |
| selected_processor = processor_3b | |
| text = text[len("@3b"):].strip() | |
| elif text.lower().startswith("@7b"): | |
| yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct") | |
| selected_model = model_7b | |
| selected_processor = processor_7b | |
| text = text[len("@7b"):].strip() | |
| else: | |
| yield "Error: Please prefix your query with @3b or @7b to select the model." | |
| return | |
| # Load images if provided | |
| if files: | |
| if isinstance(files, list): | |
| if len(files) > 1: | |
| images = [load_image(image) for image in files] | |
| elif len(files) == 1: | |
| images = [load_image(files[0])] | |
| else: | |
| images = [] | |
| else: | |
| images = [load_image(files)] | |
| else: | |
| images = [] | |
| # Validate input: text query is required | |
| if text == "": | |
| yield "Error: Please input a text query along with the image(s) if any." | |
| return | |
| # Prepare messages for the model | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| *[{"type": "image", "image": image} for image in images], | |
| {"type": "text", "text": text}, | |
| ] | |
| }] | |
| # Apply the chat template and process the inputs | |
| prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = selected_processor( | |
| text=[prompt], | |
| images=images if images else None, | |
| return_tensors="pt", | |
| padding=True, | |
| ).to("cuda") | |
| # Set up a streamer for real-time text generation | |
| streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
| # Start generation in a separate thread | |
| thread = Thread(target=selected_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield an animated progress message | |
| yield progress_bar_html("Almost there, hold tight!") | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| time.sleep(0.01) | |
| yield buffer | |
| # Example inputs with model prefixes | |
| examples = [ | |
| [{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}], | |
| [{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}], | |
| [{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}], | |
| [{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}], | |
| ] | |
| demo = gr.ChatInterface( | |
| fn=model_inference, | |
| description=DESCRIPTION, | |
| css=css, | |
| examples=examples, | |
| textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"), | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| cache_examples=False, | |
| ) | |
| demo.launch(debug=True) |