Spaces:
Running
Running
| import base64 | |
| import io | |
| import os | |
| import random | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from colpali_engine.models import ColPali, ColPaliProcessor | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| from PIL import Image, ImageDraw | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from tqdm import tqdm | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from typing import Iterable | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up device | |
| if torch.cuda.is_available(): | |
| device = "cuda:0" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| print(f"Using device: {device}") | |
| # Set up Qdrant client | |
| QDRANT_URL = os.getenv("QDRANT_URL") | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| qdrant_client = QdrantClient( | |
| url=QDRANT_URL, | |
| api_key=QDRANT_API_KEY, | |
| prefer_grpc=True, | |
| ) | |
| # Load dataset and set up model | |
| dataset = load_dataset("davanstrien/ufo-ColPali", split="train") | |
| collection_name = "ufo" | |
| model_name = "davanstrien/finetune_colpali_v1_2-ufo-4bit" | |
| colpali_model = ColPali.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| ) | |
| colpali_processor = ColPaliProcessor.from_pretrained( | |
| "vidore/colpaligemma-3b-pt-448-base" | |
| ) | |
| def search_images_by_text(query_text, top_k=5): | |
| with torch.no_grad(): | |
| batch_query = colpali_processor.process_queries([query_text]).to( | |
| colpali_model.device | |
| ) | |
| query_embedding = colpali_model(**batch_query) | |
| multivector_query = query_embedding[0].cpu().float().numpy().tolist() | |
| results = qdrant_client.query_points( | |
| collection_name=collection_name, | |
| query=multivector_query, | |
| limit=top_k, | |
| timeout=60, | |
| ) | |
| print(results) | |
| return results | |
| def search_by_text_and_return_images(query_text, top_k=5): | |
| results = search_images_by_text(query_text, top_k) | |
| print(results) | |
| row_ids = [r.id for r in results.points] | |
| subset = dataset.select(row_ids) | |
| return list(subset["image"]) | |
| class Geocities90s(Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.yellow, | |
| secondary_hue: colors.Color | str = colors.purple, | |
| neutral_hue: colors.Color | str = colors.gray, | |
| font: fonts.Font | str = fonts.GoogleFont("Comic Neue"), | |
| font_mono: fonts.Font | str = fonts.GoogleFont("VT323"), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| font=(font, "Comic Sans MS", "ui-sans-serif", "sans-serif"), | |
| font_mono=(font_mono, "Courier New", "monospace"), | |
| ) | |
| self.set( | |
| body_background_fill="url('https://web.archive.org/web/20091020152706/http://hk.geocities.com/neonlightfantasy/image/stars.gif')", | |
| button_primary_background_fill="linear-gradient(90deg, *primary_500, *secondary_500)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *primary_500)", | |
| button_primary_text_color="*neutral_50", | |
| ) | |
| geocities90s = Geocities90s() | |
| css = """ | |
| body { | |
| margin: 0; | |
| padding: 0; | |
| color: #00ff00; | |
| font-family: 'Comic Sans MS', cursive; | |
| } | |
| .gradio-container { | |
| background-image: url('https://i.ytimg.com/vi/5WapcCXEcXA/maxresdefault.jpg'); | |
| background-repeat: repeat; | |
| background-size: 300px 300px; | |
| } | |
| h1 { | |
| text-align: center; | |
| color: #ff00ff; | |
| text-shadow: 2px 2px #000000; | |
| font-size: 36px; | |
| animation: flash 1s linear infinite; | |
| } | |
| @keyframes flash { | |
| 0% { color: #ff00ff; } | |
| 50% { color: #00ffff; } | |
| 100% { color: #ff00ff; } | |
| } | |
| .yellow-text { | |
| color: #ffff00; | |
| text-shadow: 2px 2px #000000; | |
| font-weight: bold; | |
| } | |
| """ | |
| # Replace the demo definition with this Blocks implementation | |
| with gr.Blocks(css=css, theme=geocities90s) as demo: | |
| gr.HTML("<h1>🛸 Top Secret UFO Document Search 🛸</h1>") | |
| gr.HTML( | |
| "<p style='text-align: center; font-style: italic;'>Powered by <a href='https://danielvanstrien.xyz/posts/post-with-code/colpali-qdrant/2024-10-02_using_colpali_with_qdrant.html' target='_blank' style='color: #00ff00;'>ColPali and Qdrant</a></p>" | |
| ) | |
| gr.HTML( | |
| "<p style='text-align: center; color: #ff00ff;'>👽 Discover how to build your own alien-approved search engine! Learn the secrets of ColPali and Qdrant, and join the ranks of interstellar code warriors. Warning: May attract Men in Black. 🕴️👽</p>" | |
| ) | |
| gr.HTML( | |
| "<marquee direction='left' scrollamount='5' class='yellow-text'>Uncover the truth that's out there! The government doesn't want you to know! ColPali will reveal the truth!</marquee>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Enter your cosmic query", | |
| placeholder="e.g., alien abduction, crop circles", | |
| ) | |
| with gr.Column(scale=1): | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| label="Number of classified documents", | |
| value=5, | |
| ) | |
| search_button = gr.Button("Declassify Documents") | |
| gallery_output = gr.Gallery(label="Declassified UFO Sightings", elem_id="gallery") | |
| search_button.click( | |
| fn=search_by_text_and_return_images, | |
| inputs=[query_input, num_results], | |
| outputs=gallery_output, | |
| ) | |
| # Keep the main block | |
| if __name__ == "__main__": | |
| demo.launch() | |