Spaces:
Sleeping
Sleeping
| import os | |
| from qdrant_client import QdrantClient, models, grpc | |
| from qdrant_client.http.models import PayloadSchemaType | |
| import logging | |
| from dotenv import load_dotenv | |
| import asyncio | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| # Configuration | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| QDRANT_URL = os.getenv("QDRANT_URL") | |
| class QdrantManager: | |
| def __init__(self): | |
| self.qdrant_client = QdrantClient( | |
| url=QDRANT_URL, | |
| api_key=QDRANT_API_KEY, | |
| prefer_grpc=True, | |
| ) | |
| print("Connected to Qdrant") | |
| async def get_or_create_company_collection(self, collection_name: str) -> str: | |
| """ | |
| Get or create a collection for a company. | |
| Args: | |
| collection_name: Name of the collection | |
| Returns: | |
| str: Collection name | |
| Raises: | |
| ValueError: If collection creation fails | |
| """ | |
| try: | |
| print(f"Creating new collection: {collection_name}") | |
| # Vector size for text-embedding-3-small is 1536 | |
| vector_size = 384 | |
| # Create collection with vector configuration | |
| self.qdrant_client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=models.VectorParams( | |
| size=vector_size, | |
| distance=models.Distance.COSINE | |
| ), | |
| hnsw_config=models.HnswConfigDiff( | |
| payload_m=16, | |
| m=0, | |
| ), | |
| ) | |
| # Create payload indices | |
| payload_indices = { | |
| "document_id": PayloadSchemaType.KEYWORD, | |
| "content": PayloadSchemaType.TEXT | |
| } | |
| for field_name, schema_type in payload_indices.items(): | |
| self.qdrant_client.create_payload_index( | |
| collection_name=collection_name, | |
| field_name=field_name, | |
| field_schema=schema_type | |
| ) | |
| print(f"Successfully created collection: {collection_name}") | |
| return collection_name | |
| except Exception as e: | |
| error_msg = f"Failed to create collection {collection_name}: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| raise ValueError(error_msg) from e | |
| # # Example usage | |
| # async def main(): | |
| # try: | |
| # qdrant_manager = QdrantManager() | |
| # collection_name = "ca-documents" | |
| # result = await qdrant_manager.get_or_create_company_collection(collection_name) | |
| # print(f"Collection name: {result}") | |
| # except Exception as e: | |
| # print(f"Error: {e}") | |
| # if __name__ == "__main__": | |
| # asyncio.run(main()) |