pyqsprag / rag_core.py
Athul Nambiar
Fix HF Spaces deployment timeout issues
fe656c3
#!/usr/bin/env python3
"""
QUADRANT RAG Core Module
Clean RAG implementation without Flask dependencies
Optimized for both Streamlit and Flask integration
"""
import os
import json
import uuid
import re
import time
from typing import List, Dict, Any, Optional
from pathlib import Path
from datetime import datetime, timezone
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue, PayloadSchemaType
import openai
# Load environment variables
load_dotenv()
class DynamicRAG:
"""
Dynamic RAG System with Qdrant Vector Database and OpenAI GPT-4o-mini
Real semantic search with proper LLM responses
"""
def __init__(self):
# Environment variables
self.openai_api_key = os.environ.get('OPENAI_API_KEY', 'your-openai-api-key-here')
self.use_memory_db = os.environ.get('USE_MEMORY_DB', 'false').lower() == 'true'
self.qdrant_url = os.environ.get('QDRANT_URL')
self.qdrant_api_key = os.environ.get('QDRANT_API_KEY')
self.collection_name = os.environ.get('QDRANT_COLLECTION_NAME', 'documents')
# Initialize clients
self._init_openai()
self._init_qdrant()
self._init_embedding_model()
# Ensure collection exists
self.ensure_collection()
def _init_openai(self):
"""Initialize OpenAI client"""
try:
if not self.openai_api_key or self.openai_api_key == 'your-openai-api-key-here':
print("❌ OpenAI API key not provided. Please set OPENAI_API_KEY environment variable.")
self.openai_client = None
return
openai.api_key = self.openai_api_key
self.openai_client = openai.OpenAI(api_key=self.openai_api_key)
print("βœ… OpenAI client initialized")
except Exception as e:
print(f"⚠️ OpenAI initialization error: {e}")
self.openai_client = None
def _init_qdrant(self):
"""Initialize Qdrant client with cloud priority"""
try:
# Configure client timeouts and transport - use shorter timeout for HF Spaces
default_timeout = '30' if os.environ.get('SPACE_ID') else '60'
qdrant_timeout = float(os.environ.get('QDRANT_TIMEOUT', default_timeout))
prefer_grpc = os.environ.get('QDRANT_PREFER_GRPC', 'false').lower() == 'true'
if self.qdrant_url and self.qdrant_api_key:
print(f"🌐 Using Qdrant Cloud: {self.qdrant_url}")
self.qdrant_client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key,
timeout=qdrant_timeout,
prefer_grpc=prefer_grpc,
)
elif self.use_memory_db:
print("πŸ’Ύ Using in-memory Qdrant (development only)")
self.qdrant_client = QdrantClient(":memory:", timeout=qdrant_timeout)
else:
# Fallback to local file storage
storage_path = os.environ.get('QDRANT_STORAGE_PATH', './qdrant_storage')
print(f"πŸ—„οΈ Using file-based Qdrant storage: {storage_path}")
self.qdrant_client = QdrantClient(path=storage_path, timeout=qdrant_timeout)
print(f"βœ… Qdrant client initialized (timeout={qdrant_timeout}s, gRPC preferred={prefer_grpc})")
except Exception as e:
print(f"❌ Qdrant initialization error: {e}")
raise
def _init_embedding_model(self):
"""Initialize OpenAI embedding model settings"""
try:
print("πŸ”„ Configuring OpenAI embeddings...")
self.embedding_model_name = 'text-embedding-3-small'
self.embedding_size = 1536 # OpenAI text-embedding-3-small dimension
# Chat model can be overridden via env; default per user request
self.chat_model_name = os.environ.get('OPENAI_COMPLETIONS_MODEL', 'gpt-4o-mini')
print("βœ… OpenAI embeddings configured")
except Exception as e:
print(f"❌ Embedding configuration error: {e}")
raise
def ensure_collection(self):
"""Create Qdrant collection if it doesn't exist"""
try:
collections = self.qdrant_client.get_collections().collections
collection_names = [c.name for c in collections]
if self.collection_name not in collection_names:
print(f"πŸ”„ Creating Qdrant collection: {self.collection_name}")
self.qdrant_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.embedding_size, distance=Distance.COSINE)
)
print("βœ… Collection created")
else:
print(f"βœ… Collection {self.collection_name} already exists")
# Create payload index for doc_id to enable filtering
try:
self.qdrant_client.create_payload_index(
collection_name=self.collection_name,
field_name="doc_id",
field_schema=PayloadSchemaType.KEYWORD
)
print("βœ… Created index for doc_id field")
except Exception as e:
# Index might already exist, which is fine
if "already exists" not in str(e):
print(f"⚠️ Note: Could not create index for doc_id: {e}")
except Exception as e:
print(f"⚠️ Error with collection: {e}")
def create_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Create embeddings for texts using OpenAI API with batch processing"""
# Handle empty texts
texts = [text if text.strip() else "empty" for text in texts]
# Process in batches to avoid timeout
batch_size = 20 # OpenAI recommends smaller batches
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
retries = 3
while retries > 0:
try:
# Create embeddings for this batch
response = self.openai_client.embeddings.create(
model=self.embedding_model_name,
input=batch
)
# Extract embedding vectors
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
# Show progress
progress = min(i + batch_size, len(texts))
print(f" βœ… Processed {progress}/{len(texts)} texts")
break
except Exception as e:
retries -= 1
if retries > 0:
print(f" ⚠️ Retry {3-retries}/3 for batch {i//batch_size + 1}: {str(e)}")
time.sleep(2) # Wait before retry
else:
print(f" ❌ Failed batch {i//batch_size + 1}: {str(e)}")
# Return zero vectors for failed batch
all_embeddings.extend([[0.0] * self.embedding_size for _ in batch])
return all_embeddings
def store_document(self, doc_id: str, chunks: List[Dict[str, Any]]):
"""Store document chunks in Qdrant with embeddings"""
print(f"πŸ”„ Processing {len(chunks)} chunks...")
# Check if chunks already exist for this document
try:
existing = self.qdrant_client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
must=[FieldCondition(key="doc_id", match=MatchValue(value=doc_id))]
),
limit=1
)
if existing[0]:
print(f"⚠️ Document {doc_id} already exists. Skipping...")
return
except:
pass # Collection might be empty
print(f"🧠 Creating embeddings for {len(chunks)} chunks...")
texts = [chunk['text'] for chunk in chunks]
embeddings = self.create_embeddings(texts)
print(f"πŸ“¦ Preparing vectors for storage...")
points = []
upload_time = datetime.now(timezone.utc).isoformat()
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Generate a proper UUID for each point
point_id = str(uuid.uuid4())
point = PointStruct(
id=point_id,
vector=embedding,
payload={
"doc_id": doc_id,
"chunk_id": i,
"text": chunk['text'],
"page": chunk['page'],
"section": chunk.get('section', 'Unknown'),
"upload_time": upload_time
}
)
points.append(point)
# Store in batches with retry and adaptive downsizing on timeout
default_batch_size = int(os.environ.get('QDRANT_UPSERT_BATCH', '32'))
i = 0
batch_index = 0
while i < len(points):
batch_size = min(default_batch_size, len(points) - i)
batch = points[i:i + batch_size]
attempts = 0
while attempts < 3:
try:
self.qdrant_client.upsert(
collection_name=self.collection_name,
points=batch,
)
batch_index += 1
print(f"πŸ“¦ Stored batch {batch_index}/{(len(points)+default_batch_size-1)//default_batch_size} ({len(batch)} points)")
i += batch_size
break
except Exception as e:
attempts += 1
if 'Timeout' in str(e) or 'timed out' in str(e):
# Halve the batch size and retry
new_size = max(5, batch_size // 2)
print(f"⚠️ Upsert timeout. Reducing batch from {batch_size} to {new_size} and retrying ({attempts}/3)...")
batch_size = new_size
batch = points[i:i + batch_size]
time.sleep(1.0)
continue
else:
print(f"❌ Upsert error on batch starting at {i}: {e}")
raise
print(f"βœ… Stored {len(chunks)} chunks in Qdrant")
def get_all_documents(self) -> List[Dict[str, Any]]:
"""Retrieve all unique documents from Qdrant with metadata"""
try:
print("πŸ”„ Fetching all documents from Qdrant...")
# Use scroll to get all points
all_points = []
offset = None
limit = 100
while True:
records, next_offset = self.qdrant_client.scroll(
collection_name=self.collection_name,
limit=limit,
offset=offset,
with_payload=True,
with_vectors=False
)
all_points.extend(records)
if next_offset is None:
break
offset = next_offset
# Group by doc_id to get unique documents
documents = {}
for point in all_points:
doc_id = point.payload.get('doc_id')
if doc_id and doc_id not in documents:
# Initialize document info
documents[doc_id] = {
'doc_id': doc_id,
'title': doc_id.replace('_', ' ').replace('.pdf', ''),
'chunks': 0,
'pages': set(),
'upload_time': point.payload.get('upload_time', 'Unknown')
}
if doc_id:
# Update chunk count and pages
documents[doc_id]['chunks'] += 1
page = point.payload.get('page', 0)
if page:
documents[doc_id]['pages'].add(page)
# Convert to list and finalize
result = []
for doc_id, doc_info in documents.items():
doc_info['pages'] = len(doc_info['pages']) # Convert set to count
result.append(doc_info)
# Sort by upload time (newest first)
result.sort(key=lambda x: x.get('upload_time', ''), reverse=True)
print(f"βœ… Found {len(result)} documents in Qdrant")
return result
except Exception as e:
print(f"❌ Error retrieving documents: {e}")
return []
def delete_document(self, doc_id: str) -> bool:
"""Delete all chunks for a specific document"""
try:
print(f"πŸ—‘οΈ Deleting document {doc_id}...")
self.qdrant_client.delete(
collection_name=self.collection_name,
points_selector=Filter(
must=[FieldCondition(key="doc_id", match=MatchValue(value=doc_id))]
)
)
print(f"βœ… Deleted document {doc_id}")
return True
except Exception as e:
print(f"❌ Error deleting document: {e}")
return False
def search(self, query: str, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
"""Search for relevant chunks using vector similarity with improved retrieval"""
print(f"πŸ” Searching for: '{query}'")
# Expand query for better medical term matching
expanded_query = self.expand_query(query)
print(f"πŸ” Expanded query: '{expanded_query}'")
# Primary search with expanded query
results = self._perform_search(expanded_query, doc_id, top_k)
# If no good results, try fallback searches
if not results or len([r for r in results if r['score'] > 0.15]) == 0:
print("πŸ” Trying fallback search with key terms...")
# Extract key medical terms for fallback search
key_terms = self._extract_key_terms(query)
for term in key_terms:
fallback_results = self._perform_search(term, doc_id, top_k//2)
results.extend(fallback_results)
# Remove duplicates and sort by score
seen_chunks = set()
unique_results = []
for result in results:
chunk_key = f"{result['chunk_id']}_{result['page']}"
if chunk_key not in seen_chunks:
seen_chunks.add(chunk_key)
unique_results.append(result)
# Sort by score descending
unique_results.sort(key=lambda x: x['score'], reverse=True)
# Filter results with minimum relevance score - very lenient threshold
filtered_results = [r for r in unique_results if r['score'] > 0.10]
print(f"πŸ“Š Found {len(filtered_results)} relevant chunks (score > 0.10)")
# If still no results, return top 5 results anyway for fallback
if not filtered_results and unique_results:
filtered_results = unique_results[:5]
print(f"πŸ“Š No high-relevance chunks found, using top {len(filtered_results)} results as fallback")
return filtered_results[:top_k]
def _perform_search(self, query: str, doc_id: str, limit: int) -> List[Dict[str, Any]]:
"""Perform a single search operation"""
query_embedding = self.create_embeddings([query])[0]
# If doc_id is 'any' or we want to search all documents, don't filter
if doc_id == 'any':
search_results = self.qdrant_client.query_points(
collection_name=self.collection_name,
query=query_embedding,
limit=limit,
with_payload=True
)
else:
# Filter strictly by the provided doc_id; fallback to no filter on error
try:
search_results = self.qdrant_client.query_points(
collection_name=self.collection_name,
query=query_embedding,
query_filter=Filter(
must=[FieldCondition(key="doc_id", match=MatchValue(value=doc_id))]
),
limit=limit,
with_payload=True
)
except Exception:
search_results = self.qdrant_client.query_points(
collection_name=self.collection_name,
query=query_embedding,
limit=limit,
with_payload=True
)
results = []
for result in search_results.points:
results.append({
"text": result.payload["text"],
"page": result.payload["page"],
"section": result.payload["section"],
"score": float(result.score),
"chunk_id": result.payload["chunk_id"],
"doc_id": result.payload.get("doc_id", "unknown")
})
return results
def _extract_key_terms(self, query: str) -> List[str]:
"""Extract key medical terms from query for fallback search"""
# Extract important terms
terms = []
# Medical abbreviations and key terms
medical_terms = ["acidosis", "RTA", "anion gap", "metabolic", "urine pH", "differential", "MUDPILES", "GOLDMARK"]
query_lower = query.lower()
for term in medical_terms:
if term.lower() in query_lower:
terms.append(term)
return terms[:3] # Return top 3 terms
def expand_query(self, query: str) -> str:
"""Expand query with synonyms and related terms for better search"""
# Common medical and general expansions
expansions = {
"fuo": "fever unknown origin fever of unknown origin pyrexia unexplained fever",
"classic": "classical traditional standard typical",
"nosocomial": "hospital acquired healthcare associated hospital-acquired",
"neutropenic": "neutropenia immunocompromised low neutrophil count",
"hiv": "human immunodeficiency virus AIDS HIV-associated",
"diagnostic": "diagnosis workup evaluation investigation",
"pet/ct": "PET-CT positron emission tomography computed tomography PET scan",
"pet": "positron emission tomography PET scan PET-CT",
"workup": "work up evaluation investigation diagnostic approach",
"first-line": "initial primary first line baseline",
"imaging": "radiologic radiology scan imaging studies",
"labs": "laboratory tests blood work investigations",
"categories": "types classifications groups subtypes",
"major": "main primary principal important key"
}
expanded = query.lower()
for term, expansion in expansions.items():
if term.lower() in expanded:
expanded = expanded.replace(term.lower(), f"{term.lower()} {expansion}")
return expanded
def generate_answer(self, query: str, context_chunks: List[Dict[str, Any]]) -> str:
"""Generate answer using OpenAI GPT-4o-mini with improved context"""
print(f"🧠 generate_answer called with {len(context_chunks)} chunks")
if not self.openai_client:
print("❌ OpenAI client not initialized")
return "OpenAI client not initialized. Please check your API key."
if not context_chunks:
print("❌ No context chunks provided")
return "I couldn't find any relevant information in the document to answer your question."
# Use fewer but more relevant chunks with size limit
relevant_chunks = [chunk for chunk in context_chunks if chunk['score'] > 0.3][:5]
if not relevant_chunks:
relevant_chunks = context_chunks[:3] # Fallback to top 3
context_parts = []
total_length = 0
max_context_length = 8000 # Limit context to 8K characters
# Derive source names from doc_id (strip trailing timestamp if present)
source_names = []
seen_sources = set()
for chunk in relevant_chunks:
doc_id = chunk.get('doc_id', 'unknown')
base = doc_id.rsplit('_', 1)[0] if '_' in doc_id else doc_id
if base and base not in seen_sources:
seen_sources.add(base)
source_names.append(base)
for chunk in relevant_chunks:
chunk_text = f"[Page {chunk['page']}, Score: {chunk['score']:.3f}] {chunk['text'][:1000]}..."
if total_length + len(chunk_text) > max_context_length:
break
context_parts.append(chunk_text)
total_length += len(chunk_text)
# Prepend sources and page summary to aid citations
sources_header = "; ".join(source_names) if source_names else "Unknown source"
page_list = sorted({c['page'] for c in relevant_chunks if 'page' in c})
pages_summary = ", ".join(str(p) for p in page_list)
context = (
f"Sources: {sources_header}\n"
f"Pages in retrieved context: {pages_summary}\n\n"
+ "\n\n".join(context_parts)
)
print(f"πŸ“„ Context length: {len(context)} characters")
print(f"πŸ” First chunk preview: {context_chunks[0]['text'][:100]}...")
# New system prompt (per user specification) + user content with context
system_prompt = (
"# Role and Objective\n"
"You are a senior medical tutor specializing in preparing students for Indian medical entrance exams (NEET-PG, INI-CET, FMGE).\n"
"# Instructions\n"
"- Always answer strictly based on information from standard textbooks (e.g., Harrison, Robbins, Bailey & Love, DC Dutta, Shaw, Park, Ganong, Guyton).\n"
"- If there is insufficient data available in these textbooks, respond: β€œInsufficient evidence from standard textbooks.”\n"
"- Do not fabricate or introduce non-standard material into your answers.\n"
"- Begin with a concise checklist (3-5 bullets) outlining the conceptual steps you will use to construct your answer (e.g., identify relevant information, reference textbooks, analyze options, format answer, cite sources).\n"
"## Output Format\n"
"- **Explanation:**\n"
"- Start with why the correct answer fits, using textbook references to support your explanation.\n"
"- **Why other options are wrong:**\n"
"- Briefly rule out each incorrect choice with textbook-based reasoning.\n"
"- **Clinical Pearl:**\n"
"- Highlight clinical pearls (e.g., β€œphysiologic leucorrhea never causes pruritus,” β€œmost common site of endometriosis = ovary”) as appropriate.\n"
"- **References:**\n"
"- Cite the textbook name, edition, and page number (if available). Place this section at the end of the answer, after all explanations and pearls.\n"
"- Keep explanations exam-friendly, high-yield, and structured (use short paragraphs or bullet points).\n"
"- If an image is provided, integrate it naturally into the reasoning but do not describe the image explicitlyβ€”only use it as a supportive clue.\n"
"- Keep answers concise but concept-rich, resembling a mini textbook explanation rather than a long essay.\n"
"## Reasoning Effort & Validation\n"
"- Set reasoning_effort=medium to ensure thorough but efficient explanations appropriate for exam-level concepts.\n"
"- After drafting the response, quickly validate whether all parts are completed as per the Output Format; if any part is missing or insufficiently referenced, self-correct before finalizing the answer."
)
user_content = (
f"Document Context (textbook excerpts):\n{context}\n\n"
f"Question: {query}\n\n"
"Use only the provided excerpts. When citing, include textbook name and exact page from the pages listed above."
)
try:
print("πŸ”„ Making OpenAI API call...")
params = {
"model": self.chat_model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
}
# gpt-4o models use 'max_tokens'; set temperature for consistency
params["max_tokens"] = 1500
params["temperature"] = 0.0
response = self.openai_client.chat.completions.create(**params)
# Try to extract text safely
text = ""
try:
text = (response.choices[0].message.content or "").strip()
except Exception:
text = ""
# Fallback to Responses API when empty
if not text:
try:
combined_input = system_prompt + "\n\n" + user_content
resp2 = self.openai_client.responses.create(
model=self.chat_model_name,
input=combined_input,
max_output_tokens=1500,
)
if hasattr(resp2, "output_text") and resp2.output_text:
text = resp2.output_text.strip()
elif hasattr(resp2, "choices") and resp2.choices:
m = getattr(resp2.choices[0], "message", None)
if m and getattr(m, "content", None):
text = m.content.strip()
except Exception as e2:
print(f"⚠️ Responses API fallback error: {e2}")
if not text:
raise RuntimeError("Empty response content from model")
print(f"βœ… OpenAI response received: {len(text)} characters")
print(f"πŸ“ Answer preview: {text[:100]}...")
return text
except Exception as e:
print(f"❌ OpenAI API error: {e}")
error_message = f"I found relevant information but couldn't generate a proper response due to an API error: {str(e)}"
if context_chunks:
error_message += f"\n\nHere's what I found: {context_chunks[0]['text'][:300]}... [Page {context_chunks[0]['page']}]"
return error_message
def extract_pdf_pages(pdf_path: str) -> List[str]:
"""Extract text from PDF pages"""
try:
import pypdf
reader = pypdf.PdfReader(pdf_path)
pages = []
for page in reader.pages:
try:
text = page.extract_text() or ""
text = text.strip()
if text:
pages.append(text)
except:
continue
return pages
except Exception as e:
print(f"PDF extraction error: {e}")
return []
def create_chunks(pages: List[str], chunk_size: int = 3000, overlap: int = 500) -> List[Dict[str, Any]]:
"""Create overlapping chunks from pages with optimized sizing"""
chunks = []
print(f"πŸ“„ Processing {len(pages)} pages into chunks...")
for page_num, page_text in enumerate(pages, 1):
if len(page_text) < 100: # Skip very short pages
continue
# For very long pages, split into smaller sections
if len(page_text) > chunk_size * 2:
# Split by paragraphs (double newline)
paragraphs = page_text.split('\n\n')
current_chunk = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
# If adding this paragraph exceeds chunk size, save current chunk
if len(current_chunk) + len(para) > chunk_size and current_chunk:
chunk_text = current_chunk.strip()
if len(chunk_text) > 200: # Only save substantial chunks
chunks.append({
"text": chunk_text,
"page": page_num,
"section": f"Page {page_num}"
})
# Keep last part for context
words = current_chunk.split()
if len(words) > 100:
overlap_text = ' '.join(words[-100:])
current_chunk = overlap_text + "\n\n" + para
else:
current_chunk = para
else:
current_chunk += "\n\n" + para if current_chunk else para
# Add remaining content
if current_chunk.strip() and len(current_chunk.strip()) > 200:
chunks.append({
"text": current_chunk.strip(),
"page": page_num,
"section": f"Page {page_num}"
})
else:
# For shorter pages, add the whole page as one chunk
if len(page_text.strip()) > 200:
chunks.append({
"text": page_text.strip(),
"page": page_num,
"section": f"Page {page_num}"
})
print(f"βœ… Created {len(chunks)} chunks from {len(pages)} pages")
return chunks