Spaces:
Running
Running
| import subprocess | |
| import sys | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "transformers==4.56.2"]) | |
| import logging | |
| from typing import List, Dict, Tuple | |
| import gradio as gr | |
| from pylate import indexes, models, retrieve | |
| from documents import MULTILINGUAL_DOCUMENTS | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class CrossLingualRetriever: | |
| """Cross-lingual retrieval system using LiquidAI's LFM2-ColBERT model.""" | |
| def __init__(self, model_name: str = "LiquidAI/LFM2-ColBERT-350M-RC"): | |
| """Initialize the retriever with model and index.""" | |
| logger.info(f"Loading model: {model_name}") | |
| self.model = models.ColBERT(model_name_or_path=model_name) | |
| # Set padding token | |
| self.model.tokenizer.pad_token = self.model.tokenizer.eos_token | |
| # Initialize PLAID index | |
| self.index = indexes.PLAID( | |
| index_folder="pylate-index", | |
| index_name="index", | |
| override=True, | |
| ) | |
| self.retriever = retrieve.ColBERT(index=self.index) | |
| self.documents_data = [] | |
| logger.info("Model and index initialized successfully") | |
| def load_documents(self, documents: List[Dict[str, str]]) -> None: | |
| """Load and index multilingual documents.""" | |
| logger.info(f"Loading {len(documents)} documents") | |
| self.documents_data = documents | |
| documents_ids = [doc["id"] for doc in documents] | |
| documents_text = [doc["text"] for doc in documents] | |
| # Encode documents | |
| documents_embeddings = self.model.encode( | |
| documents_text, | |
| batch_size=32, | |
| is_query=False, | |
| show_progress_bar=True, | |
| ) | |
| # Add to index | |
| self.index.add_documents( | |
| documents_ids=documents_ids, | |
| documents_embeddings=documents_embeddings, | |
| ) | |
| logger.info("Documents indexed successfully") | |
| def search(self, query: str, k: int = 5) -> List[Dict]: | |
| """Perform cross-lingual search.""" | |
| logger.info(f"Searching for: {query}") | |
| # Encode query | |
| query_embedding = self.model.encode( | |
| [query], | |
| batch_size=32, | |
| is_query=True, | |
| show_progress_bar=False, | |
| ) | |
| # Retrieve results | |
| scores = self.retriever.retrieve( | |
| queries_embeddings=query_embedding, | |
| k=k, | |
| ) | |
| # Format results | |
| results = [] | |
| for score in scores[0]: | |
| doc = next((d for d in self.documents_data if d["id"] == score["id"]), None) | |
| if doc: | |
| results.append({ | |
| "id": score["id"], | |
| "score": round(score["score"], 4), | |
| "text": doc["text"], | |
| "language": doc["language"], | |
| "title": doc["title"], | |
| "category": doc["category"] | |
| }) | |
| return results | |
| # Initialize retriever and load documents | |
| retriever = CrossLingualRetriever() | |
| retriever.load_documents(MULTILINGUAL_DOCUMENTS) | |
| def format_results(results: List[Dict]) -> str: | |
| """Format search results as HTML for better visualization.""" | |
| if not results: | |
| return "<div style='padding: 20px; text-align: center; color: #666;'>No results found</div>" | |
| html = "<div style='font-family: Arial, sans-serif;'>" | |
| for i, result in enumerate(results, 1): | |
| score_color = "#22c55e" if result["score"] > 30 else "#eab308" if result["score"] > 20 else "#ef4444" | |
| html += f""" | |
| <div style='margin-bottom: 20px; padding: 15px; border: 1px solid #e5e7eb; border-radius: 8px; background: #f9fafb;'> | |
| <div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;'> | |
| <div> | |
| <span style='font-weight: bold; font-size: 16px;'>#{i} {result["title"]}</span> | |
| <span style='margin-left: 10px; padding: 2px 8px; background: #fef3c7; color: #92400e; border-radius: 4px; font-size: 12px;'>{result["category"]}</span> | |
| <span style='margin-left: 5px; padding: 2px 8px; background: #dbeafe; color: #1e40af; border-radius: 4px; font-size: 12px;'>{result["language"]}</span> | |
| </div> | |
| <span style='padding: 4px 12px; background: {score_color}; color: white; border-radius: 4px; font-weight: bold;'> | |
| Score: {result["score"]} | |
| </span> | |
| </div> | |
| <div style='color: #374151; line-height: 1.6;'> | |
| {result["text"]} | |
| </div> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def search_documents(query: str, top_k: int) -> Tuple[str, str]: | |
| """Search documents and return formatted results.""" | |
| if not query.strip(): | |
| return "", "Please enter a search query." | |
| try: | |
| results = retriever.search(query, k=min(top_k, 10)) | |
| formatted_results = format_results(results) | |
| # Create summary | |
| if results: | |
| languages_found = set(r["language"] for r in results) | |
| summary = f"โ Found {len(results)} relevant documents across {len(languages_found)} language(s): {', '.join(sorted(languages_found))}" | |
| else: | |
| summary = "โ No relevant documents found." | |
| return formatted_results, summary | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| return "", f"โ Error during search: {str(e)}" | |
| # Example queries in different languages | |
| EXAMPLE_QUERIES = [ | |
| ["What is artificial intelligence?", 8], | |
| ["ยฟQuรฉ es el cambio climรกtico?", 4], | |
| ["์์ ์ปดํจํ ์ด๋ ๋ฌด์์ธ๊ฐ์?", 6], | |
| ["ู ุง ูู ุงูุตุญุฉ ุงูููุณูุฉุ", 5], | |
| ["้ๅญ่ฎก็ฎๆฏไปไน๏ผ", 8], | |
| ] | |
| # Build Gradio interface | |
| with gr.Blocks(title="Cross-Lingual Retrieval Demo", theme=gr.themes.Soft(primary_hue="purple")) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ Cross-Lingual Document Retrieval | |
| ### Powered by [LiquidAI/LFM2-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2-ColBERT-350M) | |
| Find semantically similar documents across different languages. | |
| **Supported Languages:** English, Arabic, Chinese, French, German, Japanese, Korean, and Spanish | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| label="๐ Enter your query", | |
| placeholder="E.g., 'artificial intelligence', 'cambio climรกtico', 'energie renouvelable'...", | |
| lines=2 | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=12, | |
| value=5, | |
| step=1, | |
| label="Number of results to retrieve", | |
| ) | |
| search_btn = gr.Button("Search", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| ### ๐ Document corpus | |
| - ๐ค **Technology** (16 docs): AI, Quantum Computing | |
| - ๐ **Environment** (16 docs): Climate, Biodiversity | |
| - โก **Energy** (8 docs): Renewable Sources | |
| - ๐ฅ **Health** (16 docs): Medicine, Mental Wellness | |
| - ๐ผ **Business** (16 docs): Digital Economy, Startups | |
| - ๐ **Education** (8 docs): Online Learning | |
| - ๐ญ **Culture** (8 docs): Global Connectivity | |
| - ๐ **Science** (8 docs): Space Exploration | |
| """ | |
| ) | |
| summary_output = gr.Textbox( | |
| label="๐ Search Summary", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| results_output = gr.HTML( | |
| label="๐ฏ Search Results" | |
| ) | |
| # Event handlers | |
| search_btn.click( | |
| fn=search_documents, | |
| inputs=[query_input, top_k_slider], | |
| outputs=[results_output, summary_output] | |
| ) | |
| query_input.submit( | |
| fn=search_documents, | |
| inputs=[query_input, top_k_slider], | |
| outputs=[results_output, summary_output] | |
| ) | |
| # Examples section | |
| gr.Examples( | |
| examples=EXAMPLE_QUERIES, | |
| inputs=[query_input, top_k_slider], | |
| outputs=[results_output, summary_output], | |
| fn=search_documents, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| """**How it works:** This demo uses the [LiquidAI/LFM2-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2-ColBERT-350M) model with late interaction retrieval. | |
| The model encodes both queries and documents into token-level embeddings, enabling fine-grained matching | |
| across languages with high speed and accuracy.""" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |