File size: 11,208 Bytes
729a1f7
 
 
 
 
 
 
 
20fd4b7
0fee802
729a1f7
 
 
 
 
20fd4b7
729a1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20fd4b7
729a1f7
20fd4b7
729a1f7
 
 
20fd4b7
729a1f7
 
20fd4b7
 
729a1f7
 
 
18b46d7
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
20fd4b7
18b46d7
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
20fd4b7
 
18b46d7
20fd4b7
 
18b46d7
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
c48df01
 
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
c48df01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
 
 
 
 
20fd4b7
729a1f7
20fd4b7
729a1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# ────────────────────────────── utils/rag.py ──────────────────────────────
import os
import math
from typing import List, Dict, Any, Optional
from pymongo import MongoClient, ASCENDING, TEXT
from pymongo.collection import Collection
from pymongo.errors import PyMongoError
import numpy as np
import logging
from ..logger import get_logger

VECTOR_DIM = 384  # all-MiniLM-L6-v2
INDEX_NAME = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
USE_ATLAS_VECTOR = os.getenv("ATLAS_VECTOR", "0") == "1"

logger = get_logger("RAG", __name__)


class RAGStore:
    def __init__(self, mongo_uri: str, db_name: str = "studybuddy"):
        self.client = MongoClient(mongo_uri)
        self.db = self.client[db_name]
        self.chunks: Collection = self.db["chunks"]
        self.files: Collection = self.db["files"]

    # ── Write ────────────────────────────────────────────────────────────────
    def store_cards(self, cards: List[Dict[str, Any]]):
        if not cards:
            return
        for c in cards:
            # basic validation
            emb = c.get("embedding")
            if not emb or len(emb) != VECTOR_DIM:
                raise ValueError("Invalid embedding length; expected %d" % VECTOR_DIM)
        self.chunks.insert_many(cards, ordered=False)
        logger.info(f"Inserted {len(cards)} cards into MongoDB")

    def upsert_file_summary(self, user_id: str, project_id: str, filename: str, summary: str):
        self.files.update_one(
            {"user_id": user_id, "project_id": project_id, "filename": filename},
            {"$set": {"summary": summary}},
            upsert=True
        )
        logger.info(f"Upserted summary for {filename} (user {user_id}, project {project_id})")

    # ── Read ────────────────────────────────────────────────────────────────
    def list_cards(self, user_id: str, project_id: str, filename: Optional[str], limit: int, skip: int):
        q = {"user_id": user_id, "project_id": project_id}
        if filename:
            q["filename"] = filename
        cur = self.chunks.find(q, {"embedding": 0}).skip(skip).limit(limit).sort([("_id", ASCENDING)])
        # Convert MongoDB documents to JSON-serializable format
        cards = []
        for card in cur:
            serializable_card = {}
            for key, value in card.items():
                if key == '_id':
                    serializable_card[key] = str(value)  # Convert ObjectId to string
                elif hasattr(value, 'isoformat'):  # Handle datetime objects
                    serializable_card[key] = value.isoformat()
                else:
                    serializable_card[key] = value
            cards.append(serializable_card)
        return cards

    def get_file_summary(self, user_id: str, project_id: str, filename: str):
        doc = self.files.find_one({"user_id": user_id, "project_id": project_id, "filename": filename})
        if doc:
            # Convert MongoDB document to JSON-serializable format
            serializable_doc = {}
            for key, value in doc.items():
                if key == '_id':
                    serializable_doc[key] = str(value)  # Convert ObjectId to string
                elif hasattr(value, 'isoformat'):  # Handle datetime objects
                    serializable_doc[key] = value.isoformat()
                else:
                    serializable_doc[key] = value
            return serializable_doc
        return None

    def list_files(self, user_id: str, project_id: str):
        """List all files for a project with their summaries"""
        files_cursor = self.files.find(
            {"user_id": user_id, "project_id": project_id},
            {"_id": 0, "filename": 1, "summary": 1}
        ).sort("filename", ASCENDING)
        
        # Convert MongoDB documents to JSON-serializable format
        files = []
        for file_doc in files_cursor:
            serializable_file = {}
            for key, value in file_doc.items():
                if hasattr(value, 'isoformat'):  # Handle datetime objects
                    serializable_file[key] = value.isoformat()
                else:
                    serializable_file[key] = value
            files.append(serializable_file)
        return files

    def vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int = 6, filenames: Optional[List[str]] = None, search_type: str = "hybrid"):
        """
        Enhanced vector search with multiple strategies:
        - hybrid: Combines Atlas and local search
        - flat: Exhaustive search for maximum accuracy
        - atlas: Uses Atlas Vector Search only
        - local: Uses local cosine similarity only
        """
        if search_type == "flat" or (search_type == "hybrid" and not USE_ATLAS_VECTOR):
            return self._flat_vector_search(user_id, project_id, query_vector, k, filenames)
        elif search_type == "atlas" and USE_ATLAS_VECTOR:
            return self._atlas_vector_search(user_id, project_id, query_vector, k, filenames)
        elif search_type == "local":
            return self._local_vector_search(user_id, project_id, query_vector, k, filenames)
        else:
            # Default hybrid approach
            if USE_ATLAS_VECTOR:
                atlas_results = self._atlas_vector_search(user_id, project_id, query_vector, k, filenames)
                if atlas_results:
                    return atlas_results
            return self._local_vector_search(user_id, project_id, query_vector, k, filenames)

    def _atlas_vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int, filenames: Optional[List[str]] = None):
        """Atlas Vector Search implementation"""
        match_stage = {"user_id": user_id, "project_id": project_id}
        if filenames:
            match_stage["filename"] = {"$in": filenames}
        pipeline = [
            {
                "$search": {
                    "index": INDEX_NAME,
                    "knnBeta": {
                        "vector": query_vector,
                        "path": "embedding",
                        "k": k,
                    }
                }
            },
            {"$match": match_stage},
            {"$project": {"doc": "$$ROOT", "score": {"$meta": "searchScore"}}},
            {"$limit": k},
        ]
        hits = list(self.chunks.aggregate(pipeline))
        return self._serialize_hits(hits)

    def _local_vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int, filenames: Optional[List[str]] = None):
        """Local cosine similarity search with improved sampling"""
        q = {"user_id": user_id, "project_id": project_id}
        if filenames:
            q["filename"] = {"$in": filenames}
        
        # Increase sample size for better accuracy
        sample_limit = max(5000, k * 50)
        sample = list(self.chunks.find(q).sort([("_id", -1)]).limit(sample_limit))
        if not sample:
            return []
        
        qv = np.array(query_vector, dtype="float32")
        scores = []
        
        for d in sample:
            v = np.array(d.get("embedding", [0]*VECTOR_DIM), dtype="float32")
            denom = (np.linalg.norm(qv) * np.linalg.norm(v)) or 1.0
            sim = float(np.dot(qv, v) / denom)
            scores.append((sim, d))
        
        scores.sort(key=lambda x: x[0], reverse=True)
        top = scores[:k]
        logger.info(f"Local vector search: {len(sample)} docs sampled, {len(top)} results")
        
        return self._serialize_results(top)

    def _flat_vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int, filenames: Optional[List[str]] = None):
        """Flat exhaustive search for maximum accuracy"""
        q = {"user_id": user_id, "project_id": project_id}
        if filenames:
            q["filename"] = {"$in": filenames}
        
        # Get ALL relevant documents for exhaustive search
        all_docs = list(self.chunks.find(q))
        if not all_docs:
            return []
        
        qv = np.array(query_vector, dtype="float32")
        scores = []
        
        for doc in all_docs:
            v = np.array(doc.get("embedding", [0]*VECTOR_DIM), dtype="float32")
            denom = (np.linalg.norm(qv) * np.linalg.norm(v)) or 1.0
            sim = float(np.dot(qv, v) / denom)
            scores.append((sim, doc))
        
        scores.sort(key=lambda x: x[0], reverse=True)
        top = scores[:k]
        logger.info(f"Flat vector search: {len(all_docs)} docs searched, {len(top)} results")
        
        return self._serialize_results(top)

    def _serialize_hits(self, hits):
        """Serialize Atlas search hits"""
        serializable_hits = []
        for hit in hits:
            doc = hit["doc"]
            serializable_doc = self._serialize_doc(doc)
            serializable_hits.append({
                "doc": serializable_doc,
                "score": float(hit.get("score", 0.0))
            })
        return serializable_hits

    def _serialize_results(self, results):
        """Serialize local search results"""
        serializable_results = []
        for score, doc in results:
            serializable_doc = self._serialize_doc(doc)
            serializable_results.append({
                "doc": serializable_doc,
                "score": float(score)
            })
        return serializable_results

    def _serialize_doc(self, doc):
        """Convert MongoDB document to JSON-serializable format"""
        serializable_doc = {}
        for key, value in doc.items():
            if key == '_id':
                serializable_doc[key] = str(value)
            elif hasattr(value, 'isoformat'):
                serializable_doc[key] = value.isoformat()
            else:
                serializable_doc[key] = value
        return serializable_doc


def ensure_indexes(store: RAGStore):
    # Basic text index for fallback keyword search (optional)
    try:
        store.chunks.create_index([("user_id", ASCENDING), ("project_id", ASCENDING), ("filename", ASCENDING)])
        store.chunks.create_index([("content", TEXT), ("topic_name", TEXT), ("summary", TEXT)], name="text_idx")
        store.files.create_index([("user_id", ASCENDING), ("project_id", ASCENDING), ("filename", ASCENDING)], unique=True)
    except PyMongoError as e:
        logger.warning(f"Index creation warning: {e}")
    # Note: For Atlas Vector, create an Atlas Search index named INDEX_NAME on field "embedding" with vector options.
    # Example (in Atlas UI):
    # {
    #   "mappings": {
    #     "dynamic": false,
    #     "fields": {
    #       "embedding": {
    #         "type": "knnVector",
    #         "dimensions": 384,
    #         "similarity": "cosine"
    #       }
    #     }
    #   }
    # }