Spaces:
Sleeping
Sleeping
Commit
·
18b46d7
1
Parent(s):
f93070b
Upd MongoDB resolver when create new proj
Browse files- app.py +280 -58
- utils/rag.py +79 -6
app.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
| 1 |
# https://binkhoale1812-edsummariser.hf.space/
|
| 2 |
import os, io, re, uuid, json, time, logging
|
| 3 |
from typing import List, Dict, Any, Optional
|
| 4 |
-
from datetime import datetime
|
|
|
|
| 5 |
|
| 6 |
from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
|
| 7 |
from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
|
| 8 |
from fastapi.staticfiles import StaticFiles
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
from utils.rotator import APIKeyRotator
|
| 12 |
from utils.parser import parse_pdf_bytes, parse_docx_bytes
|
| 13 |
from utils.caption import BlipCaptioner
|
|
@@ -19,6 +23,48 @@ from utils.summarizer import cheap_summarize
|
|
| 19 |
from utils.common import trim_text
|
| 20 |
from utils.logger import get_logger
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# ────────────────────────────── App Setup ──────────────────────────────
|
| 23 |
logger = get_logger("APP", name="studybuddy")
|
| 24 |
|
|
@@ -45,8 +91,19 @@ captioner = BlipCaptioner()
|
|
| 45 |
embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
|
| 46 |
|
| 47 |
# Mongo / RAG store
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
# ────────────────────────────── Auth Helpers/Routes ───────────────────────────
|
|
@@ -100,50 +157,108 @@ async def login(email: str = Form(...), password: str = Form(...)):
|
|
| 100 |
|
| 101 |
|
| 102 |
# ────────────────────────────── Project Management ───────────────────────────
|
| 103 |
-
@app.post("/projects/create")
|
| 104 |
async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
|
| 105 |
"""Create a new project for a user"""
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
-
@app.get("/projects")
|
| 125 |
async def list_projects(user_id: str):
|
| 126 |
"""List all projects for a user"""
|
| 127 |
-
|
| 128 |
-
{"user_id": user_id}
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
-
@app.get("/projects/{project_id}")
|
| 135 |
async def get_project(project_id: str, user_id: str):
|
| 136 |
"""Get a specific project (with user ownership check)"""
|
| 137 |
project = rag.db["projects"].find_one(
|
| 138 |
-
{"project_id": project_id, "user_id": user_id}
|
| 139 |
-
{"_id": 0}
|
| 140 |
)
|
| 141 |
if not project:
|
| 142 |
raise HTTPException(404, detail="Project not found")
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
-
@app.delete("/projects/{project_id}")
|
| 147 |
async def delete_project(project_id: str, user_id: str):
|
| 148 |
"""Delete a project and all its associated data"""
|
| 149 |
# Check ownership
|
|
@@ -158,11 +273,11 @@ async def delete_project(project_id: str, user_id: str):
|
|
| 158 |
rag.db["chat_sessions"].delete_many({"project_id": project_id})
|
| 159 |
|
| 160 |
logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
|
| 161 |
-
return
|
| 162 |
|
| 163 |
|
| 164 |
# ────────────────────────────── Chat Sessions ──────────────────────────────
|
| 165 |
-
@app.post("/chat/save")
|
| 166 |
async def save_chat_message(
|
| 167 |
user_id: str = Form(...),
|
| 168 |
project_id: str = Form(...),
|
|
@@ -180,21 +295,32 @@ async def save_chat_message(
|
|
| 180 |
"role": role,
|
| 181 |
"content": content,
|
| 182 |
"timestamp": timestamp or time.time(),
|
| 183 |
-
"created_at": datetime.
|
| 184 |
}
|
| 185 |
|
| 186 |
rag.db["chat_sessions"].insert_one(message)
|
| 187 |
-
return
|
| 188 |
|
| 189 |
|
| 190 |
-
@app.get("/chat/history")
|
| 191 |
async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
|
| 192 |
"""Get chat history for a project"""
|
| 193 |
-
|
| 194 |
-
{"user_id": user_id, "project_id": project_id}
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
|
| 200 |
# ────────────────────────────── Helpers ──────────────────────────────
|
|
@@ -226,7 +352,7 @@ def index():
|
|
| 226 |
return FileResponse(index_path)
|
| 227 |
|
| 228 |
|
| 229 |
-
@app.post("/upload")
|
| 230 |
async def upload_files(
|
| 231 |
request: Request,
|
| 232 |
background_tasks: BackgroundTasks,
|
|
@@ -307,23 +433,39 @@ async def upload_files(
|
|
| 307 |
|
| 308 |
# Kick off processing in background to keep UI responsive
|
| 309 |
background_tasks.add_task(_process)
|
| 310 |
-
return
|
| 311 |
|
| 312 |
|
| 313 |
@app.get("/cards")
|
| 314 |
def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
|
| 318 |
-
@app.get("/file-summary")
|
| 319 |
def get_file_summary(user_id: str, project_id: str, filename: str):
|
| 320 |
doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
|
| 321 |
if not doc:
|
| 322 |
raise HTTPException(404, detail="No summary found for that file.")
|
| 323 |
-
return
|
| 324 |
|
| 325 |
|
| 326 |
-
@app.post("/chat")
|
| 327 |
async def chat(
|
| 328 |
user_id: str = Form(...),
|
| 329 |
project_id: str = Form(...),
|
|
@@ -337,6 +479,7 @@ async def chat(
|
|
| 337 |
- Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
|
| 338 |
- After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
|
| 339 |
"""
|
|
|
|
| 340 |
from memo.memory import MemoryLRU
|
| 341 |
from memo.history import summarize_qa_with_nvidia, files_relevance, related_recent_and_semantic_context
|
| 342 |
from utils.router import NVIDIA_SMALL # reuse default name
|
|
@@ -349,9 +492,15 @@ async def chat(
|
|
| 349 |
fn = m.group(1)
|
| 350 |
doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
|
| 351 |
if doc:
|
| 352 |
-
return
|
|
|
|
|
|
|
|
|
|
| 353 |
else:
|
| 354 |
-
return
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
# 1) Preload file list + summaries
|
| 357 |
files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
|
|
@@ -391,11 +540,11 @@ async def chat(
|
|
| 391 |
q_vec = embedder.embed([question])[0]
|
| 392 |
hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=k, filenames=relevant_files if relevant_files else None)
|
| 393 |
if not hits:
|
| 394 |
-
return
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
# Compose context
|
| 400 |
contexts = []
|
| 401 |
sources_meta = []
|
|
@@ -408,7 +557,7 @@ async def chat(
|
|
| 408 |
"topic_name": doc.get("topic_name"),
|
| 409 |
"page_span": doc.get("page_span"),
|
| 410 |
"score": float(score),
|
| 411 |
-
"chunk_id": str(doc.get("_id", ""))
|
| 412 |
})
|
| 413 |
context_text = "\n\n---\n\n".join(contexts)
|
| 414 |
|
|
@@ -462,9 +611,82 @@ async def chat(
|
|
| 462 |
logger.warning(f"QA summarize/store failed: {e}")
|
| 463 |
# Trim for logging
|
| 464 |
logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
|
| 465 |
-
return
|
| 466 |
|
| 467 |
|
| 468 |
-
@app.get("/healthz")
|
| 469 |
def health():
|
| 470 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# https://binkhoale1812-edsummariser.hf.space/
|
| 2 |
import os, io, re, uuid, json, time, logging
|
| 3 |
from typing import List, Dict, Any, Optional
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
|
| 7 |
from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
|
| 8 |
from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
|
| 9 |
from fastapi.staticfiles import StaticFiles
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
|
| 12 |
+
# MongoDB imports
|
| 13 |
+
from pymongo.errors import PyMongoError, ConnectionFailure, ServerSelectionTimeoutError
|
| 14 |
+
|
| 15 |
from utils.rotator import APIKeyRotator
|
| 16 |
from utils.parser import parse_pdf_bytes, parse_docx_bytes
|
| 17 |
from utils.caption import BlipCaptioner
|
|
|
|
| 23 |
from utils.common import trim_text
|
| 24 |
from utils.logger import get_logger
|
| 25 |
|
| 26 |
+
# ────────────────────────────── Response Models ──────────────────────────────
|
| 27 |
+
class ProjectResponse(BaseModel):
|
| 28 |
+
project_id: str
|
| 29 |
+
user_id: str
|
| 30 |
+
name: str
|
| 31 |
+
description: str
|
| 32 |
+
created_at: str
|
| 33 |
+
updated_at: str
|
| 34 |
+
|
| 35 |
+
class ProjectsListResponse(BaseModel):
|
| 36 |
+
projects: List[ProjectResponse]
|
| 37 |
+
|
| 38 |
+
class ChatMessageResponse(BaseModel):
|
| 39 |
+
user_id: str
|
| 40 |
+
project_id: str
|
| 41 |
+
role: str
|
| 42 |
+
content: str
|
| 43 |
+
timestamp: float
|
| 44 |
+
created_at: str
|
| 45 |
+
|
| 46 |
+
class ChatHistoryResponse(BaseModel):
|
| 47 |
+
messages: List[ChatMessageResponse]
|
| 48 |
+
|
| 49 |
+
class MessageResponse(BaseModel):
|
| 50 |
+
message: str
|
| 51 |
+
|
| 52 |
+
class UploadResponse(BaseModel):
|
| 53 |
+
job_id: str
|
| 54 |
+
status: str
|
| 55 |
+
|
| 56 |
+
class FileSummaryResponse(BaseModel):
|
| 57 |
+
filename: str
|
| 58 |
+
summary: str
|
| 59 |
+
|
| 60 |
+
class ChatAnswerResponse(BaseModel):
|
| 61 |
+
answer: str
|
| 62 |
+
sources: List[Dict[str, Any]]
|
| 63 |
+
relevant_files: Optional[List[str]] = None
|
| 64 |
+
|
| 65 |
+
class HealthResponse(BaseModel):
|
| 66 |
+
ok: bool
|
| 67 |
+
|
| 68 |
# ────────────────────────────── App Setup ──────────────────────────────
|
| 69 |
logger = get_logger("APP", name="studybuddy")
|
| 70 |
|
|
|
|
| 91 |
embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
|
| 92 |
|
| 93 |
# Mongo / RAG store
|
| 94 |
+
try:
|
| 95 |
+
rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
|
| 96 |
+
# Test the connection
|
| 97 |
+
rag.client.admin.command('ping')
|
| 98 |
+
logger.info("[APP] MongoDB connection successful")
|
| 99 |
+
ensure_indexes(rag)
|
| 100 |
+
logger.info("[APP] MongoDB indexes ensured")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"[APP] Failed to initialize MongoDB/RAG store: {str(e)}")
|
| 103 |
+
logger.error(f"[APP] MONGO_URI: {os.getenv('MONGO_URI', 'Not set')}")
|
| 104 |
+
logger.error(f"[APP] MONGO_DB: {os.getenv('MONGO_DB', 'studybuddy')}")
|
| 105 |
+
# Create a dummy RAG store for now - this will cause errors but prevents the app from crashing
|
| 106 |
+
rag = None
|
| 107 |
|
| 108 |
|
| 109 |
# ────────────────────────────── Auth Helpers/Routes ───────────────────────────
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
# ────────────────────────────── Project Management ───────────────────────────
|
| 160 |
+
@app.post("/projects/create", response_model=ProjectResponse)
|
| 161 |
async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
|
| 162 |
"""Create a new project for a user"""
|
| 163 |
+
try:
|
| 164 |
+
if not rag:
|
| 165 |
+
raise HTTPException(500, detail="Database connection not available")
|
| 166 |
+
|
| 167 |
+
if not name.strip():
|
| 168 |
+
raise HTTPException(400, detail="Project name is required")
|
| 169 |
+
|
| 170 |
+
if not user_id.strip():
|
| 171 |
+
raise HTTPException(400, detail="User ID is required")
|
| 172 |
+
|
| 173 |
+
project_id = str(uuid.uuid4())
|
| 174 |
+
current_time = datetime.now(timezone.utc)
|
| 175 |
+
|
| 176 |
+
project = {
|
| 177 |
+
"project_id": project_id,
|
| 178 |
+
"user_id": user_id,
|
| 179 |
+
"name": name.strip(),
|
| 180 |
+
"description": description.strip(),
|
| 181 |
+
"created_at": current_time,
|
| 182 |
+
"updated_at": current_time
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
logger.info(f"[PROJECT] Creating project {name} for user {user_id}")
|
| 186 |
+
|
| 187 |
+
# Insert the project
|
| 188 |
+
try:
|
| 189 |
+
result = rag.db["projects"].insert_one(project)
|
| 190 |
+
logger.info(f"[PROJECT] Created project {name} with ID {project_id}, MongoDB result: {result.inserted_id}")
|
| 191 |
+
except PyMongoError as mongo_error:
|
| 192 |
+
logger.error(f"[PROJECT] MongoDB error creating project: {str(mongo_error)}")
|
| 193 |
+
raise HTTPException(500, detail=f"Database error: {str(mongo_error)}")
|
| 194 |
+
except Exception as db_error:
|
| 195 |
+
logger.error(f"[PROJECT] Database error creating project: {str(db_error)}")
|
| 196 |
+
raise HTTPException(500, detail=f"Database error: {str(db_error)}")
|
| 197 |
+
|
| 198 |
+
# Return a properly formatted response
|
| 199 |
+
response = ProjectResponse(
|
| 200 |
+
project_id=project_id,
|
| 201 |
+
user_id=user_id,
|
| 202 |
+
name=name.strip(),
|
| 203 |
+
description=description.strip(),
|
| 204 |
+
created_at=current_time.isoformat(),
|
| 205 |
+
updated_at=current_time.isoformat()
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
logger.info(f"[PROJECT] Successfully created project {name} for user {user_id}")
|
| 209 |
+
return response
|
| 210 |
+
|
| 211 |
+
except HTTPException:
|
| 212 |
+
# Re-raise HTTP exceptions
|
| 213 |
+
raise
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"[PROJECT] Error creating project: {str(e)}")
|
| 216 |
+
logger.error(f"[PROJECT] Error type: {type(e)}")
|
| 217 |
+
logger.error(f"[PROJECT] Error details: {e}")
|
| 218 |
+
raise HTTPException(500, detail=f"Failed to create project: {str(e)}")
|
| 219 |
|
| 220 |
|
| 221 |
+
@app.get("/projects", response_model=ProjectsListResponse)
|
| 222 |
async def list_projects(user_id: str):
|
| 223 |
"""List all projects for a user"""
|
| 224 |
+
projects_cursor = rag.db["projects"].find(
|
| 225 |
+
{"user_id": user_id}
|
| 226 |
+
).sort("updated_at", -1)
|
| 227 |
+
|
| 228 |
+
projects = []
|
| 229 |
+
for project in projects_cursor:
|
| 230 |
+
projects.append(ProjectResponse(
|
| 231 |
+
project_id=project["project_id"],
|
| 232 |
+
user_id=project["user_id"],
|
| 233 |
+
name=project["name"],
|
| 234 |
+
description=project.get("description", ""),
|
| 235 |
+
created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
|
| 236 |
+
updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
|
| 237 |
+
))
|
| 238 |
+
|
| 239 |
+
return ProjectsListResponse(projects=projects)
|
| 240 |
|
| 241 |
|
| 242 |
+
@app.get("/projects/{project_id}", response_model=ProjectResponse)
|
| 243 |
async def get_project(project_id: str, user_id: str):
|
| 244 |
"""Get a specific project (with user ownership check)"""
|
| 245 |
project = rag.db["projects"].find_one(
|
| 246 |
+
{"project_id": project_id, "user_id": user_id}
|
|
|
|
| 247 |
)
|
| 248 |
if not project:
|
| 249 |
raise HTTPException(404, detail="Project not found")
|
| 250 |
+
|
| 251 |
+
return ProjectResponse(
|
| 252 |
+
project_id=project["project_id"],
|
| 253 |
+
user_id=project["user_id"],
|
| 254 |
+
name=project["name"],
|
| 255 |
+
description=project.get("description", ""),
|
| 256 |
+
created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
|
| 257 |
+
updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
|
| 258 |
+
)
|
| 259 |
|
| 260 |
|
| 261 |
+
@app.delete("/projects/{project_id}", response_model=MessageResponse)
|
| 262 |
async def delete_project(project_id: str, user_id: str):
|
| 263 |
"""Delete a project and all its associated data"""
|
| 264 |
# Check ownership
|
|
|
|
| 273 |
rag.db["chat_sessions"].delete_many({"project_id": project_id})
|
| 274 |
|
| 275 |
logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
|
| 276 |
+
return MessageResponse(message="Project deleted successfully")
|
| 277 |
|
| 278 |
|
| 279 |
# ────────────────────────────── Chat Sessions ──────────────────────────────
|
| 280 |
+
@app.post("/chat/save", response_model=MessageResponse)
|
| 281 |
async def save_chat_message(
|
| 282 |
user_id: str = Form(...),
|
| 283 |
project_id: str = Form(...),
|
|
|
|
| 295 |
"role": role,
|
| 296 |
"content": content,
|
| 297 |
"timestamp": timestamp or time.time(),
|
| 298 |
+
"created_at": datetime.now(timezone.utc)
|
| 299 |
}
|
| 300 |
|
| 301 |
rag.db["chat_sessions"].insert_one(message)
|
| 302 |
+
return MessageResponse(message="Chat message saved")
|
| 303 |
|
| 304 |
|
| 305 |
+
@app.get("/chat/history", response_model=ChatHistoryResponse)
|
| 306 |
async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
|
| 307 |
"""Get chat history for a project"""
|
| 308 |
+
messages_cursor = rag.db["chat_sessions"].find(
|
| 309 |
+
{"user_id": user_id, "project_id": project_id}
|
| 310 |
+
).sort("timestamp", 1).limit(limit)
|
| 311 |
+
|
| 312 |
+
messages = []
|
| 313 |
+
for message in messages_cursor:
|
| 314 |
+
messages.append(ChatMessageResponse(
|
| 315 |
+
user_id=message["user_id"],
|
| 316 |
+
project_id=message["project_id"],
|
| 317 |
+
role=message["role"],
|
| 318 |
+
content=message["content"],
|
| 319 |
+
timestamp=message["timestamp"],
|
| 320 |
+
created_at=message["created_at"].isoformat() if isinstance(message["created_at"], datetime) else str(message["created_at"])
|
| 321 |
+
))
|
| 322 |
+
|
| 323 |
+
return ChatHistoryResponse(messages=messages)
|
| 324 |
|
| 325 |
|
| 326 |
# ────────────────────────────── Helpers ──────────────────────────────
|
|
|
|
| 352 |
return FileResponse(index_path)
|
| 353 |
|
| 354 |
|
| 355 |
+
@app.post("/upload", response_model=UploadResponse)
|
| 356 |
async def upload_files(
|
| 357 |
request: Request,
|
| 358 |
background_tasks: BackgroundTasks,
|
|
|
|
| 433 |
|
| 434 |
# Kick off processing in background to keep UI responsive
|
| 435 |
background_tasks.add_task(_process)
|
| 436 |
+
return UploadResponse(job_id=job_id, status="processing")
|
| 437 |
|
| 438 |
|
| 439 |
@app.get("/cards")
|
| 440 |
def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
|
| 441 |
+
"""List cards for a project"""
|
| 442 |
+
cards = rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
|
| 443 |
+
|
| 444 |
+
# Ensure all cards are JSON serializable
|
| 445 |
+
serializable_cards = []
|
| 446 |
+
for card in cards:
|
| 447 |
+
serializable_card = {}
|
| 448 |
+
for key, value in card.items():
|
| 449 |
+
if key == '_id':
|
| 450 |
+
serializable_card[key] = str(value) # Convert ObjectId to string
|
| 451 |
+
elif isinstance(value, datetime):
|
| 452 |
+
serializable_card[key] = value.isoformat() # Convert datetime to ISO string
|
| 453 |
+
else:
|
| 454 |
+
serializable_card[key] = value
|
| 455 |
+
serializable_cards.append(serializable_card)
|
| 456 |
+
|
| 457 |
+
return {"cards": serializable_cards}
|
| 458 |
|
| 459 |
|
| 460 |
+
@app.get("/file-summary", response_model=FileSummaryResponse)
|
| 461 |
def get_file_summary(user_id: str, project_id: str, filename: str):
|
| 462 |
doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
|
| 463 |
if not doc:
|
| 464 |
raise HTTPException(404, detail="No summary found for that file.")
|
| 465 |
+
return FileSummaryResponse(filename=filename, summary=doc.get("summary", ""))
|
| 466 |
|
| 467 |
|
| 468 |
+
@app.post("/chat", response_model=ChatAnswerResponse)
|
| 469 |
async def chat(
|
| 470 |
user_id: str = Form(...),
|
| 471 |
project_id: str = Form(...),
|
|
|
|
| 479 |
- Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
|
| 480 |
- After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
|
| 481 |
"""
|
| 482 |
+
import sys
|
| 483 |
from memo.memory import MemoryLRU
|
| 484 |
from memo.history import summarize_qa_with_nvidia, files_relevance, related_recent_and_semantic_context
|
| 485 |
from utils.router import NVIDIA_SMALL # reuse default name
|
|
|
|
| 492 |
fn = m.group(1)
|
| 493 |
doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
|
| 494 |
if doc:
|
| 495 |
+
return ChatAnswerResponse(
|
| 496 |
+
answer=doc.get("summary", ""),
|
| 497 |
+
sources=[{"filename": fn, "file_summary": True}]
|
| 498 |
+
)
|
| 499 |
else:
|
| 500 |
+
return ChatAnswerResponse(
|
| 501 |
+
answer="I couldn't find a summary for that file in your library.",
|
| 502 |
+
sources=[]
|
| 503 |
+
)
|
| 504 |
|
| 505 |
# 1) Preload file list + summaries
|
| 506 |
files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
|
|
|
|
| 540 |
q_vec = embedder.embed([question])[0]
|
| 541 |
hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=k, filenames=relevant_files if relevant_files else None)
|
| 542 |
if not hits:
|
| 543 |
+
return ChatAnswerResponse(
|
| 544 |
+
answer="I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
|
| 545 |
+
sources=[],
|
| 546 |
+
relevant_files=relevant_files
|
| 547 |
+
)
|
| 548 |
# Compose context
|
| 549 |
contexts = []
|
| 550 |
sources_meta = []
|
|
|
|
| 557 |
"topic_name": doc.get("topic_name"),
|
| 558 |
"page_span": doc.get("page_span"),
|
| 559 |
"score": float(score),
|
| 560 |
+
"chunk_id": str(doc.get("_id", "")) # Convert ObjectId to string
|
| 561 |
})
|
| 562 |
context_text = "\n\n---\n\n".join(contexts)
|
| 563 |
|
|
|
|
| 611 |
logger.warning(f"QA summarize/store failed: {e}")
|
| 612 |
# Trim for logging
|
| 613 |
logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
|
| 614 |
+
return ChatAnswerResponse(answer=answer, sources=sources_meta, relevant_files=relevant_files)
|
| 615 |
|
| 616 |
|
| 617 |
+
@app.get("/healthz", response_model=HealthResponse)
|
| 618 |
def health():
|
| 619 |
+
return HealthResponse(ok=True)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
@app.get("/test-db")
|
| 623 |
+
async def test_database():
|
| 624 |
+
"""Test database connection and basic operations"""
|
| 625 |
+
try:
|
| 626 |
+
if not rag:
|
| 627 |
+
return {
|
| 628 |
+
"status": "error",
|
| 629 |
+
"message": "RAG store not initialized",
|
| 630 |
+
"error_type": "RAGStoreNotInitialized"
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
# Test basic connection
|
| 634 |
+
rag.client.admin.command('ping')
|
| 635 |
+
|
| 636 |
+
# Test basic insert/query
|
| 637 |
+
test_collection = rag.db["test_collection"]
|
| 638 |
+
test_doc = {"test": True, "timestamp": datetime.now(timezone.utc)}
|
| 639 |
+
result = test_collection.insert_one(test_doc)
|
| 640 |
+
|
| 641 |
+
# Test query
|
| 642 |
+
found = test_collection.find_one({"_id": result.inserted_id})
|
| 643 |
+
|
| 644 |
+
# Clean up
|
| 645 |
+
test_collection.delete_one({"_id": result.inserted_id})
|
| 646 |
+
|
| 647 |
+
return {
|
| 648 |
+
"status": "success",
|
| 649 |
+
"message": "Database connection and operations working correctly",
|
| 650 |
+
"test_id": str(result.inserted_id),
|
| 651 |
+
"found_doc": str(found["_id"]) if found else None
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
except Exception as e:
|
| 655 |
+
logger.error(f"[TEST-DB] Database test failed: {str(e)}")
|
| 656 |
+
return {
|
| 657 |
+
"status": "error",
|
| 658 |
+
"message": f"Database test failed: {str(e)}",
|
| 659 |
+
"error_type": str(type(e))
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@app.get("/rag-status")
|
| 664 |
+
async def rag_status():
|
| 665 |
+
"""Check the status of the RAG store"""
|
| 666 |
+
if not rag:
|
| 667 |
+
return {
|
| 668 |
+
"status": "error",
|
| 669 |
+
"message": "RAG store not initialized",
|
| 670 |
+
"rag_available": False
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
try:
|
| 674 |
+
# Test connection
|
| 675 |
+
rag.client.admin.command('ping')
|
| 676 |
+
return {
|
| 677 |
+
"status": "success",
|
| 678 |
+
"message": "RAG store is available and connected",
|
| 679 |
+
"rag_available": True,
|
| 680 |
+
"database": rag.db.name,
|
| 681 |
+
"collections": {
|
| 682 |
+
"chunks": rag.chunks.name,
|
| 683 |
+
"files": rag.files.name
|
| 684 |
+
}
|
| 685 |
+
}
|
| 686 |
+
except Exception as e:
|
| 687 |
+
return {
|
| 688 |
+
"status": "error",
|
| 689 |
+
"message": f"RAG store connection failed: {str(e)}",
|
| 690 |
+
"rag_available": False,
|
| 691 |
+
"error": str(e)
|
| 692 |
+
}
|
utils/rag.py
CHANGED
|
@@ -49,17 +49,53 @@ class RAGStore:
|
|
| 49 |
if filename:
|
| 50 |
q["filename"] = filename
|
| 51 |
cur = self.chunks.find(q, {"embedding": 0}).skip(skip).limit(limit).sort([("_id", ASCENDING)])
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def get_file_summary(self, user_id: str, project_id: str, filename: str):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def list_files(self, user_id: str, project_id: str):
|
| 58 |
"""List all files for a project with their summaries"""
|
| 59 |
-
|
| 60 |
{"user_id": user_id, "project_id": project_id},
|
| 61 |
{"_id": 0, "filename": 1, "summary": 1}
|
| 62 |
-
).sort("filename", ASCENDING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
def vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int = 6, filenames: Optional[List[str]] = None):
|
| 65 |
if USE_ATLAS_VECTOR:
|
|
@@ -88,7 +124,26 @@ class RAGStore:
|
|
| 88 |
]
|
| 89 |
# Append hit scoring algorithm
|
| 90 |
hits = list(self.chunks.aggregate(pipeline))
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
else:
|
| 93 |
# Fallback: scan limited sample and compute cosine locally
|
| 94 |
q = {"user_id": user_id, "project_id": project_id}
|
|
@@ -107,7 +162,25 @@ class RAGStore:
|
|
| 107 |
scores.sort(key=lambda x: x[0], reverse=True)
|
| 108 |
top = scores[:k]
|
| 109 |
logger.info(f"Vector search sample={len(sample)} returned top={len(top)}")
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
def ensure_indexes(store: RAGStore):
|
|
|
|
| 49 |
if filename:
|
| 50 |
q["filename"] = filename
|
| 51 |
cur = self.chunks.find(q, {"embedding": 0}).skip(skip).limit(limit).sort([("_id", ASCENDING)])
|
| 52 |
+
# Convert MongoDB documents to JSON-serializable format
|
| 53 |
+
cards = []
|
| 54 |
+
for card in cur:
|
| 55 |
+
serializable_card = {}
|
| 56 |
+
for key, value in card.items():
|
| 57 |
+
if key == '_id':
|
| 58 |
+
serializable_card[key] = str(value) # Convert ObjectId to string
|
| 59 |
+
elif hasattr(value, 'isoformat'): # Handle datetime objects
|
| 60 |
+
serializable_card[key] = value.isoformat()
|
| 61 |
+
else:
|
| 62 |
+
serializable_card[key] = value
|
| 63 |
+
cards.append(serializable_card)
|
| 64 |
+
return cards
|
| 65 |
|
| 66 |
def get_file_summary(self, user_id: str, project_id: str, filename: str):
|
| 67 |
+
doc = self.files.find_one({"user_id": user_id, "project_id": project_id, "filename": filename})
|
| 68 |
+
if doc:
|
| 69 |
+
# Convert MongoDB document to JSON-serializable format
|
| 70 |
+
serializable_doc = {}
|
| 71 |
+
for key, value in doc.items():
|
| 72 |
+
if key == '_id':
|
| 73 |
+
serializable_doc[key] = str(value) # Convert ObjectId to string
|
| 74 |
+
elif hasattr(value, 'isoformat'): # Handle datetime objects
|
| 75 |
+
serializable_doc[key] = value.isoformat()
|
| 76 |
+
else:
|
| 77 |
+
serializable_doc[key] = value
|
| 78 |
+
return serializable_doc
|
| 79 |
+
return None
|
| 80 |
|
| 81 |
def list_files(self, user_id: str, project_id: str):
|
| 82 |
"""List all files for a project with their summaries"""
|
| 83 |
+
files_cursor = self.files.find(
|
| 84 |
{"user_id": user_id, "project_id": project_id},
|
| 85 |
{"_id": 0, "filename": 1, "summary": 1}
|
| 86 |
+
).sort("filename", ASCENDING)
|
| 87 |
+
|
| 88 |
+
# Convert MongoDB documents to JSON-serializable format
|
| 89 |
+
files = []
|
| 90 |
+
for file_doc in files_cursor:
|
| 91 |
+
serializable_file = {}
|
| 92 |
+
for key, value in file_doc.items():
|
| 93 |
+
if hasattr(value, 'isoformat'): # Handle datetime objects
|
| 94 |
+
serializable_file[key] = value.isoformat()
|
| 95 |
+
else:
|
| 96 |
+
serializable_file[key] = value
|
| 97 |
+
files.append(serializable_file)
|
| 98 |
+
return files
|
| 99 |
|
| 100 |
def vector_search(self, user_id: str, project_id: str, query_vector: List[float], k: int = 6, filenames: Optional[List[str]] = None):
|
| 101 |
if USE_ATLAS_VECTOR:
|
|
|
|
| 124 |
]
|
| 125 |
# Append hit scoring algorithm
|
| 126 |
hits = list(self.chunks.aggregate(pipeline))
|
| 127 |
+
|
| 128 |
+
# Convert MongoDB documents to JSON-serializable format
|
| 129 |
+
serializable_hits = []
|
| 130 |
+
for hit in hits:
|
| 131 |
+
doc = hit["doc"]
|
| 132 |
+
serializable_doc = {}
|
| 133 |
+
for key, value in doc.items():
|
| 134 |
+
if key == '_id':
|
| 135 |
+
serializable_doc[key] = str(value) # Convert ObjectId to string
|
| 136 |
+
elif hasattr(value, 'isoformat'): # Handle datetime objects
|
| 137 |
+
serializable_doc[key] = value.isoformat()
|
| 138 |
+
else:
|
| 139 |
+
serializable_doc[key] = value
|
| 140 |
+
|
| 141 |
+
serializable_hits.append({
|
| 142 |
+
"doc": serializable_doc,
|
| 143 |
+
"score": float(hit["score"]) # Ensure score is a regular float
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
return serializable_hits
|
| 147 |
else:
|
| 148 |
# Fallback: scan limited sample and compute cosine locally
|
| 149 |
q = {"user_id": user_id, "project_id": project_id}
|
|
|
|
| 162 |
scores.sort(key=lambda x: x[0], reverse=True)
|
| 163 |
top = scores[:k]
|
| 164 |
logger.info(f"Vector search sample={len(sample)} returned top={len(top)}")
|
| 165 |
+
|
| 166 |
+
# Convert MongoDB documents to JSON-serializable format
|
| 167 |
+
serializable_results = []
|
| 168 |
+
for score, doc in top:
|
| 169 |
+
serializable_doc = {}
|
| 170 |
+
for key, value in doc.items():
|
| 171 |
+
if key == '_id':
|
| 172 |
+
serializable_doc[key] = str(value) # Convert ObjectId to string
|
| 173 |
+
elif hasattr(value, 'isoformat'): # Handle datetime objects
|
| 174 |
+
serializable_doc[key] = value.isoformat()
|
| 175 |
+
else:
|
| 176 |
+
serializable_doc[key] = value
|
| 177 |
+
|
| 178 |
+
serializable_results.append({
|
| 179 |
+
"doc": serializable_doc,
|
| 180 |
+
"score": float(score) # Ensure score is a regular float
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
return serializable_results
|
| 184 |
|
| 185 |
|
| 186 |
def ensure_indexes(store: RAGStore):
|