Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ["GRADIO_SSR_MODE"] = "false" | |
| if not os.path.exists("checkpoints"): | |
| os.makedirs("checkpoints") | |
| os.system("pip install gdown") | |
| os.system("gdown https://drive.google.com/uc?id=1eQe6blJcyI7oy78C8ozwj1IUkbkFEItf; unzip -o dam_3b_v1.zip -d checkpoints") | |
| from segment_anything import sam_model_registry, SamPredictor | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| import io | |
| import argparse | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from transformers import SamModel, SamProcessor | |
| from dam import DescribeAnythingModel, disable_torch_init | |
| try: | |
| from spaces import GPU | |
| except ImportError: | |
| print("Spaces not installed, using dummy GPU decorator") | |
| GPU = lambda fn: fn | |
| # Load SAM model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
| sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
| def image_to_sam_embedding(base64_image): | |
| try: | |
| # Decode base64 string to bytes | |
| image_bytes = base64.b64decode(base64_image) | |
| # Convert bytes to PIL Image | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Process image with SAM processor | |
| inputs = sam_processor(image, return_tensors="pt").to(device) | |
| # Get image embedding | |
| with torch.no_grad(): | |
| image_embedding = sam_model.get_image_embeddings(inputs["pixel_values"]) | |
| # Convert to CPU and numpy | |
| image_embedding = image_embedding.cpu().numpy() | |
| # Encode the embedding as base64 | |
| embedding_bytes = image_embedding.tobytes() | |
| embedding_base64 = base64.b64encode(embedding_bytes).decode('utf-8') | |
| return embedding_base64 | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| raise gr.Error(f"Failed to process image: {str(e)}") | |
| def describe(image_base64: str, mask_base64: str, query: str): | |
| # Convert base64 to PIL Image | |
| image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) | |
| mask = Image.open(io.BytesIO(mask_bytes)) | |
| # Process the mask | |
| mask = Image.fromarray((np.array(mask.convert('L')) > 0).astype(np.uint8) * 255) | |
| # Get description using DAM with streaming | |
| description_generator = dam.get_description(img, mask, query, streaming=True) | |
| # Stream the tokens | |
| text = "" | |
| for token in description_generator: | |
| text += token | |
| yield text | |
| def describe_without_streaming(image_base64: str, mask_base64: str, query: str): | |
| # Convert base64 to PIL Image | |
| image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) | |
| mask = Image.open(io.BytesIO(mask_bytes)) | |
| # Process the mask | |
| mask = Image.fromarray((np.array(mask.convert('L')) > 0).astype(np.uint8) * 255) | |
| # Get description using DAM | |
| description = dam.get_description(img, mask, query) | |
| return description | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Describe Anything gradio demo") | |
| parser.add_argument("--model-path", type=str, default="checkpoints/dam_3b_v1", help="Path to the model checkpoint") | |
| parser.add_argument("--prompt-mode", type=str, default="full+focal_crop", help="Prompt mode") | |
| parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode") | |
| parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature") | |
| parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling") | |
| args = parser.parse_args() | |
| # Initialize DAM model | |
| disable_torch_init() | |
| dam = DescribeAnythingModel( | |
| model_path=args.model_path, | |
| conv_mode=args.conv_mode, | |
| prompt_mode=args.prompt_mode, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| num_beams=1, | |
| max_new_tokens=512, | |
| ).to(device) | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Interface( | |
| fn=image_to_sam_embedding, | |
| inputs=gr.Textbox(label="Image Base64"), | |
| outputs=gr.Textbox(label="Embedding Base64"), | |
| title="Image Embedding Generator", | |
| api_name="image_to_sam_embedding" | |
| ) | |
| gr.Interface( | |
| fn=describe, | |
| inputs=[ | |
| gr.Textbox(label="Image Base64"), | |
| gr.Text(label="Mask Base64"), | |
| gr.Text(label="Prompt") | |
| ], | |
| outputs=[ | |
| gr.Text(label="Description") | |
| ], | |
| title="Mask Description Generator", | |
| api_name="describe" | |
| ) | |
| gr.Interface( | |
| fn=describe_without_streaming, | |
| inputs=[ | |
| gr.Textbox(label="Image Base64"), | |
| gr.Text(label="Mask Base64"), | |
| gr.Text(label="Prompt") | |
| ], | |
| outputs=[ | |
| gr.Text(label="Description") | |
| ], | |
| title="Mask Description Generator (Non-Streaming)", | |
| api_name="describe_without_streaming" | |
| ) | |
| demo._block_thread = demo.block_thread | |
| demo.block_thread = lambda: None | |
| demo.launch() | |
| for route in demo.app.routes: | |
| if route.path == "/": | |
| demo.app.routes.remove(route) | |
| demo.app.mount("/", StaticFiles(directory="dist", html=True), name="demo") | |
| demo._block_thread() | |