LiamKhoaLe commited on
Commit
18b46d7
·
1 Parent(s): f93070b

Upd MongoDB resolver when create new proj

Browse files
Files changed (2) hide show
  1. app.py +280 -58
  2. utils/rag.py +79 -6
app.py CHANGED
@@ -1,13 +1,17 @@
1
  # https://binkhoale1812-edsummariser.hf.space/
2
  import os, io, re, uuid, json, time, logging
3
  from typing import List, Dict, Any, Optional
4
- from datetime import datetime
 
5
 
6
  from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
7
  from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
8
  from fastapi.staticfiles import StaticFiles
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
 
 
 
11
  from utils.rotator import APIKeyRotator
12
  from utils.parser import parse_pdf_bytes, parse_docx_bytes
13
  from utils.caption import BlipCaptioner
@@ -19,6 +23,48 @@ from utils.summarizer import cheap_summarize
19
  from utils.common import trim_text
20
  from utils.logger import get_logger
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # ────────────────────────────── App Setup ──────────────────────────────
23
  logger = get_logger("APP", name="studybuddy")
24
 
@@ -45,8 +91,19 @@ captioner = BlipCaptioner()
45
  embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
46
 
47
  # Mongo / RAG store
48
- rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
49
- ensure_indexes(rag)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  # ────────────────────────────── Auth Helpers/Routes ───────────────────────────
@@ -100,50 +157,108 @@ async def login(email: str = Form(...), password: str = Form(...)):
100
 
101
 
102
  # ────────────────────────────── Project Management ───────────────────────────
103
- @app.post("/projects/create")
104
  async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
105
  """Create a new project for a user"""
106
- if not name.strip():
107
- raise HTTPException(400, detail="Project name is required")
108
-
109
- project_id = str(uuid.uuid4())
110
- project = {
111
- "project_id": project_id,
112
- "user_id": user_id,
113
- "name": name.strip(),
114
- "description": description.strip(),
115
- "created_at": datetime.utcnow(),
116
- "updated_at": datetime.utcnow()
117
- }
118
-
119
- rag.db["projects"].insert_one(project)
120
- logger.info(f"[PROJECT] Created project {name} for user {user_id}")
121
- return project
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
- @app.get("/projects")
125
  async def list_projects(user_id: str):
126
  """List all projects for a user"""
127
- projects = list(rag.db["projects"].find(
128
- {"user_id": user_id},
129
- {"_id": 0}
130
- ).sort("updated_at", -1))
131
- return {"projects": projects}
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
- @app.get("/projects/{project_id}")
135
  async def get_project(project_id: str, user_id: str):
136
  """Get a specific project (with user ownership check)"""
137
  project = rag.db["projects"].find_one(
138
- {"project_id": project_id, "user_id": user_id},
139
- {"_id": 0}
140
  )
141
  if not project:
142
  raise HTTPException(404, detail="Project not found")
143
- return project
 
 
 
 
 
 
 
 
144
 
145
 
146
- @app.delete("/projects/{project_id}")
147
  async def delete_project(project_id: str, user_id: str):
148
  """Delete a project and all its associated data"""
149
  # Check ownership
@@ -158,11 +273,11 @@ async def delete_project(project_id: str, user_id: str):
158
  rag.db["chat_sessions"].delete_many({"project_id": project_id})
159
 
160
  logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
161
- return {"message": "Project deleted successfully"}
162
 
163
 
164
  # ────────────────────────────── Chat Sessions ──────────────────────────────
165
- @app.post("/chat/save")
166
  async def save_chat_message(
167
  user_id: str = Form(...),
168
  project_id: str = Form(...),
@@ -180,21 +295,32 @@ async def save_chat_message(
180
  "role": role,
181
  "content": content,
182
  "timestamp": timestamp or time.time(),
183
- "created_at": datetime.utcnow()
184
  }
185
 
186
  rag.db["chat_sessions"].insert_one(message)
187
- return {"message": "Chat message saved"}
188
 
189
 
190
- @app.get("/chat/history")
191
  async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
192
  """Get chat history for a project"""
193
- messages = list(rag.db["chat_sessions"].find(
194
- {"user_id": user_id, "project_id": project_id},
195
- {"_id": 0}
196
- ).sort("timestamp", 1).limit(limit))
197
- return {"messages": messages}
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  # ────────────────────────────── Helpers ──────────────────────────────
@@ -226,7 +352,7 @@ def index():
226
  return FileResponse(index_path)
227
 
228
 
229
- @app.post("/upload")
230
  async def upload_files(
231
  request: Request,
232
  background_tasks: BackgroundTasks,
@@ -307,23 +433,39 @@ async def upload_files(
307
 
308
  # Kick off processing in background to keep UI responsive
309
  background_tasks.add_task(_process)
310
- return {"job_id": job_id, "status": "processing"}
311
 
312
 
313
  @app.get("/cards")
314
  def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
315
- return rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
 
318
- @app.get("/file-summary")
319
  def get_file_summary(user_id: str, project_id: str, filename: str):
320
  doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
321
  if not doc:
322
  raise HTTPException(404, detail="No summary found for that file.")
323
- return {"filename": filename, "summary": doc.get("summary", "")}
324
 
325
 
326
- @app.post("/chat")
327
  async def chat(
328
  user_id: str = Form(...),
329
  project_id: str = Form(...),
@@ -337,6 +479,7 @@ async def chat(
337
  - Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
338
  - After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
339
  """
 
340
  from memo.memory import MemoryLRU
341
  from memo.history import summarize_qa_with_nvidia, files_relevance, related_recent_and_semantic_context
342
  from utils.router import NVIDIA_SMALL # reuse default name
@@ -349,9 +492,15 @@ async def chat(
349
  fn = m.group(1)
350
  doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
351
  if doc:
352
- return {"answer": doc.get("summary", ""), "sources": [{"filename": fn, "file_summary": True}]}
 
 
 
353
  else:
354
- return {"answer": "I couldn't find a summary for that file in your library.", "sources": []}
 
 
 
355
 
356
  # 1) Preload file list + summaries
357
  files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
@@ -391,11 +540,11 @@ async def chat(
391
  q_vec = embedder.embed([question])[0]
392
  hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=k, filenames=relevant_files if relevant_files else None)
393
  if not hits:
394
- return {
395
- "answer": "I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
396
- "sources": [],
397
- "relevant_files": relevant_files
398
- }
399
  # Compose context
400
  contexts = []
401
  sources_meta = []
@@ -408,7 +557,7 @@ async def chat(
408
  "topic_name": doc.get("topic_name"),
409
  "page_span": doc.get("page_span"),
410
  "score": float(score),
411
- "chunk_id": str(doc.get("_id", ""))
412
  })
413
  context_text = "\n\n---\n\n".join(contexts)
414
 
@@ -462,9 +611,82 @@ async def chat(
462
  logger.warning(f"QA summarize/store failed: {e}")
463
  # Trim for logging
464
  logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
465
- return {"answer": answer, "sources": sources_meta}
466
 
467
 
468
- @app.get("/healthz")
469
  def health():
470
- return {"ok": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # https://binkhoale1812-edsummariser.hf.space/
2
  import os, io, re, uuid, json, time, logging
3
  from typing import List, Dict, Any, Optional
4
+ from datetime import datetime, timezone
5
+ from pydantic import BaseModel
6
 
7
  from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
8
  from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
9
  from fastapi.staticfiles import StaticFiles
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ # MongoDB imports
13
+ from pymongo.errors import PyMongoError, ConnectionFailure, ServerSelectionTimeoutError
14
+
15
  from utils.rotator import APIKeyRotator
16
  from utils.parser import parse_pdf_bytes, parse_docx_bytes
17
  from utils.caption import BlipCaptioner
 
23
  from utils.common import trim_text
24
  from utils.logger import get_logger
25
 
26
+ # ────────────────────────────── Response Models ──────────────────────────────
27
+ class ProjectResponse(BaseModel):
28
+ project_id: str
29
+ user_id: str
30
+ name: str
31
+ description: str
32
+ created_at: str
33
+ updated_at: str
34
+
35
+ class ProjectsListResponse(BaseModel):
36
+ projects: List[ProjectResponse]
37
+
38
+ class ChatMessageResponse(BaseModel):
39
+ user_id: str
40
+ project_id: str
41
+ role: str
42
+ content: str
43
+ timestamp: float
44
+ created_at: str
45
+
46
+ class ChatHistoryResponse(BaseModel):
47
+ messages: List[ChatMessageResponse]
48
+
49
+ class MessageResponse(BaseModel):
50
+ message: str
51
+
52
+ class UploadResponse(BaseModel):
53
+ job_id: str
54
+ status: str
55
+
56
+ class FileSummaryResponse(BaseModel):
57
+ filename: str
58
+ summary: str
59
+
60
+ class ChatAnswerResponse(BaseModel):
61
+ answer: str
62
+ sources: List[Dict[str, Any]]
63
+ relevant_files: Optional[List[str]] = None
64
+
65
+ class HealthResponse(BaseModel):
66
+ ok: bool
67
+
68
  # ────────────────────────────── App Setup ──────────────────────────────
69
  logger = get_logger("APP", name="studybuddy")
70
 
 
91
  embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
92
 
93
  # Mongo / RAG store
94
+ try:
95
+ rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
96
+ # Test the connection
97
+ rag.client.admin.command('ping')
98
+ logger.info("[APP] MongoDB connection successful")
99
+ ensure_indexes(rag)
100
+ logger.info("[APP] MongoDB indexes ensured")
101
+ except Exception as e:
102
+ logger.error(f"[APP] Failed to initialize MongoDB/RAG store: {str(e)}")
103
+ logger.error(f"[APP] MONGO_URI: {os.getenv('MONGO_URI', 'Not set')}")
104
+ logger.error(f"[APP] MONGO_DB: {os.getenv('MONGO_DB', 'studybuddy')}")
105
+ # Create a dummy RAG store for now - this will cause errors but prevents the app from crashing
106
+ rag = None
107
 
108
 
109
  # ────────────────────────────── Auth Helpers/Routes ───────────────────────────
 
157
 
158
 
159
  # ────────────────────────────── Project Management ───────────────────────────
160
+ @app.post("/projects/create", response_model=ProjectResponse)
161
  async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
162
  """Create a new project for a user"""
163
+ try:
164
+ if not rag:
165
+ raise HTTPException(500, detail="Database connection not available")
166
+
167
+ if not name.strip():
168
+ raise HTTPException(400, detail="Project name is required")
169
+
170
+ if not user_id.strip():
171
+ raise HTTPException(400, detail="User ID is required")
172
+
173
+ project_id = str(uuid.uuid4())
174
+ current_time = datetime.now(timezone.utc)
175
+
176
+ project = {
177
+ "project_id": project_id,
178
+ "user_id": user_id,
179
+ "name": name.strip(),
180
+ "description": description.strip(),
181
+ "created_at": current_time,
182
+ "updated_at": current_time
183
+ }
184
+
185
+ logger.info(f"[PROJECT] Creating project {name} for user {user_id}")
186
+
187
+ # Insert the project
188
+ try:
189
+ result = rag.db["projects"].insert_one(project)
190
+ logger.info(f"[PROJECT] Created project {name} with ID {project_id}, MongoDB result: {result.inserted_id}")
191
+ except PyMongoError as mongo_error:
192
+ logger.error(f"[PROJECT] MongoDB error creating project: {str(mongo_error)}")
193
+ raise HTTPException(500, detail=f"Database error: {str(mongo_error)}")
194
+ except Exception as db_error:
195
+ logger.error(f"[PROJECT] Database error creating project: {str(db_error)}")
196
+ raise HTTPException(500, detail=f"Database error: {str(db_error)}")
197
+
198
+ # Return a properly formatted response
199
+ response = ProjectResponse(
200
+ project_id=project_id,
201
+ user_id=user_id,
202
+ name=name.strip(),
203
+ description=description.strip(),
204
+ created_at=current_time.isoformat(),
205
+ updated_at=current_time.isoformat()
206
+ )
207
+
208
+ logger.info(f"[PROJECT] Successfully created project {name} for user {user_id}")
209
+ return response
210
+
211
+ except HTTPException:
212
+ # Re-raise HTTP exceptions
213
+ raise
214
+ except Exception as e:
215
+ logger.error(f"[PROJECT] Error creating project: {str(e)}")
216
+ logger.error(f"[PROJECT] Error type: {type(e)}")
217
+ logger.error(f"[PROJECT] Error details: {e}")
218
+ raise HTTPException(500, detail=f"Failed to create project: {str(e)}")
219
 
220
 
221
+ @app.get("/projects", response_model=ProjectsListResponse)
222
  async def list_projects(user_id: str):
223
  """List all projects for a user"""
224
+ projects_cursor = rag.db["projects"].find(
225
+ {"user_id": user_id}
226
+ ).sort("updated_at", -1)
227
+
228
+ projects = []
229
+ for project in projects_cursor:
230
+ projects.append(ProjectResponse(
231
+ project_id=project["project_id"],
232
+ user_id=project["user_id"],
233
+ name=project["name"],
234
+ description=project.get("description", ""),
235
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
236
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
237
+ ))
238
+
239
+ return ProjectsListResponse(projects=projects)
240
 
241
 
242
+ @app.get("/projects/{project_id}", response_model=ProjectResponse)
243
  async def get_project(project_id: str, user_id: str):
244
  """Get a specific project (with user ownership check)"""
245
  project = rag.db["projects"].find_one(
246
+ {"project_id": project_id, "user_id": user_id}
 
247
  )
248
  if not project:
249
  raise HTTPException(404, detail="Project not found")
250
+
251
+ return ProjectResponse(
252
+ project_id=project["project_id"],
253
+ user_id=project["user_id"],
254
+ name=project["name"],
255
+ description=project.get("description", ""),
256
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
257
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
258
+ )
259
 
260
 
261
+ @app.delete("/projects/{project_id}", response_model=MessageResponse)
262
  async def delete_project(project_id: str, user_id: str):
263
  """Delete a project and all its associated data"""
264
  # Check ownership
 
273
  rag.db["chat_sessions"].delete_many({"project_id": project_id})
274
 
275
  logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
276
+ return MessageResponse(message="Project deleted successfully")
277
 
278
 
279
  # ────────────────────────────── Chat Sessions ──────────────────────────────
280
+ @app.post("/chat/save", response_model=MessageResponse)
281
  async def save_chat_message(
282
  user_id: str = Form(...),
283
  project_id: str = Form(...),
 
295
  "role": role,
296
  "content": content,
297
  "timestamp": timestamp or time.time(),
298
+ "created_at": datetime.now(timezone.utc)
299
  }
300
 
301
  rag.db["chat_sessions"].insert_one(message)
302
+ return MessageResponse(message="Chat message saved")
303
 
304
 
305
+ @app.get("/chat/history", response_model=ChatHistoryResponse)
306
  async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
307
  """Get chat history for a project"""
308
+ messages_cursor = rag.db["chat_sessions"].find(
309
+ {"user_id": user_id, "project_id": project_id}
310
+ ).sort("timestamp", 1).limit(limit)
311
+
312
+ messages = []
313
+ for message in messages_cursor:
314
+ messages.append(ChatMessageResponse(
315
+ user_id=message["user_id"],
316
+ project_id=message["project_id"],
317
+ role=message["role"],
318
+ content=message["content"],
319
+ timestamp=message["timestamp"],
320
+ created_at=message["created_at"].isoformat() if isinstance(message["created_at"], datetime) else str(message["created_at"])
321
+ ))
322
+
323
+ return ChatHistoryResponse(messages=messages)
324
 
325
 
326
  # ────────────────────────────── Helpers ──────────────────────────────
 
352
  return FileResponse(index_path)
353
 
354
 
355
+ @app.post("/upload", response_model=UploadResponse)
356
  async def upload_files(
357
  request: Request,
358
  background_tasks: BackgroundTasks,
 
433
 
434
  # Kick off processing in background to keep UI responsive
435
  background_tasks.add_task(_process)
436
+ return UploadResponse(job_id=job_id, status="processing")
437
 
438
 
439
  @app.get("/cards")
440
  def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
441
+ """List cards for a project"""
442
+ cards = rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
443
+
444
+ # Ensure all cards are JSON serializable
445
+ serializable_cards = []
446
+ for card in cards:
447
+ serializable_card = {}
448
+ for key, value in card.items():
449
+ if key == '_id':
450
+ serializable_card[key] = str(value) # Convert ObjectId to string
451
+ elif isinstance(value, datetime):
452
+ serializable_card[key] = value.isoformat() # Convert datetime to ISO string
453
+ else:
454
+ serializable_card[key] = value
455
+ serializable_cards.append(serializable_card)
456
+
457
+ return {"cards": serializable_cards}
458
 
459
 
460
+ @app.get("/file-summary", response_model=FileSummaryResponse)
461
  def get_file_summary(user_id: str, project_id: str, filename: str):
462
  doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
463
  if not doc:
464
  raise HTTPException(404, detail="No summary found for that file.")
465
+ return FileSummaryResponse(filename=filename, summary=doc.get("summary", ""))
466
 
467
 
468
+ @app.post("/chat", response_model=ChatAnswerResponse)
469
  async def chat(
470
  user_id: str = Form(...),
471
  project_id: str = Form(...),
 
479
  - Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
480
  - After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
481
  """
482
+ import sys
483
  from memo.memory import MemoryLRU
484
  from memo.history import summarize_qa_with_nvidia, files_relevance, related_recent_and_semantic_context
485
  from utils.router import NVIDIA_SMALL # reuse default name
 
492
  fn = m.group(1)
493
  doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
494
  if doc:
495
+ return ChatAnswerResponse(
496
+ answer=doc.get("summary", ""),
497
+ sources=[{"filename": fn, "file_summary": True}]
498
+ )
499
  else:
500
+ return ChatAnswerResponse(
501
+ answer="I couldn't find a summary for that file in your library.",
502
+ sources=[]
503
+ )
504
 
505
  # 1) Preload file list + summaries
506
  files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
 
540
  q_vec = embedder.embed([question])[0]
541
  hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=k, filenames=relevant_files if relevant_files else None)
542
  if not hits:
543
+ return ChatAnswerResponse(
544
+ answer="I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
545
+ sources=[],
546
+ relevant_files=relevant_files
547
+ )
548
  # Compose context
549
  contexts = []
550
  sources_meta = []
 
557
  "topic_name": doc.get("topic_name"),
558
  "page_span": doc.get("page_span"),
559
  "score": float(score),
560
+ "chunk_id": str(doc.get("_id", "")) # Convert ObjectId to string
561
  })
562
  context_text = "\n\n---\n\n".join(contexts)
563
 
 
611
  logger.warning(f"QA summarize/store failed: {e}")
612
  # Trim for logging
613
  logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
614
+ return ChatAnswerResponse(answer=answer, sources=sources_meta, relevant_files=relevant_files)
615
 
616
 
617
+ @app.get("/healthz", response_model=HealthResponse)
618
  def health():
619
+ return HealthResponse(ok=True)
620
+
621
+
622
+ @app.get("/test-db")
623
+ async def test_database():
624
+ """Test database connection and basic operations"""
625
+ try:
626
+ if not rag:
627
+ return {
628
+ "status": "error",
629
+ "message": "RAG store not initialized",
630
+ "error_type": "RAGStoreNotInitialized"
631
+ }
632
+
633
+ # Test basic connection
634
+ rag.client.admin.command('ping')
635
+
636
+ # Test basic insert/query
637
+ test_collection = rag.db["test_collection"]
638
+ test_doc = {"test": True, "timestamp": datetime.now(timezone.utc)}
639
+ result = test_collection.insert_one(test_doc)
640
+
641
+ # Test query
642
+ found = test_collection.find_one({"_id": result.inserted_id})
643
+
644
+ # Clean up
645
+ test_collection.delete_one({"_id": result.inserted_id})
646
+
647
+ return {
648
+ "status": "success",
649
+ "message": "Database connection and operations working correctly",
650
+ "test_id": str(result.inserted_id),
651
+ "found_doc": str(found["_id"]) if found else None
652
+ }
653
+
654
+ except Exception as e:
655
+ logger.error(f"[TEST-DB] Database test failed: {str(e)}")
656
+ return {
657
+ "status": "error",
658
+ "message": f"Database test failed: {str(e)}",
659
+ "error_type": str(type(e))
660
+ }
661
+
662
+
663
+ @app.get("/rag-status")
664
+ async def rag_status():
665
+ """Check the status of the RAG store"""
666
+ if not rag:
667
+ return {
668
+ "status": "error",
669
+ "message": "RAG store not initialized",
670
+ "rag_available": False
671
+ }
672
+
673
+ try:
674
+ # Test connection
675
+ rag.client.admin.command('ping')
676
+ return {
677
+ "status": "success",
678
+ "message": "RAG store is available and connected",
679
+ "rag_available": True,
680
+ "database": rag.db.name,
681
+ "collections": {
682
+ "chunks": rag.chunks.name,
683
+ "files": rag.files.name
684
+ }
685
+ }
686
+ except Exception as e:
687
+ return {
688
+ "status": "error",
689
+ "message": f"RAG store connection failed: {str(e)}",
690
+ "rag_available": False,
691
+ "error": str(e)
692
+ }
utils/rag.py CHANGED
@@ -49,17 +49,53 @@ class RAGStore:
49
  if filename:
50
  q["filename"] = filename
51
  cur = self.chunks.find(q, {"embedding": 0}).skip(skip).limit(limit).sort([("_id", ASCENDING)])
52
- return list(cur)
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def get_file_summary(self, user_id: str, project_id: str, filename: str):
55
- return self.files.find_one({"user_id": user_id, "project_id": project_id, "filename": filename})
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def list_files(self, user_id: str, project_id: str):
58
  """List all files for a project with their summaries"""
59
- return list(self.files.find(
60
  {"user_id": user_id, "project_id": project_id},
61
  {"_id": 0, "filename": 1, "summary": 1}
62
- ).sort("filename", ASCENDING))
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int = 6, filenames: Optional[List[str]] = None):
65
  if USE_ATLAS_VECTOR:
@@ -88,7 +124,26 @@ class RAGStore:
88
  ]
