LFM2-ColBERT / app.py
mlabonne's picture
Update app.py
bc7fbc4 verified
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "transformers==4.56.2"])
import logging
from typing import List, Dict, Tuple
import gradio as gr
from pylate import indexes, models, retrieve
from documents import MULTILINGUAL_DOCUMENTS
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class CrossLingualRetriever:
"""Cross-lingual retrieval system using LiquidAI's LFM2-ColBERT model."""
def __init__(self, model_name: str = "LiquidAI/LFM2-ColBERT-350M-RC"):
"""Initialize the retriever with model and index."""
logger.info(f"Loading model: {model_name}")
self.model = models.ColBERT(model_name_or_path=model_name)
# Set padding token
self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
# Initialize PLAID index
self.index = indexes.PLAID(
index_folder="pylate-index",
index_name="index",
override=True,
)
self.retriever = retrieve.ColBERT(index=self.index)
self.documents_data = []
logger.info("Model and index initialized successfully")
def load_documents(self, documents: List[Dict[str, str]]) -> None:
"""Load and index multilingual documents."""
logger.info(f"Loading {len(documents)} documents")
self.documents_data = documents
documents_ids = [doc["id"] for doc in documents]
documents_text = [doc["text"] for doc in documents]
# Encode documents
documents_embeddings = self.model.encode(
documents_text,
batch_size=32,
is_query=False,
show_progress_bar=True,
)
# Add to index
self.index.add_documents(
documents_ids=documents_ids,
documents_embeddings=documents_embeddings,
)
logger.info("Documents indexed successfully")
def search(self, query: str, k: int = 5) -> List[Dict]:
"""Perform cross-lingual search."""
logger.info(f"Searching for: {query}")
# Encode query
query_embedding = self.model.encode(
[query],
batch_size=32,
is_query=True,
show_progress_bar=False,
)
# Retrieve results
scores = self.retriever.retrieve(
queries_embeddings=query_embedding,
k=k,
)
# Format results
results = []
for score in scores[0]:
doc = next((d for d in self.documents_data if d["id"] == score["id"]), None)
if doc:
results.append({
"id": score["id"],
"score": round(score["score"], 4),
"text": doc["text"],
"language": doc["language"],
"title": doc["title"],
"category": doc["category"]
})
return results
# Initialize retriever and load documents
retriever = CrossLingualRetriever()
retriever.load_documents(MULTILINGUAL_DOCUMENTS)
def format_results(results: List[Dict]) -> str:
"""Format search results as HTML for better visualization."""
if not results:
return "<div style='padding: 20px; text-align: center; color: #666;'>No results found</div>"
html = "<div style='font-family: Arial, sans-serif;'>"
for i, result in enumerate(results, 1):
score_color = "#22c55e" if result["score"] > 30 else "#eab308" if result["score"] > 20 else "#ef4444"
html += f"""
<div style='margin-bottom: 20px; padding: 15px; border: 1px solid #e5e7eb; border-radius: 8px; background: #f9fafb;'>
<div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;'>
<div>
<span style='font-weight: bold; font-size: 16px;'>#{i} {result["title"]}</span>
<span style='margin-left: 10px; padding: 2px 8px; background: #fef3c7; color: #92400e; border-radius: 4px; font-size: 12px;'>{result["category"]}</span>
<span style='margin-left: 5px; padding: 2px 8px; background: #dbeafe; color: #1e40af; border-radius: 4px; font-size: 12px;'>{result["language"]}</span>
</div>
<span style='padding: 4px 12px; background: {score_color}; color: white; border-radius: 4px; font-weight: bold;'>
Score: {result["score"]}
</span>
</div>
<div style='color: #374151; line-height: 1.6;'>
{result["text"]}
</div>
</div>
"""
html += "</div>"
return html
def search_documents(query: str, top_k: int) -> Tuple[str, str]:
"""Search documents and return formatted results."""
if not query.strip():
return "", "Please enter a search query."
try:
results = retriever.search(query, k=min(top_k, 10))
formatted_results = format_results(results)
# Create summary
if results:
languages_found = set(r["language"] for r in results)
summary = f"โœ… Found {len(results)} relevant documents across {len(languages_found)} language(s): {', '.join(sorted(languages_found))}"
else:
summary = "โŒ No relevant documents found."
return formatted_results, summary
except Exception as e:
logger.error(f"Search error: {e}")
return "", f"โŒ Error during search: {str(e)}"
# Example queries in different languages
EXAMPLE_QUERIES = [
["What is artificial intelligence?", 8],
["ยฟQuรฉ es el cambio climรกtico?", 4],
["์–‘์ž ์ปดํ“จํŒ…์ด๋ž€ ๋ฌด์—‡์ธ๊ฐ€์š”?", 6],
["ู…ุง ู‡ูŠ ุงู„ุตุญุฉ ุงู„ู†ูุณูŠุฉุŸ", 5],
["้‡ๅญ่ฎก็ฎ—ๆ˜ฏไป€ไนˆ๏ผŸ", 8],
]
# Build Gradio interface
with gr.Blocks(title="Cross-Lingual Retrieval Demo", theme=gr.themes.Soft(primary_hue="purple")) as demo:
gr.Markdown(
"""
# ๐ŸŒ Cross-Lingual Document Retrieval
### Powered by [LiquidAI/LFM2-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2-ColBERT-350M)
Find semantically similar documents across different languages.
**Supported Languages:** English, Arabic, Chinese, French, German, Japanese, Korean, and Spanish
"""
)
with gr.Row():
with gr.Column(scale=2):
query_input = gr.Textbox(
label="๐Ÿ” Enter your query",
placeholder="E.g., 'artificial intelligence', 'cambio climรกtico', 'energie renouvelable'...",
lines=2
)
top_k_slider = gr.Slider(
minimum=1,
maximum=12,
value=5,
step=1,
label="Number of results to retrieve",
)
search_btn = gr.Button("Search", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown(
"""
### ๐Ÿ“š Document corpus
- ๐Ÿค– **Technology** (16 docs): AI, Quantum Computing
- ๐ŸŒ **Environment** (16 docs): Climate, Biodiversity
- โšก **Energy** (8 docs): Renewable Sources
- ๐Ÿฅ **Health** (16 docs): Medicine, Mental Wellness
- ๐Ÿ’ผ **Business** (16 docs): Digital Economy, Startups
- ๐Ÿ“– **Education** (8 docs): Online Learning
- ๐ŸŽญ **Culture** (8 docs): Global Connectivity
- ๐Ÿš€ **Science** (8 docs): Space Exploration
"""
)
summary_output = gr.Textbox(
label="๐Ÿ“Š Search Summary",
interactive=False,
lines=2
)
results_output = gr.HTML(
label="๐ŸŽฏ Search Results"
)
# Event handlers
search_btn.click(
fn=search_documents,
inputs=[query_input, top_k_slider],
outputs=[results_output, summary_output]
)
query_input.submit(
fn=search_documents,
inputs=[query_input, top_k_slider],
outputs=[results_output, summary_output]
)
# Examples section
gr.Examples(
examples=EXAMPLE_QUERIES,
inputs=[query_input, top_k_slider],
outputs=[results_output, summary_output],
fn=search_documents,
cache_examples=False,
)
gr.Markdown(
"""**How it works:** This demo uses the [LiquidAI/LFM2-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2-ColBERT-350M) model with late interaction retrieval.
The model encodes both queries and documents into token-level embeddings, enabling fine-grained matching
across languages with high speed and accuracy."""
)
if __name__ == "__main__":
demo.launch()