Spaces:
Runtime error
Runtime error
| import base64 | |
| from io import BytesIO | |
| import torch | |
| from fastapi import FastAPI, Query | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration | |
| from fastapi import FastAPI, Body | |
| from pydantic import BaseModel | |
| app = FastAPI() | |
| # Define request model | |
| class PredictRequest(BaseModel): | |
| image_base64: str | |
| prompt: str | |
| # checkpoint = "Qwen/Qwen2-VL-2B-Instruct" | |
| # min_pixels = 256 * 28 * 28 | |
| # max_pixels = 1280 * 28 * 28 | |
| # processor = AutoProcessor.from_pretrained( | |
| # checkpoint, min_pixels=min_pixels, max_pixels=max_pixels | |
| # ) | |
| # model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| # checkpoint, | |
| # torch_dtype=torch.bfloat16, | |
| # device_map="auto", | |
| # # attn_implementation="flash_attention_2", | |
| # ) | |
| checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| min_pixels = 256*28*28 | |
| max_pixels = 1280*28*28 | |
| processor = AutoProcessor.from_pretrained( | |
| checkpoint, | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels | |
| ) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| checkpoint, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| # attn_implementation="flash_attention_2", | |
| ) | |
| def read_root(): | |
| return {"message": "API is live. Use the /predict endpoint."} | |
| # def encode_image(image_path, max_size=(800, 800), quality=85): | |
| # """ | |
| # Converts an image from a local file path to a Base64-encoded string with optimized size. | |
| # Args: | |
| # image_path (str): The path to the image file. | |
| # max_size (tuple): The maximum width and height of the resized image. | |
| # quality (int): The compression quality (1-100, higher means better quality but bigger size). | |
| # Returns: | |
| # str: Base64-encoded representation of the optimized image. | |
| # """ | |
| # try: | |
| # with Image.open(image_path) as img: | |
| # # Convert to RGB (avoid issues with PNG transparency) | |
| # img = img.convert("RGB") | |
| # # Resize while maintaining aspect ratio | |
| # img.thumbnail(max_size, Image.LANCZOS) | |
| # # Save to buffer with compression | |
| # buffer = BytesIO() | |
| # img.save( | |
| # buffer, format="JPEG", quality=quality | |
| # ) # Save as JPEG to reduce size | |
| # return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| # except Exception as e: | |
| # print(f"❌ Error encoding image {image_path}: {e}") | |
| # return None | |
| def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): | |
| """ | |
| Converts an image from file data to a Base64-encoded string with optimized size. | |
| """ | |
| try: | |
| with Image.open(image_data) as img: | |
| img = img.convert("RGB") | |
| img.thumbnail(max_size, Image.LANCZOS) | |
| buffer = BytesIO() | |
| img.save(buffer, format="JPEG", quality=quality) | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") | |
| async def upload_and_encode_image(file: UploadFile = File(...)): | |
| """ | |
| Endpoint to upload an image file and return its Base64-encoded representation. | |
| """ | |
| try: | |
| image_data = BytesIO(await file.read()) | |
| encoded_string = encode_image(image_data) | |
| return {"filename": file.filename, "encoded_image": encoded_string} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid file: {e}") | |
| def predict(data: PredictRequest): | |
| """ | |
| Generates a description for an image using the Qwen-2-VL model. | |
| Args: | |
| data (any): The encoded image and the prompt to be used. | |
| prompt (str): The text prompt to guide the model's response. | |
| Returns: | |
| str: The generated description of the image. | |
| """ | |
| # Create the input message structure | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": f"data:image;base64,{data.image_base64}"}, | |
| {"type": "text", "text": data.prompt}, | |
| ], | |
| } | |
| ] | |
| # Prepare inputs for the model | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| # Generate the output | |
| generated_ids = model.generate(**inputs, max_new_tokens=2056) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| return {"response": output_text[0] if output_text else "No description generated."} | |
| # @app.get("/predict") | |
| # def predict(image_url: str = Query(...), prompt: str = Query(...)): | |
| # image = encode_image(image_url) | |
| # messages = [ | |
| # { | |
| # "role": "system", | |
| # "content": "You are a helpful assistant with vision abilities.", | |
| # }, | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image", "image": f"data:image;base64,{image}"}, | |
| # {"type": "text", "text": prompt}, | |
| # ], | |
| # }, | |
| # ] | |
| # text = processor.apply_chat_template( | |
| # messages, tokenize=False, add_generation_prompt=True | |
| # ) | |
| # image_inputs, video_inputs = process_vision_info(messages) | |
| # inputs = processor( | |
| # text=[text], | |
| # images=image_inputs, | |
| # videos=video_inputs, | |
| # padding=True, | |
| # return_tensors="pt", | |
| # ).to(model.device) | |
| # with torch.no_grad(): | |
| # generated_ids = model.generate(**inputs, max_new_tokens=128) | |
| # generated_ids_trimmed = [ | |
| # out_ids[len(in_ids) :] | |
| # for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| # ] | |
| # output_texts = processor.batch_decode( | |
| # generated_ids_trimmed, | |
| # skip_special_tokens=True, | |
| # clean_up_tokenization_spaces=False, | |
| # ) | |
| # return {"response": output_texts[0]} | |