Spaces:
Running
Running
“vinit5112”
commited on
Commit
·
5b65de2
1
Parent(s):
82dac66
async changes
Browse files- backend/Qdrant.py +13 -11
- backend/backend_api.py +4 -4
- backend/rag.py +6 -8
- backend/vector_store.py +28 -26
backend/Qdrant.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
-
from qdrant_client import QdrantClient, models
|
| 3 |
-
from qdrant_client.models import PayloadSchemaType
|
| 4 |
import logging
|
| 5 |
from dotenv import load_dotenv
|
|
|
|
| 6 |
|
| 7 |
# Configure logging
|
| 8 |
logger = logging.getLogger(__name__)
|
|
@@ -10,8 +11,6 @@ logger = logging.getLogger(__name__)
|
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
# Configuration
|
| 13 |
-
# QDRANT_URL = "https://cc102304-2c06-4d51-9dee-d436f4413549.us-west-1-0.aws.cloud.qdrant.io"
|
| 14 |
-
# QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.cHs27o6erIf1BQHCdTxE4L4qZg4vCdrp51oNNNghjWM"
|
| 15 |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 16 |
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 17 |
|
|
@@ -20,10 +19,11 @@ class QdrantManager:
|
|
| 20 |
self.qdrant_client = QdrantClient(
|
| 21 |
url=QDRANT_URL,
|
| 22 |
api_key=QDRANT_API_KEY,
|
|
|
|
| 23 |
)
|
| 24 |
print("Connected to Qdrant")
|
| 25 |
|
| 26 |
-
def get_or_create_company_collection(self, collection_name: str) -> str:
|
| 27 |
"""
|
| 28 |
Get or create a collection for a company.
|
| 29 |
|
|
@@ -37,14 +37,13 @@ class QdrantManager:
|
|
| 37 |
ValueError: If collection creation fails
|
| 38 |
"""
|
| 39 |
try:
|
| 40 |
-
|
| 41 |
print(f"Creating new collection: {collection_name}")
|
| 42 |
|
| 43 |
# Vector size for text-embedding-3-small is 1536
|
| 44 |
vector_size = 384
|
| 45 |
|
| 46 |
# Create collection with vector configuration
|
| 47 |
-
self.qdrant_client.create_collection(
|
| 48 |
collection_name=collection_name,
|
| 49 |
vectors_config=models.VectorParams(
|
| 50 |
size=vector_size,
|
|
@@ -63,7 +62,7 @@ class QdrantManager:
|
|
| 63 |
}
|
| 64 |
|
| 65 |
for field_name, schema_type in payload_indices.items():
|
| 66 |
-
self.qdrant_client.create_payload_index(
|
| 67 |
collection_name=collection_name,
|
| 68 |
field_name=field_name,
|
| 69 |
field_schema=schema_type
|
|
@@ -78,11 +77,14 @@ class QdrantManager:
|
|
| 78 |
raise ValueError(error_msg) from e
|
| 79 |
|
| 80 |
# # Example usage
|
| 81 |
-
#
|
| 82 |
# try:
|
| 83 |
# qdrant_manager = QdrantManager()
|
| 84 |
# collection_name = "ca-documents"
|
| 85 |
-
# result = qdrant_manager.get_or_create_company_collection(collection_name)
|
| 86 |
# print(f"Collection name: {result}")
|
| 87 |
# except Exception as e:
|
| 88 |
-
# print(f"Error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
from qdrant_client import QdrantClient, models, grpc
|
| 3 |
+
from qdrant_client.http.models import PayloadSchemaType
|
| 4 |
import logging
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
+
import asyncio
|
| 7 |
|
| 8 |
# Configure logging
|
| 9 |
logger = logging.getLogger(__name__)
|
|
|
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
# Configuration
|
|
|
|
|
|
|
| 14 |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 15 |
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 16 |
|
|
|
|
| 19 |
self.qdrant_client = QdrantClient(
|
| 20 |
url=QDRANT_URL,
|
| 21 |
api_key=QDRANT_API_KEY,
|
| 22 |
+
prefer_grpc=True,
|
| 23 |
)
|
| 24 |
print("Connected to Qdrant")
|
| 25 |
|
| 26 |
+
async def get_or_create_company_collection(self, collection_name: str) -> str:
|
| 27 |
"""
|
| 28 |
Get or create a collection for a company.
|
| 29 |
|
|
|
|
| 37 |
ValueError: If collection creation fails
|
| 38 |
"""
|
| 39 |
try:
|
|
|
|
| 40 |
print(f"Creating new collection: {collection_name}")
|
| 41 |
|
| 42 |
# Vector size for text-embedding-3-small is 1536
|
| 43 |
vector_size = 384
|
| 44 |
|
| 45 |
# Create collection with vector configuration
|
| 46 |
+
await self.qdrant_client.create_collection(
|
| 47 |
collection_name=collection_name,
|
| 48 |
vectors_config=models.VectorParams(
|
| 49 |
size=vector_size,
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
for field_name, schema_type in payload_indices.items():
|
| 65 |
+
await self.qdrant_client.create_payload_index(
|
| 66 |
collection_name=collection_name,
|
| 67 |
field_name=field_name,
|
| 68 |
field_schema=schema_type
|
|
|
|
| 77 |
raise ValueError(error_msg) from e
|
| 78 |
|
| 79 |
# # Example usage
|
| 80 |
+
# async def main():
|
| 81 |
# try:
|
| 82 |
# qdrant_manager = QdrantManager()
|
| 83 |
# collection_name = "ca-documents"
|
| 84 |
+
# result = await qdrant_manager.get_or_create_company_collection(collection_name)
|
| 85 |
# print(f"Collection name: {result}")
|
| 86 |
# except Exception as e:
|
| 87 |
+
# print(f"Error: {e}")
|
| 88 |
+
|
| 89 |
+
# if __name__ == "__main__":
|
| 90 |
+
# asyncio.run(main())
|
backend/backend_api.py
CHANGED
|
@@ -96,7 +96,7 @@ async def ask_question_stream(request: QuestionRequest):
|
|
| 96 |
|
| 97 |
async def event_generator():
|
| 98 |
try:
|
| 99 |
-
for chunk in rag_system.ask_question_stream(request.question):
|
| 100 |
if chunk: # Only yield non-empty chunks
|
| 101 |
yield chunk
|
| 102 |
except Exception as e:
|
|
@@ -138,7 +138,7 @@ async def upload_document(file: UploadFile = File(...)):
|
|
| 138 |
try:
|
| 139 |
# Process the uploaded file
|
| 140 |
logger.info(f"Processing uploaded file: {file.filename}")
|
| 141 |
-
success = rag_system.upload_document(temp_file_path)
|
| 142 |
|
| 143 |
if success:
|
| 144 |
return {
|
|
@@ -170,7 +170,7 @@ async def search_documents(request: SearchRequest):
|
|
| 170 |
if not rag_system:
|
| 171 |
raise HTTPException(status_code=500, detail="RAG system not initialized")
|
| 172 |
|
| 173 |
-
results = rag_system.vector_store.search_similar(request.query, limit=request.limit)
|
| 174 |
|
| 175 |
return {
|
| 176 |
"status": "success",
|
|
@@ -210,7 +210,7 @@ async def get_collection_info():
|
|
| 210 |
if not rag_system:
|
| 211 |
raise HTTPException(status_code=500, detail="RAG system not initialized")
|
| 212 |
|
| 213 |
-
info = rag_system.vector_store.get_collection_info()
|
| 214 |
return {
|
| 215 |
"status": "success",
|
| 216 |
"collection_info": info
|
|
|
|
| 96 |
|
| 97 |
async def event_generator():
|
| 98 |
try:
|
| 99 |
+
async for chunk in rag_system.ask_question_stream(request.question):
|
| 100 |
if chunk: # Only yield non-empty chunks
|
| 101 |
yield chunk
|
| 102 |
except Exception as e:
|
|
|
|
| 138 |
try:
|
| 139 |
# Process the uploaded file
|
| 140 |
logger.info(f"Processing uploaded file: {file.filename}")
|
| 141 |
+
success = await rag_system.upload_document(temp_file_path)
|
| 142 |
|
| 143 |
if success:
|
| 144 |
return {
|
|
|
|
| 170 |
if not rag_system:
|
| 171 |
raise HTTPException(status_code=500, detail="RAG system not initialized")
|
| 172 |
|
| 173 |
+
results = await rag_system.vector_store.search_similar(request.query, limit=request.limit)
|
| 174 |
|
| 175 |
return {
|
| 176 |
"status": "success",
|
|
|
|
| 210 |
if not rag_system:
|
| 211 |
raise HTTPException(status_code=500, detail="RAG system not initialized")
|
| 212 |
|
| 213 |
+
info = await rag_system.vector_store.get_collection_info()
|
| 214 |
return {
|
| 215 |
"status": "success",
|
| 216 |
"collection_info": info
|
backend/rag.py
CHANGED
|
@@ -5,6 +5,7 @@ from docx import Document
|
|
| 5 |
from typing import List
|
| 6 |
import os
|
| 7 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
| 8 |
|
| 9 |
class RAG:
|
| 10 |
def __init__(self, google_api_key: str, collection_name: str = "ca-documents"):
|
|
@@ -17,10 +18,7 @@ class RAG:
|
|
| 17 |
self.vector_store = VectorStore()
|
| 18 |
|
| 19 |
# Verify vector store is properly initialized
|
| 20 |
-
|
| 21 |
-
print("Warning: Vector store collection health check failed")
|
| 22 |
-
else:
|
| 23 |
-
print("Vector store initialized successfully")
|
| 24 |
|
| 25 |
# Setup Text Splitter
|
| 26 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
@@ -53,7 +51,7 @@ class RAG:
|
|
| 53 |
chunks = self.text_splitter.split_text(full_text)
|
| 54 |
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
| 55 |
|
| 56 |
-
def upload_document(self, file_path: str) -> bool:
|
| 57 |
"""Upload and process document"""
|
| 58 |
try:
|
| 59 |
filename = os.path.basename(file_path)
|
|
@@ -73,7 +71,7 @@ class RAG:
|
|
| 73 |
|
| 74 |
# Store chunks in Qdrant
|
| 75 |
for i, chunk in enumerate(chunks):
|
| 76 |
-
self.vector_store.add_document(
|
| 77 |
text=chunk,
|
| 78 |
metadata={"source": filename, "chunk_id": i}
|
| 79 |
)
|
|
@@ -136,7 +134,7 @@ class RAG:
|
|
| 136 |
|
| 137 |
return False
|
| 138 |
|
| 139 |
-
def ask_question_stream(self, question: str):
|
| 140 |
"""Ask a question and get a streaming answer"""
|
| 141 |
try:
|
| 142 |
# 1. Check if this is casual conversation
|
|
@@ -154,7 +152,7 @@ Respond naturally and warmly as a CA study assistant. Be helpful and mention tha
|
|
| 154 |
return
|
| 155 |
|
| 156 |
# 2. For CA-specific questions, search for similar documents
|
| 157 |
-
similar_docs = self.vector_store.search_similar(question, limit=3)
|
| 158 |
|
| 159 |
if similar_docs and len(similar_docs) > 0:
|
| 160 |
# 3. Create context from similar documents
|
|
|
|
| 5 |
from typing import List
|
| 6 |
import os
|
| 7 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 8 |
+
import asyncio
|
| 9 |
|
| 10 |
class RAG:
|
| 11 |
def __init__(self, google_api_key: str, collection_name: str = "ca-documents"):
|
|
|
|
| 18 |
self.vector_store = VectorStore()
|
| 19 |
|
| 20 |
# Verify vector store is properly initialized
|
| 21 |
+
asyncio.run(self.vector_store.verify_collection_health())
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Setup Text Splitter
|
| 24 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
| 51 |
chunks = self.text_splitter.split_text(full_text)
|
| 52 |
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
| 53 |
|
| 54 |
+
async def upload_document(self, file_path: str) -> bool:
|
| 55 |
"""Upload and process document"""
|
| 56 |
try:
|
| 57 |
filename = os.path.basename(file_path)
|
|
|
|
| 71 |
|
| 72 |
# Store chunks in Qdrant
|
| 73 |
for i, chunk in enumerate(chunks):
|
| 74 |
+
await self.vector_store.add_document(
|
| 75 |
text=chunk,
|
| 76 |
metadata={"source": filename, "chunk_id": i}
|
| 77 |
)
|
|
|
|
| 134 |
|
| 135 |
return False
|
| 136 |
|
| 137 |
+
async def ask_question_stream(self, question: str):
|
| 138 |
"""Ask a question and get a streaming answer"""
|
| 139 |
try:
|
| 140 |
# 1. Check if this is casual conversation
|
|
|
|
| 152 |
return
|
| 153 |
|
| 154 |
# 2. For CA-specific questions, search for similar documents
|
| 155 |
+
similar_docs = await self.vector_store.search_similar(question, limit=3)
|
| 156 |
|
| 157 |
if similar_docs and len(similar_docs) > 0:
|
| 158 |
# 3. Create context from similar documents
|
backend/vector_store.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from qdrant_client import QdrantClient, models
|
| 2 |
-
from qdrant_client.models import PointStruct, PayloadSchemaType
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
import uuid
|
| 5 |
import os
|
|
@@ -7,6 +7,7 @@ import logging
|
|
| 7 |
from typing import List, Dict, Any
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
# Load environment variables
|
| 12 |
load_dotenv()
|
|
@@ -29,6 +30,7 @@ class VectorStore:
|
|
| 29 |
self.client = QdrantClient(
|
| 30 |
url=qdrant_url,
|
| 31 |
api_key=qdrant_api_key,
|
|
|
|
| 32 |
)
|
| 33 |
print("Connected to Qdrant")
|
| 34 |
|
|
@@ -36,7 +38,7 @@ class VectorStore:
|
|
| 36 |
self.embedding_model = self._initialize_embedding_model()
|
| 37 |
|
| 38 |
# Create collection with proper indices
|
| 39 |
-
self._ensure_collection_exists()
|
| 40 |
|
| 41 |
def _initialize_embedding_model(self):
|
| 42 |
"""Initialize the embedding model from a local directory"""
|
|
@@ -52,7 +54,7 @@ class VectorStore:
|
|
| 52 |
print(f"Failed to load local model: {e}")
|
| 53 |
raise RuntimeError("Failed to initialize embedding model from local path")
|
| 54 |
|
| 55 |
-
def _collection_exists_and_accessible(self) -> bool:
|
| 56 |
"""
|
| 57 |
Check if collection exists and is accessible by trying to get its info.
|
| 58 |
|
|
@@ -61,14 +63,14 @@ class VectorStore:
|
|
| 61 |
"""
|
| 62 |
try:
|
| 63 |
# Try to get collection info - this is more reliable than just listing collections
|
| 64 |
-
collection_info = self.client.get_collection(self.collection_name)
|
| 65 |
print(f"Collection '{self.collection_name}' exists and is accessible")
|
| 66 |
return True
|
| 67 |
except Exception as e:
|
| 68 |
print(f"Collection '{self.collection_name}' is not accessible: {e}")
|
| 69 |
return False
|
| 70 |
|
| 71 |
-
def _create_collection(self) -> bool:
|
| 72 |
"""
|
| 73 |
Create the collection with proper configuration.
|
| 74 |
|
|
@@ -82,7 +84,7 @@ class VectorStore:
|
|
| 82 |
vector_size = 384
|
| 83 |
|
| 84 |
# Create collection with vector configuration
|
| 85 |
-
self.client.create_collection(
|
| 86 |
collection_name=self.collection_name,
|
| 87 |
vectors_config=models.VectorParams(
|
| 88 |
size=vector_size,
|
|
@@ -95,7 +97,7 @@ class VectorStore:
|
|
| 95 |
)
|
| 96 |
|
| 97 |
# Wait a moment for collection to be fully created
|
| 98 |
-
|
| 99 |
|
| 100 |
# Create payload indices
|
| 101 |
payload_indices = {
|
|
@@ -105,7 +107,7 @@ class VectorStore:
|
|
| 105 |
|
| 106 |
for field_name, schema_type in payload_indices.items():
|
| 107 |
try:
|
| 108 |
-
self.client.create_payload_index(
|
| 109 |
collection_name=self.collection_name,
|
| 110 |
field_name=field_name,
|
| 111 |
field_schema=schema_type
|
|
@@ -122,7 +124,7 @@ class VectorStore:
|
|
| 122 |
print(error_msg)
|
| 123 |
return False
|
| 124 |
|
| 125 |
-
def _ensure_collection_exists(self) -> bool:
|
| 126 |
"""
|
| 127 |
Ensure collection exists and is accessible, create if necessary.
|
| 128 |
|
|
@@ -131,12 +133,12 @@ class VectorStore:
|
|
| 131 |
"""
|
| 132 |
try:
|
| 133 |
# First, check if collection exists and is accessible
|
| 134 |
-
if self._collection_exists_and_accessible():
|
| 135 |
return True
|
| 136 |
|
| 137 |
# If not accessible, try to create it
|
| 138 |
print(f"Collection '{self.collection_name}' not found or not accessible, creating...")
|
| 139 |
-
return self._create_collection()
|
| 140 |
|
| 141 |
except Exception as e:
|
| 142 |
error_msg = f"Failed to ensure collection exists: {str(e)}"
|
|
@@ -144,7 +146,7 @@ class VectorStore:
|
|
| 144 |
print(error_msg)
|
| 145 |
return False
|
| 146 |
|
| 147 |
-
def add_document(self, text: str, metadata: Dict = None) -> bool:
|
| 148 |
"""Add a document to the collection with retry logic"""
|
| 149 |
max_retries = 3
|
| 150 |
retry_delay = 1
|
|
@@ -152,9 +154,9 @@ class VectorStore:
|
|
| 152 |
for attempt in range(max_retries):
|
| 153 |
try:
|
| 154 |
# Ensure collection exists before adding document
|
| 155 |
-
if not self._collection_exists_and_accessible():
|
| 156 |
print("Collection not accessible, trying to recreate...")
|
| 157 |
-
if not self._create_collection():
|
| 158 |
raise Exception("Failed to create collection")
|
| 159 |
|
| 160 |
# Generate embedding
|
|
@@ -181,7 +183,7 @@ class VectorStore:
|
|
| 181 |
)
|
| 182 |
|
| 183 |
# Store in Qdrant
|
| 184 |
-
result = self.client.upsert(
|
| 185 |
collection_name=self.collection_name,
|
| 186 |
points=[point]
|
| 187 |
)
|
|
@@ -200,11 +202,11 @@ class VectorStore:
|
|
| 200 |
if "Not found" in str(e) and "doesn't exist" in str(e):
|
| 201 |
# Collection doesn't exist, try to recreate
|
| 202 |
print("Collection not found, attempting to recreate...")
|
| 203 |
-
self._create_collection()
|
| 204 |
|
| 205 |
if attempt < max_retries - 1:
|
| 206 |
print(f"Retrying in {retry_delay} seconds...")
|
| 207 |
-
|
| 208 |
retry_delay *= 2 # Exponential backoff
|
| 209 |
else:
|
| 210 |
print(f"Failed to add document after {max_retries} attempts")
|
|
@@ -212,11 +214,11 @@ class VectorStore:
|
|
| 212 |
|
| 213 |
return False
|
| 214 |
|
| 215 |
-
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
|
| 216 |
"""Search for similar documents with error handling"""
|
| 217 |
try:
|
| 218 |
# Ensure collection exists before searching
|
| 219 |
-
if not self._collection_exists_and_accessible():
|
| 220 |
print("Collection not accessible for search")
|
| 221 |
return []
|
| 222 |
|
|
@@ -224,7 +226,7 @@ class VectorStore:
|
|
| 224 |
query_embedding = self.embedding_model.encode([query])[0]
|
| 225 |
|
| 226 |
# Search in Qdrant
|
| 227 |
-
results = self.client.search(
|
| 228 |
collection_name=self.collection_name,
|
| 229 |
query_vector=query_embedding.tolist(),
|
| 230 |
limit=limit
|
|
@@ -246,10 +248,10 @@ class VectorStore:
|
|
| 246 |
print(f"Error searching: {e}")
|
| 247 |
return []
|
| 248 |
|
| 249 |
-
def get_collection_info(self) -> Dict:
|
| 250 |
"""Get information about the collection"""
|
| 251 |
try:
|
| 252 |
-
collection_info = self.client.get_collection(self.collection_name)
|
| 253 |
return {
|
| 254 |
"name": collection_info.config.name,
|
| 255 |
"vector_size": collection_info.config.params.vectors.size,
|
|
@@ -261,16 +263,16 @@ class VectorStore:
|
|
| 261 |
print(f"Error getting collection info: {e}")
|
| 262 |
return {}
|
| 263 |
|
| 264 |
-
def verify_collection_health(self) -> bool:
|
| 265 |
"""Verify that the collection is healthy and accessible"""
|
| 266 |
try:
|
| 267 |
# Try to get collection info
|
| 268 |
-
info = self.get_collection_info()
|
| 269 |
if not info:
|
| 270 |
return False
|
| 271 |
|
| 272 |
# Try a simple search to verify functionality
|
| 273 |
-
test_results = self.search_similar("test query", limit=1)
|
| 274 |
# This should not fail even if no results are found
|
| 275 |
|
| 276 |
print(f"Collection health check passed. Points count: {info.get('points_count', 0)}")
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient, models, grpc
|
| 2 |
+
from qdrant_client.http.models import PointStruct, PayloadSchemaType
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
import uuid
|
| 5 |
import os
|
|
|
|
| 7 |
from typing import List, Dict, Any
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
import time
|
| 10 |
+
import asyncio
|
| 11 |
|
| 12 |
# Load environment variables
|
| 13 |
load_dotenv()
|
|
|
|
| 30 |
self.client = QdrantClient(
|
| 31 |
url=qdrant_url,
|
| 32 |
api_key=qdrant_api_key,
|
| 33 |
+
prefer_grpc=True,
|
| 34 |
)
|
| 35 |
print("Connected to Qdrant")
|
| 36 |
|
|
|
|
| 38 |
self.embedding_model = self._initialize_embedding_model()
|
| 39 |
|
| 40 |
# Create collection with proper indices
|
| 41 |
+
asyncio.run(self._ensure_collection_exists())
|
| 42 |
|
| 43 |
def _initialize_embedding_model(self):
|
| 44 |
"""Initialize the embedding model from a local directory"""
|
|
|
|
| 54 |
print(f"Failed to load local model: {e}")
|
| 55 |
raise RuntimeError("Failed to initialize embedding model from local path")
|
| 56 |
|
| 57 |
+
async def _collection_exists_and_accessible(self) -> bool:
|
| 58 |
"""
|
| 59 |
Check if collection exists and is accessible by trying to get its info.
|
| 60 |
|
|
|
|
| 63 |
"""
|
| 64 |
try:
|
| 65 |
# Try to get collection info - this is more reliable than just listing collections
|
| 66 |
+
collection_info = await self.client.get_collection(self.collection_name)
|
| 67 |
print(f"Collection '{self.collection_name}' exists and is accessible")
|
| 68 |
return True
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Collection '{self.collection_name}' is not accessible: {e}")
|
| 71 |
return False
|
| 72 |
|
| 73 |
+
async def _create_collection(self) -> bool:
|
| 74 |
"""
|
| 75 |
Create the collection with proper configuration.
|
| 76 |
|
|
|
|
| 84 |
vector_size = 384
|
| 85 |
|
| 86 |
# Create collection with vector configuration
|
| 87 |
+
await self.client.create_collection(
|
| 88 |
collection_name=self.collection_name,
|
| 89 |
vectors_config=models.VectorParams(
|
| 90 |
size=vector_size,
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Wait a moment for collection to be fully created
|
| 100 |
+
await asyncio.sleep(1)
|
| 101 |
|
| 102 |
# Create payload indices
|
| 103 |
payload_indices = {
|
|
|
|
| 107 |
|
| 108 |
for field_name, schema_type in payload_indices.items():
|
| 109 |
try:
|
| 110 |
+
await self.client.create_payload_index(
|
| 111 |
collection_name=self.collection_name,
|
| 112 |
field_name=field_name,
|
| 113 |
field_schema=schema_type
|
|
|
|
| 124 |
print(error_msg)
|
| 125 |
return False
|
| 126 |
|
| 127 |
+
async def _ensure_collection_exists(self) -> bool:
|
| 128 |
"""
|
| 129 |
Ensure collection exists and is accessible, create if necessary.
|
| 130 |
|
|
|
|
| 133 |
"""
|
| 134 |
try:
|
| 135 |
# First, check if collection exists and is accessible
|
| 136 |
+
if await self._collection_exists_and_accessible():
|
| 137 |
return True
|
| 138 |
|
| 139 |
# If not accessible, try to create it
|
| 140 |
print(f"Collection '{self.collection_name}' not found or not accessible, creating...")
|
| 141 |
+
return await self._create_collection()
|
| 142 |
|
| 143 |
except Exception as e:
|
| 144 |
error_msg = f"Failed to ensure collection exists: {str(e)}"
|
|
|
|
| 146 |
print(error_msg)
|
| 147 |
return False
|
| 148 |
|
| 149 |
+
async def add_document(self, text: str, metadata: Dict = None) -> bool:
|
| 150 |
"""Add a document to the collection with retry logic"""
|
| 151 |
max_retries = 3
|
| 152 |
retry_delay = 1
|
|
|
|
| 154 |
for attempt in range(max_retries):
|
| 155 |
try:
|
| 156 |
# Ensure collection exists before adding document
|
| 157 |
+
if not await self._collection_exists_and_accessible():
|
| 158 |
print("Collection not accessible, trying to recreate...")
|
| 159 |
+
if not await self._create_collection():
|
| 160 |
raise Exception("Failed to create collection")
|
| 161 |
|
| 162 |
# Generate embedding
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
# Store in Qdrant
|
| 186 |
+
result = await self.client.upsert(
|
| 187 |
collection_name=self.collection_name,
|
| 188 |
points=[point]
|
| 189 |
)
|
|
|
|
| 202 |
if "Not found" in str(e) and "doesn't exist" in str(e):
|
| 203 |
# Collection doesn't exist, try to recreate
|
| 204 |
print("Collection not found, attempting to recreate...")
|
| 205 |
+
await self._create_collection()
|
| 206 |
|
| 207 |
if attempt < max_retries - 1:
|
| 208 |
print(f"Retrying in {retry_delay} seconds...")
|
| 209 |
+
await asyncio.sleep(retry_delay)
|
| 210 |
retry_delay *= 2 # Exponential backoff
|
| 211 |
else:
|
| 212 |
print(f"Failed to add document after {max_retries} attempts")
|
|
|
|
| 214 |
|
| 215 |
return False
|
| 216 |
|
| 217 |
+
async def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
|
| 218 |
"""Search for similar documents with error handling"""
|
| 219 |
try:
|
| 220 |
# Ensure collection exists before searching
|
| 221 |
+
if not await self._collection_exists_and_accessible():
|
| 222 |
print("Collection not accessible for search")
|
| 223 |
return []
|
| 224 |
|
|
|
|
| 226 |
query_embedding = self.embedding_model.encode([query])[0]
|
| 227 |
|
| 228 |
# Search in Qdrant
|
| 229 |
+
results = await self.client.search(
|
| 230 |
collection_name=self.collection_name,
|
| 231 |
query_vector=query_embedding.tolist(),
|
| 232 |
limit=limit
|
|
|
|
| 248 |
print(f"Error searching: {e}")
|
| 249 |
return []
|
| 250 |
|
| 251 |
+
async def get_collection_info(self) -> Dict:
|
| 252 |
"""Get information about the collection"""
|
| 253 |
try:
|
| 254 |
+
collection_info = await self.client.get_collection(self.collection_name)
|
| 255 |
return {
|
| 256 |
"name": collection_info.config.name,
|
| 257 |
"vector_size": collection_info.config.params.vectors.size,
|
|
|
|
| 263 |
print(f"Error getting collection info: {e}")
|
| 264 |
return {}
|
| 265 |
|
| 266 |
+
async def verify_collection_health(self) -> bool:
|
| 267 |
"""Verify that the collection is healthy and accessible"""
|
| 268 |
try:
|
| 269 |
# Try to get collection info
|
| 270 |
+
info = await self.get_collection_info()
|
| 271 |
if not info:
|
| 272 |
return False
|
| 273 |
|
| 274 |
# Try a simple search to verify functionality
|
| 275 |
+
test_results = await self.search_similar("test query", limit=1)
|
| 276 |
# This should not fail even if no results are found
|
| 277 |
|
| 278 |
print(f"Collection health check passed. Points count: {info.get('points_count', 0)}")
|