89
  # Append hit scoring algorithm
90
  hits = list(self.chunks.aggregate(pipeline))
91
- return [{"doc": h["doc"], "score": h["score"]} for h in hits]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
  # Fallback: scan limited sample and compute cosine locally
94
  q = {"user_id": user_id, "project_id": project_id}
@@ -107,7 +162,25 @@ class RAGStore:
107
  scores.sort(key=lambda x: x[0], reverse=True)
108
  top = scores[:k]
109
  logger.info(f"Vector search sample={len(sample)} returned top={len(top)}")
110
- return [{"doc": d, "score": s} for (s, d) in top]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  def ensure_indexes(store: RAGStore):
 
49
  if filename:
50
  q["filename"] = filename
51
  cur = self.chunks.find(q, {"embedding": 0}).skip(skip).limit(limit).sort([("_id", ASCENDING)])
52
+ # Convert MongoDB documents to JSON-serializable format
53
+ cards = []
54
+ for card in cur:
55
+ serializable_card = {}
56
+ for key, value in card.items():
57
+ if key == '_id':
58
+ serializable_card[key] = str(value) # Convert ObjectId to string
59
+ elif hasattr(value, 'isoformat'): # Handle datetime objects
60
+ serializable_card[key] = value.isoformat()
61
+ else:
62
+ serializable_card[key] = value
63
+ cards.append(serializable_card)
64
+ return cards
65
 
66
  def get_file_summary(self, user_id: str, project_id: str, filename: str):
67
+ doc = self.files.find_one({"user_id": user_id, "project_id": project_id, "filename": filename})
68
+ if doc:
69
+ # Convert MongoDB document to JSON-serializable format
70
+ serializable_doc = {}
71
+ for key, value in doc.items():
72
+ if key == '_id':
73
+ serializable_doc[key] = str(value) # Convert ObjectId to string
74
+ elif hasattr(value, 'isoformat'): # Handle datetime objects
75
+ serializable_doc[key] = value.isoformat()
76
+ else:
77
+ serializable_doc[key] = value
78
+ return serializable_doc
79
+ return None
80
 
81
  def list_files(self, user_id: str, project_id: str):
82
  """List all files for a project with their summaries"""
83
+ files_cursor = self.files.find(
84
  {"user_id": user_id, "project_id": project_id},
85
  {"_id": 0, "filename": 1, "summary": 1}
86
+ ).sort("filename", ASCENDING)
87
+
88
+ # Convert MongoDB documents to JSON-serializable format
89
+ files = []
90
+ for file_doc in files_cursor:
91
+ serializable_file = {}
92
+ for key, value in file_doc.items():
93
+ if hasattr(value, 'isoformat'): # Handle datetime objects
94
+ serializable_file[key] = value.isoformat()
95
+ else:
96
+ serializable_file[key] = value
97
+ files.append(serializable_file)
98
+ return files
99
 
100
  def vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int = 6, filenames: Optional[List[str]] = None):
101
  if USE_ATLAS_VECTOR:
 
124
  ]
125
  # Append hit scoring algorithm
126
  hits = list(self.chunks.aggregate(pipeline))
127
+
128
+ # Convert MongoDB documents to JSON-serializable format
129
+ serializable_hits = []
130
+ for hit in hits:
131
+ doc = hit["doc"]
132
+ serializable_doc = {}
133
+ for key, value in doc.items():
134
+ if key == '_id':
135
+ serializable_doc[key] = str(value) # Convert ObjectId to string
136
+ elif hasattr(value, 'isoformat'): # Handle datetime objects
137
+ serializable_doc[key] = value.isoformat()
138
+ else:
139
+ serializable_doc[key] = value
140
+
141
+ serializable_hits.append({
142
+ "doc": serializable_doc,
143
+ "score": float(hit["score"]) # Ensure score is a regular float
144
+ })
145
+
146
+ return serializable_hits
147
  else:
148
  # Fallback: scan limited sample and compute cosine locally
149
  q = {"user_id": user_id, "project_id": project_id}
 
162
  scores.sort(key=lambda x: x[0], reverse=True)
163
  top = scores[:k]
164
  logger.info(f"Vector search sample={len(sample)} returned top={len(top)}")
165
+
166
+ # Convert MongoDB documents to JSON-serializable format
167
+ serializable_results = []
168
+ for score, doc in top:
169
+ serializable_doc = {}
170
+ for key, value in doc.items():
171
+ if key == '_id':
172
+ serializable_doc[key] = str(value) # Convert ObjectId to string
173
+ elif hasattr(value, 'isoformat'): # Handle datetime objects
174
+ serializable_doc[key] = value.isoformat()
175
+ else:
176
+ serializable_doc[key] = value
177
+
178
+ serializable_results.append({
179
+ "doc": serializable_doc,
180
+ "score": float(score) # Ensure score is a regular float
181
+ })
182
+
183
+ return serializable_results
184
 
185
 
186
  def ensure_indexes(store: RAGStore):