Spaces:
Running
Running
| from qdrant_client import QdrantClient, models, grpc | |
| from qdrant_client.http.models import PointStruct, PayloadSchemaType | |
| from sentence_transformers import SentenceTransformer | |
| import uuid | |
| import os | |
| import logging | |
| from typing import List, Dict, Any | |
| from dotenv import load_dotenv | |
| import time | |
| import asyncio | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class VectorStore: | |
| def __init__(self): | |
| self.collection_name = "ca-documents" | |
| # Get Qdrant configuration from environment variables | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| if not qdrant_url or not qdrant_api_key: | |
| raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required") | |
| # Connect to Qdrant cluster with API key | |
| self.client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| prefer_grpc=True, | |
| ) | |
| print("Connected to Qdrant") | |
| # Initialize embedding model with offline support | |
| self.embedding_model = self._initialize_embedding_model() | |
| async def initialize(self): | |
| """Asynchronous initialization to be called after object creation.""" | |
| await self._ensure_collection_exists() | |
| def _initialize_embedding_model(self): | |
| """Initialize the embedding model from a local directory""" | |
| try: | |
| print("Loading sentence transformer model from local path...") | |
| # Resolve local path to model directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| local_model_path = os.path.join(current_dir, "..", "model", "all-MiniLM-L6-v2") | |
| model = SentenceTransformer(local_model_path) | |
| print("Successfully loaded local sentence transformer model") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load local model: {e}") | |
| raise RuntimeError("Failed to initialize embedding model from local path") | |
| async def _collection_exists_and_accessible(self) -> bool: | |
| """ | |
| Check if collection exists and is accessible by trying to get its info. | |
| Returns: | |
| bool: True if collection exists and is accessible | |
| """ | |
| try: | |
| # Try to get collection info - this is more reliable than just listing collections | |
| collection_info = self.client.get_collection(self.collection_name) | |
| print(f"Collection '{self.collection_name}' exists and is accessible") | |
| return True | |
| except Exception as e: | |
| print(f"Collection '{self.collection_name}' is not accessible: {e}") | |
| return False | |
| async def _create_collection(self) -> bool: | |
| """ | |
| Create the collection with proper configuration. | |
| Returns: | |
| bool: True if collection was created successfully or already exists | |
| """ | |
| try: | |
| print(f"Creating new collection: {self.collection_name}") | |
| # Vector size for all-MiniLM-L6-v2 is 384 | |
| vector_size = 384 | |
| # Create collection with vector configuration | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams( | |
| size=vector_size, | |
| distance=models.Distance.COSINE | |
| ), | |
| hnsw_config=models.HnswConfigDiff( | |
| payload_m=16, | |
| m=0, | |
| ), | |
| ) | |
| # Wait a moment for collection to be fully created | |
| await asyncio.sleep(1) | |
| # Create payload indices | |
| payload_indices = { | |
| "document_id": PayloadSchemaType.KEYWORD, | |
| "content": PayloadSchemaType.TEXT | |
| } | |
| for field_name, schema_type in payload_indices.items(): | |
| try: | |
| self.client.create_payload_index( | |
| collection_name=self.collection_name, | |
| field_name=field_name, | |
| field_schema=schema_type | |
| ) | |
| except Exception as idx_error: | |
| print(f"Warning: Failed to create index for {field_name}: {idx_error}") | |
| print(f"Successfully created collection: {self.collection_name}") | |
| return True | |
| except Exception as e: | |
| # Check if the error is because collection already exists | |
| if "already exists" in str(e).lower() or "ALREADY_EXISTS" in str(e): | |
| print(f"Collection '{self.collection_name}' already exists, using existing collection") | |
| return True | |
| error_msg = f"Failed to create collection {self.collection_name}: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| print(error_msg) | |
| return False | |
| async def _ensure_collection_exists(self) -> bool: | |
| """ | |
| Ensure collection exists and is accessible, create if necessary. | |
| Returns: | |
| bool: True if collection exists or was created successfully | |
| """ | |
| try: | |
| # First, check if collection exists and is accessible | |
| if await self._collection_exists_and_accessible(): | |
| print(f"Collection '{self.collection_name}' is ready to use") | |
| return True | |
| # If not accessible, try to create it (or verify it exists) | |
| print(f"Collection '{self.collection_name}' not immediately accessible, attempting to create/verify...") | |
| created = await self._create_collection() | |
| # After creation attempt, verify it's accessible | |
| if created and await self._collection_exists_and_accessible(): | |
| print(f"Collection '{self.collection_name}' is now ready to use") | |
| return True | |
| elif created: | |
| # Created successfully but not immediately accessible, which is okay | |
| print(f"Collection '{self.collection_name}' created/verified successfully") | |
| return True | |
| else: | |
| return False | |
| except Exception as e: | |
| error_msg = f"Failed to ensure collection exists: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| print(error_msg) | |
| return False | |
| async def add_document(self, text: str, metadata: Dict = None) -> bool: | |
| """Add a document to the collection with retry logic""" | |
| max_retries = 3 | |
| retry_delay = 1 | |
| for attempt in range(max_retries): | |
| try: | |
| # Ensure collection exists before adding document | |
| if not await self._collection_exists_and_accessible(): | |
| print("Collection not accessible, trying to recreate...") | |
| if not await self._create_collection(): | |
| raise Exception("Failed to create collection") | |
| # Generate embedding | |
| embedding = self.embedding_model.encode([text])[0] | |
| # Generate document ID | |
| document_id = str(uuid.uuid4()) | |
| # Create payload with indexed fields | |
| payload = { | |
| "document_id": document_id, # KEYWORD index | |
| "content": text, # TEXT index - stores the actual text content | |
| } | |
| # Add metadata fields if provided | |
| if metadata: | |
| payload.update(metadata) | |
| # Create point | |
| point = PointStruct( | |
| id=document_id, | |
| vector=embedding.tolist(), | |
| payload=payload | |
| ) | |
| # Store in Qdrant | |
| result = self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=[point] | |
| ) | |
| # Check if upsert was successful | |
| if hasattr(result, 'status') and result.status == 'completed': | |
| return True | |
| elif hasattr(result, 'operation_id'): | |
| return True | |
| else: | |
| print(f"Unexpected upsert result: {result}") | |
| return True # Assume success if no error was raised | |
| except Exception as e: | |
| print(f"Error adding document (attempt {attempt + 1}/{max_retries}): {e}") | |
| if "Not found" in str(e) and "doesn't exist" in str(e): | |
| # Collection doesn't exist, try to recreate | |
| print("Collection not found, attempting to recreate...") | |
| await self._create_collection() | |
| if attempt < max_retries - 1: | |
| print(f"Retrying in {retry_delay} seconds...") | |
| await asyncio.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| print(f"Failed to add document after {max_retries} attempts") | |
| return False | |
| return False | |
| async def search_similar(self, query: str, limit: int = 5) -> List[Dict]: | |
| """Search for similar documents with error handling""" | |
| try: | |
| # Ensure collection exists before searching | |
| if not await self._collection_exists_and_accessible(): | |
| print("Collection not accessible for search") | |
| return [] | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query])[0] | |
| # Search in Qdrant | |
| results = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding.tolist(), | |
| limit=limit | |
| ) | |
| # Return results | |
| return [ | |
| { | |
| "text": hit.payload["content"], # Use content field | |
| "document_id": hit.payload.get("document_id"), | |
| "score": hit.score, | |
| # Include any additional metadata fields | |
| **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} | |
| } | |
| for hit in results | |
| ] | |
| except Exception as e: | |
| print(f"Error searching: {e}") | |
| return [] | |
| async def get_collection_info(self) -> Dict: | |
| """Get information about the collection""" | |
| try: | |
| collection_info = self.client.get_collection(self.collection_name) | |
| return { | |
| "name": collection_info.config.name, | |
| "vector_size": collection_info.config.params.vectors.size, | |
| "distance": collection_info.config.params.vectors.distance, | |
| "points_count": collection_info.points_count, | |
| "indexed_only": collection_info.config.params.vectors.on_disk | |
| } | |
| except Exception as e: | |
| print(f"Error getting collection info: {e}") | |
| return {} | |
| async def verify_collection_health(self) -> bool: | |
| """Verify that the collection is healthy and accessible""" | |
| try: | |
| # Try to get collection info | |
| info = await self.get_collection_info() | |
| if not info: | |
| return False | |
| # Try a simple search to verify functionality | |
| test_results = await self.search_similar("test query", limit=1) | |
| # This should not fail even if no results are found | |
| print(f"Collection health check passed. Points count: {info.get('points_count', 0)}") | |
| return True | |
| except Exception as e: | |
| print(f"Collection health check failed: {e}") | |
| return False |