Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import os | |
| import tempfile | |
| import sqlite3 | |
| import json | |
| import hashlib | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple | |
| import PyPDF2 | |
| import docx | |
| import fitz # pymupdf | |
| from unstructured.partition.auto import partition | |
| os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| # PyLate imports | |
| from pylate import models, indexes, retrieve | |
| # Global variables for PyLate components | |
| model = None | |
| index = None | |
| retriever = None | |
| metadata_db = None | |
| # ===== DOCUMENT PROCESSING FUNCTIONS ===== | |
| def extract_text_from_pdf(file_path: str) -> str: | |
| """Extract text from PDF file.""" | |
| text = "" | |
| try: | |
| # Try PyMuPDF first (better for complex PDFs) | |
| doc = fitz.open(file_path) | |
| for page in doc: | |
| text += page.get_text() + "\n" | |
| doc.close() | |
| except: | |
| # Fallback to PyPDF2 | |
| try: | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| except: | |
| # Last resort: unstructured | |
| try: | |
| elements = partition(filename=file_path) | |
| text = "\n".join([str(element) for element in elements]) | |
| except: | |
| text = "Error: Could not extract text from PDF" | |
| return text.strip() | |
| def extract_text_from_docx(file_path: str) -> str: | |
| """Extract text from DOCX file.""" | |
| try: | |
| doc = docx.Document(file_path) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text.strip() | |
| except: | |
| return "Error: Could not extract text from DOCX" | |
| def extract_text_from_txt(file_path: str) -> str: | |
| """Extract text from TXT file.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| return file.read().strip() | |
| except: | |
| try: | |
| with open(file_path, 'r', encoding='latin1') as file: | |
| return file.read().strip() | |
| except: | |
| return "Error: Could not read text file" | |
| def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]: | |
| """Chunk text with overlap and return metadata.""" | |
| chunks = [] | |
| start = 0 | |
| chunk_index = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunk_text = text[start:end] | |
| # Try to break at sentence boundary | |
| if end < len(text): | |
| last_period = chunk_text.rfind('.') | |
| last_newline = chunk_text.rfind('\n') | |
| break_point = max(last_period, last_newline) | |
| if break_point > chunk_size * 0.7: | |
| chunk_text = chunk_text[:break_point + 1] | |
| end = start + break_point + 1 | |
| if chunk_text.strip(): | |
| chunks.append({ | |
| 'text': chunk_text.strip(), | |
| 'start': start, | |
| 'end': end, | |
| 'index': chunk_index, | |
| 'length': len(chunk_text.strip()) | |
| }) | |
| chunk_index += 1 | |
| start = max(start + 1, end - overlap) | |
| return chunks | |
| # ===== METADATA DATABASE ===== | |
| def init_metadata_db(): | |
| """Initialize SQLite database for metadata.""" | |
| global metadata_db | |
| db_path = "metadata.db" | |
| metadata_db = sqlite3.connect(db_path, check_same_thread=False) | |
| metadata_db.execute(""" | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| doc_id TEXT PRIMARY KEY, | |
| filename TEXT NOT NULL, | |
| file_hash TEXT NOT NULL, | |
| original_text TEXT NOT NULL, | |
| chunk_index INTEGER NOT NULL, | |
| total_chunks INTEGER NOT NULL, | |
| chunk_start INTEGER NOT NULL, | |
| chunk_end INTEGER NOT NULL, | |
| chunk_size INTEGER NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| metadata_db.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename); | |
| """) | |
| metadata_db.commit() | |
| def add_document_metadata(doc_id: str, filename: str, file_hash: str, | |
| original_text: str, chunk_info: Dict[str, Any], total_chunks: int): | |
| """Add document metadata to database.""" | |
| global metadata_db | |
| metadata_db.execute(""" | |
| INSERT OR REPLACE INTO documents | |
| (doc_id, filename, file_hash, original_text, chunk_index, total_chunks, | |
| chunk_start, chunk_end, chunk_size) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| doc_id, filename, file_hash, original_text, | |
| chunk_info['index'], total_chunks, | |
| chunk_info['start'], chunk_info['end'], chunk_info['length'] | |
| )) | |
| metadata_db.commit() | |
| def get_document_metadata(doc_id: str) -> Dict[str, Any]: | |
| """Get document metadata by ID.""" | |
| global metadata_db | |
| cursor = metadata_db.execute( | |
| "SELECT * FROM documents WHERE doc_id = ?", (doc_id,) | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| columns = [desc[0] for desc in cursor.description] | |
| return dict(zip(columns, row)) | |
| return {} | |
| # ===== PYLATE INITIALIZATION ===== | |
| def initialize_pylate(model_name: str = "lightonai/GTE-ModernColBERT-v1") -> str: | |
| """Initialize PyLate components on GPU.""" | |
| global model, index, retriever | |
| try: | |
| # Initialize metadata database | |
| init_metadata_db() | |
| # Load ColBERT model | |
| model = models.ColBERT(model_name_or_path=model_name) | |
| # Move to GPU if available | |
| if torch.cuda.is_available(): | |
| model = model.to('cuda') | |
| # Initialize PLAID index with CPU fallback for k-means | |
| index = indexes.PLAID( | |
| index_folder="./pylate_index", | |
| index_name="documents", | |
| override=True, | |
| kmeans_niters=1, # Reduce k-means iterations | |
| nbits=1 # Reduce quantization bits | |
| ) | |
| # Initialize retriever | |
| retriever = retrieve.ColBERT(index=index) | |
| return f"โ PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}" | |
| except Exception as e: | |
| return f"โ Error initializing PyLate: {str(e)}" | |
| # ===== DOCUMENT PROCESSING ===== | |
| def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str: | |
| """Process uploaded documents and add to index.""" | |
| global model, index, metadata_db | |
| if not model or not index: | |
| return "โ Please initialize PyLate first!" | |
| if not files: | |
| return "โ No files uploaded!" | |
| try: | |
| all_documents = [] | |
| all_doc_ids = [] | |
| processed_files = [] | |
| for file in files: | |
| # Get file info | |
| filename = Path(file.name).name | |
| file_path = file.name | |
| # Calculate file hash | |
| with open(file_path, 'rb') as f: | |
| file_hash = hashlib.md5(f.read()).hexdigest() | |
| # Extract text based on file type | |
| if filename.lower().endswith('.pdf'): | |
| text = extract_text_from_pdf(file_path) | |
| elif filename.lower().endswith('.docx'): | |
| text = extract_text_from_docx(file_path) | |
| elif filename.lower().endswith('.txt'): | |
| text = extract_text_from_txt(file_path) | |
| else: | |
| continue | |
| if not text or text.startswith("Error:"): | |
| continue | |
| # Chunk the text | |
| chunks = chunk_text(text, chunk_size, overlap) | |
| # Process each chunk | |
| for chunk in chunks: | |
| doc_id = f"{filename}_chunk_{chunk['index']}" | |
| all_documents.append(chunk['text']) | |
| all_doc_ids.append(doc_id) | |
| # Store metadata | |
| add_document_metadata( | |
| doc_id=doc_id, | |
| filename=filename, | |
| file_hash=file_hash, | |
| original_text=chunk['text'], | |
| chunk_info=chunk, | |
| total_chunks=len(chunks) | |
| ) | |
| processed_files.append(f"{filename}: {len(chunks)} chunks") | |
| if not all_documents: | |
| return "โ No text could be extracted from uploaded files!" | |
| # Encode documents with PyLate | |
| document_embeddings = model.encode( | |
| all_documents, | |
| batch_size=16, # Smaller batch for ZeroGPU | |
| is_query=False, | |
| show_progress_bar=True | |
| ) | |
| # Add to PLAID index | |
| index.add_documents( | |
| documents_ids=all_doc_ids, | |
| documents_embeddings=document_embeddings | |
| ) | |
| result = f"โ Successfully processed {len(files)} files:\n" | |
| result += f"๐ Total chunks: {len(all_documents)}\n" | |
| result += f"๐ Indexed documents:\n" | |
| for file_info in processed_files: | |
| result += f" โข {file_info}\n" | |
| return result | |
| except Exception as e: | |
| return f"โ Error processing documents: {str(e)}" | |
| # ===== SEARCH FUNCTION ===== | |
| def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str: | |
| """Search documents using PyLate.""" | |
| global model, retriever, metadata_db | |
| if not model or not retriever: | |
| return "โ Please initialize PyLate and process documents first!" | |
| if not query.strip(): | |
| return "โ Please enter a search query!" | |
| try: | |
| # Encode query | |
| query_embedding = model.encode([query], is_query=True) | |
| # Search | |
| results = retriever.retrieve(query_embedding, k=k)[0] | |
| if not results: | |
| return "๐ No results found for your query." | |
| # Format results with metadata | |
| formatted_results = [f"๐ **Search Results for:** '{query}'\n"] | |
| for i, result in enumerate(results): | |
| doc_id = result['id'] | |
| score = result['score'] | |
| # Get metadata | |
| metadata = get_document_metadata(doc_id) | |
| formatted_results.append(f"## Result {i+1} (Score: {score:.2f})") | |
| formatted_results.append( | |
| f"**File:** {metadata.get('filename', 'Unknown')}") | |
| formatted_results.append( | |
| f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}") | |
| if show_chunks: | |
| text = metadata.get('original_text', '') | |
| preview = text[:300] + "..." if len(text) > 300 else text | |
| formatted_results.append(f"**Text:** {preview}") | |
| formatted_results.append("---") | |
| return "\n".join(formatted_results) | |
| except Exception as e: | |
| return f"โ Error searching: {str(e)}" | |
| # ===== GRADIO INTERFACE ===== | |
| def create_interface(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ๐ PyLate Document Search | |
| ### Powered by ColBERT and ZeroGPU H100 | |
| Upload documents, process them with PyLate, and perform semantic search! | |
| """) | |
| with gr.Tab("๐ Setup"): | |
| gr.Markdown("### Initialize PyLate System") | |
| model_choice = gr.Dropdown( | |
| choices=[ | |
| # "lightonai/GTE-ModernColBERT-v1", | |
| "colbert-ir/colbertv2.0", | |
| "sentence-transformers/all-MiniLM-L6-v2" | |
| ], | |
| value="lightonai/GTE-ModernColBERT-v1", | |
| label="Select Model" | |
| ) | |
| init_btn = gr.Button("Initialize PyLate", variant="primary") | |
| init_status = gr.Textbox(label="Initialization Status", lines=3) | |
| init_btn.click( | |
| initialize_pylate, | |
| inputs=model_choice, | |
| outputs=init_status | |
| ) | |
| with gr.Tab("๐ Document Upload"): | |
| gr.Markdown("### Upload and Process Documents") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf", ".docx", ".txt"], | |
| label="Upload Documents (PDF, DOCX, TXT)" | |
| ) | |
| with gr.Row(): | |
| chunk_size = gr.Slider( | |
| minimum=500, | |
| maximum=3000, | |
| value=1000, | |
| step=100, | |
| label="Chunk Size (characters)" | |
| ) | |
| overlap = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=100, | |
| step=50, | |
| label="Chunk Overlap (characters)" | |
| ) | |
| process_btn = gr.Button( | |
| "Process Documents", variant="primary") | |
| with gr.Column(): | |
| process_status = gr.Textbox( | |
| label="Processing Status", | |
| lines=10, | |
| max_lines=15 | |
| ) | |
| process_btn.click( | |
| process_documents, | |
| inputs=[file_upload, chunk_size, overlap], | |
| outputs=process_status | |
| ) | |
| with gr.Tab("๐ Search"): | |
| gr.Markdown("### Search Your Documents") | |
| with gr.Row(): | |
| with gr.Column(): | |
| search_query = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Enter your search query...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Number of Results" | |
| ) | |
| show_chunks = gr.Checkbox( | |
| value=True, | |
| label="Show Text Chunks" | |
| ) | |
| search_btn = gr.Button("Search", variant="primary") | |
| with gr.Column(): | |
| search_results = gr.Textbox( | |
| label="Search Results", | |
| lines=15, | |
| max_lines=20 | |
| ) | |
| search_btn.click( | |
| search_documents, | |
| inputs=[search_query, num_results, show_chunks], | |
| outputs=search_results | |
| ) | |
| with gr.Tab("โน๏ธ Info"): | |
| gr.Markdown(""" | |
| ### About This System | |
| **PyLate Document Search** is a semantic search system that uses: | |
| - **PyLate**: A flexible library for ColBERT models | |
| - **ColBERT**: Late interaction retrieval for high-quality search | |
| - **ZeroGPU**: Hugging Face's free H100 GPU infrastructure | |
| #### Features: | |
| - ๐ Multi-format document support (PDF, DOCX, TXT) | |
| - โ๏ธ Intelligent text chunking with overlap | |
| - ๐ง Semantic search using ColBERT embeddings | |
| - ๐พ Metadata tracking for result context | |
| - โก GPU-accelerated processing | |
| #### Usage Tips: | |
| 1. Initialize the system first (required) | |
| 2. Upload your documents and process them | |
| 3. Use natural language queries for best results | |
| 4. Adjust chunk size based on your document types | |
| #### Model Information: | |
| - **GTE-ModernColBERT**: Latest high-performance model | |
| - **ColBERTv2**: Original Stanford implementation | |
| - **MiniLM**: Faster, smaller model for quick testing | |
| Built with โค๏ธ using PyLate and Gradio | |
| """) | |
| return demo | |
| # ===== MAIN ===== | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |