Spaces:
Runtime error
Runtime error
| import google.generativeai as genai | |
| from google.generativeai.types import HarmBlockThreshold, HarmCategory | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import json | |
| # Fetch bounding boxes and labels | |
| async def get_bounding_boxes(prompt: str, image: str, api_key: str): | |
| system_prompt = """ | |
| You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else. | |
| Your response can also include multiple bounding boxes and their labels in the list. | |
| The values in the list should be integers. | |
| Here are some example responses: | |
| { | |
| "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.", | |
| "bounding_boxes": [ | |
| {"label": "dragon", "box": [ymin, xmin, ymax, xmax]} | |
| ] | |
| } | |
| { | |
| "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.", | |
| "bounding_boxes": [ | |
| {"label": "apple", "box": [ymin, xmin, ymax, xmax]}, | |
| {"label": "tomato", "box": [ymin, xmin, ymax, xmax]} | |
| ] | |
| } | |
| """.strip() | |
| prompt = f"Return the bounding boxes and labels of: {prompt}" | |
| messages = [ | |
| {"role": "user", "parts": [prompt, image]}, | |
| ] | |
| genai.configure(api_key=api_key) | |
| generation_config = { | |
| "temperature": 1, | |
| "max_output_tokens": 8192, | |
| "response_mime_type": "application/json", | |
| } | |
| model = genai.GenerativeModel( | |
| model_name="gemini-2.0-flash", | |
| generation_config=generation_config, | |
| safety_settings={ | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE | |
| }, | |
| system_instruction=system_prompt | |
| ) | |
| try: | |
| response = await model.generate_content_async(messages) | |
| except Exception as e: | |
| if "API key not valid" in str(e): | |
| raise gr.Error( | |
| "Invalid API key. Please provide a valid Gemini API key.") | |
| elif "rate limit" in str(e).lower(): | |
| raise gr.Error("Rate limit exceeded for the API key.") | |
| else: | |
| raise gr.Error(f"Failed to generate content: {str(e)}") | |
| response_json = json.loads(response.text) | |
| explanation = response_json["explanation"] | |
| bounding_boxes = response_json["bounding_boxes"] | |
| return bounding_boxes, explanation | |
| # Adjust bounding boxes based on image size | |
| async def adjust_bounding_box(bounding_boxes, image): | |
| width, height = image.size | |
| adjusted_boxes = [] | |
| for item in bounding_boxes: | |
| label = item["label"] | |
| ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]] | |
| xmin *= width | |
| xmax *= width | |
| ymin *= height | |
| ymax *= height | |
| adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]}) | |
| return adjusted_boxes | |
| # Process the image and draw bounding boxes and labels | |
| async def process_image(image, text, api_key): | |
| if not api_key: | |
| raise gr.Error("Please provide a Gemini API key.") | |
| # Open the image using PIL | |
| image = Image.open(image) | |
| # Call the async bounding box function | |
| bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key) | |
| # Adjust the bounding box based on the image dimensions | |
| adjusted_boxes = await adjust_bounding_box(bounding_boxes, image) | |
| # Draw the bounding boxes and labels on the image | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.load_default(size=20) | |
| for item in adjusted_boxes: | |
| box = item["box"] | |
| label = item["label"] | |
| draw.rectangle(box, outline="red", width=3) | |
| # Draw the label above the bounding box | |
| draw.text((box[0], box[1] - 25), label, fill="red", font=font) | |
| # Format adjusted boxes for display | |
| adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes) | |
| return explanation, image, adjusted_boxes_str | |
| # Gradio app | |
| async def gradio_app(image, text, api_key): | |
| return await process_image(image, text, api_key) | |
| # Launch the Gradio interface | |
| iface = gr.Interface( | |
| fn=gradio_app, | |
| inputs=[ | |
| gr.Image(type="filepath"), | |
| gr.Textbox(label="Object(s) to detect", value="person"), | |
| gr.Textbox(label="Your Gemini API Key", type="password") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Explanation"), | |
| gr.Image(type="pil", label="Output Image"), | |
| gr.Textbox(label="Coordinates of the detected objects") | |
| ], | |
| title="Gemini Object Detection ✨", | |
| description="Detect objects in images using the Gemini 2.0 Flash model.", | |
| allow_flagging="never" | |
| ) | |
| iface.launch() | |