Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -- coding: utf-8 -- | |
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from io import BytesIO | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| # Create the temporary folder if it doesn't exist. | |
| TEMP_DIR = "/temp" | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| app = FastAPI() | |
| # Mount the temporary folder so annotated images can be served at /temp/<filename> | |
| app.mount("/temp", StaticFiles(directory=TEMP_DIR), name="temp") | |
| # Define the request model | |
| class PredictRequest(BaseModel): | |
| image_base64: list[str] | |
| prompt: str | |
| # Use the desired checkpoint: Qwen/Qwen2.5-VL-3B-Instruct-AWQ | |
| checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ" | |
| min_pixels = 256 * 28 * 28 | |
| max_pixels = 1280 * 28 * 28 | |
| # Load the processor with the image resolution settings | |
| processor = AutoProcessor.from_pretrained( | |
| checkpoint, min_pixels=min_pixels, max_pixels=max_pixels | |
| ) | |
| # Load the Qwen2.5-VL model. | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| checkpoint, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| # attn_implementation="flash_attention_2", | |
| ) | |
| def read_root(): | |
| return {"message": "API is live. Use the /predict endpoint."} | |
| def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): | |
| """ | |
| Converts an image from file data to a Base64-encoded string with optimized size. | |
| """ | |
| try: | |
| with Image.open(image_data) as img: | |
| img = img.convert("RGB") | |
| img.thumbnail(max_size, Image.LANCZOS) | |
| buffer = BytesIO() | |
| img.save(buffer, format="JPEG", quality=quality) | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") | |
| async def upload_and_encode_image(file: UploadFile = File(...)): | |
| """ | |
| Endpoint to upload an image file and return its Base64-encoded representation. | |
| """ | |
| try: | |
| image_data = BytesIO(await file.read()) | |
| encoded_string = encode_image(image_data) | |
| return {"filename": file.filename, "encoded_image": encoded_string} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid file: {e}") | |
| def predict(data: PredictRequest, annotate: bool = False): | |
| """ | |
| Generates a description (e.g. bounding boxes with labels) for image(s) using Qwen2.5-VL-3B-Instruct-AWQ. | |
| If 'annotate' is True (as a query parameter), the first image is annotated with the predicted bounding boxes, | |
| stored in a temporary folder, and its URL is returned. | |
| Request: | |
| - image_base64: List of base64-encoded images. | |
| - prompt: A prompt string. | |
| Response (JSON): | |
| { | |
| "response": <text generated by Qwen2.5-VL>, | |
| "annotated_image_url": "/temp/<filename>" # only if annotate=True | |
| } | |
| """ | |
| logging.warning("Calling /predict endpoint...") | |
| # Ensure image_base64 is a list. | |
| image_list = data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64] | |
| # Create input messages: include all images and then the prompt. | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": f"data:image;base64,{image}"} | |
| for image in image_list | |
| ] + [{"type": "text", "text": data.prompt}], | |
| } | |
| ] | |
| logging.info("Processing inputs... Number of images: %d", len(image_list)) | |
| # Prepare inputs for the model using the processor's chat interface. | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| logging.warning("Starting generation...") | |
| start_time = time.time() | |
| # Generate output using the model. | |
| generated_ids = model.generate(**inputs, max_new_tokens=2056) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| generation_time = time.time() - start_time | |
| logging.warning("Generation completed in %.2fs.", generation_time) | |
| # The generated output text is expected to be JSON (e.g., list of detections). | |
| result_text = output_text[0] if output_text else "No description generated." | |
| response_data = {"response": result_text} | |
| if annotate: | |
| # Decode the first image for annotation. | |
| try: | |
| img_str = image_list[0] | |
| # If the image string contains a data URI prefix, remove it. | |
| if img_str.startswith("data:image"): | |
| img_str = img_str.split(",")[1] | |
| img_data = base64.b64decode(img_str) | |
| image = Image.open(BytesIO(img_data)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error decoding image for annotation: {e}") | |
| # Determine image dimensions (width, height) | |
| input_wh = image.size | |
| resolution_wh = input_wh # Assuming no resolution change | |
| # Parse the detection result from the model output. | |
| try: | |
| detection_result = json.loads(result_text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error parsing detection result: {e}") | |
| # Use the supervision library to create detections and annotate the image. | |
| try: | |
| import supervision as sv | |
| detections = sv.Detections.from_vlm( | |
| vlm=sv.VLM.QWEN_2_5_VL, | |
| result=detection_result, | |
| input_wh=input_wh, | |
| resolution_wh=resolution_wh | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error creating detections: {e}") | |
| try: | |
| box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| annotated_image = image.copy() | |
| annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections) | |
| annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error annotating image: {e}") | |
| # Save the annotated image in the temporary folder. | |
| try: | |
| filename = f"{uuid.uuid4()}.jpg" | |
| filepath = os.path.join(TEMP_DIR, filename) | |
| annotated_image.save(filepath, format="JPEG") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving annotated image: {e}") | |
| # Add the annotated image URL to the response. | |
| response_data["annotated_image_url"] = f"/temp/{filename}" | |
| return response_data | |