import time import pandas as pd import numpy as np import random from typing import Literal import chromadb import re, unicodedata from config import SanatanConfig from embeddings import get_embedding import logging from metadata import MetadataWhereClause from modules.db.relevance import validate_relevance_queryresult from tqdm import tqdm import nalayiram_helper logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class SanatanDatabase: _instance = None def __new__(cls, *args, **kwargs): # ✅ Ensure only one instance exists if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init_once() return cls._instance def _init_once(self): """Initialize once per process""" self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) self._count_cache = {} # {collection_name: (timestamp, count)} self._cache_ttl = 84600 # seconds (24 hours) logger.info("✅ SanatanDatabase singleton initialized") def does_data_exist(self, collection_name: str) -> bool: collection = self.chroma_client.get_or_create_collection(name=collection_name) num_rows = collection.count() logger.info("num_rows in %s = %d", collection_name, num_rows) return num_rows > 0 def load(self, collection_name: str, ids, documents, embeddings, metadatas): collection = self.chroma_client.get_or_create_collection(name=collection_name) collection.add( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas, ) def get(self, collection_name: str, where, n_results=5): collection = self.chroma_client.get_or_create_collection(name=collection_name) return collection.get(where=where, limit=n_results) def fetch_random_data( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None, n_results=1, ): # fetch all documents once logger.info( "getting %d random verses from [%s] | metadata_where_clause = %s", n_results, collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) data = collection.get( include=["metadatas", "documents"], where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), ) docs = data["documents"] # list of all verse texts ids = data["ids"] metas = data["metadatas"] if not docs: logger.warning("No data found! - data=%s", data) return chromadb.QueryResult(ids=[], documents=[], metadatas=[]) # pick k random indices indices = random.sample(range(len(docs)), k=min(n_results, len(docs))) return chromadb.QueryResult( ids=[ids[i] for i in indices], documents=[docs[i] for i in indices], metadatas=[metas[i] for i in indices], ) def fetch_first_match( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None ): """This version is created to support the browse module with fallback regex matching""" def normalize_for_match(s: str) -> str: # Convert to canonical decomposed form (NFD), then strip combining marks s = unicodedata.normalize("NFD", s) s = "".join(ch for ch in s if not unicodedata.combining(ch)) return s logger.info( "getting first matching verses from [%s] | metadata_where_clause = %s", collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) where_clause = ( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ) # If the conversion returns an empty dict, treat it as None if isinstance(where_clause, dict) and not where_clause: where_clause = None data = collection.get(include=["metadatas", "documents"], where=where_clause) if data["metadatas"]: # ✅ normal path min_index = min( range(len(data["metadatas"])), key=lambda i: data["metadatas"][i].get("_global_index", float("inf")), ) return { "ids": [data["ids"][min_index]], "documents": [data["documents"][min_index]], "metadatas": [data["metadatas"][min_index]], } # ⚠️ fallback path logger.warning("No data found using strict filter. Attempting regex fallback.") if not metadata_where_clause or not metadata_where_clause.filters: return chromadb.GetResult(ids=[], documents=[], metadatas=[]) # find filters with $eq string type regex_filters = [ f for f in metadata_where_clause.filters if f.metadata_search_operator == "$eq" and isinstance(f.metadata_value, str) ] if not regex_filters: return chromadb.GetResult(ids=[], documents=[], metadatas=[]) # Pull all documents for manual regex scan all_data = collection.get(include=["metadatas", "documents"]) matched_indices = [] for i, meta in enumerate(all_data["metadatas"]): ok = True for f in regex_filters: field_val = str(meta.get(f.metadata_field, "")) # Normalize both the stored field and the search value norm_val = normalize_for_match(field_val) norm_query = normalize_for_match(f.metadata_value) # Do case-insensitive substring/regex search if not re.search(re.escape(norm_query), norm_val, flags=re.IGNORECASE): ok = False break if ok: matched_indices.append(i) if not matched_indices: logger.warning("Regex fallback also found no matches.") return chromadb.GetResult(ids=[], documents=[], metadatas=[]) # Pick lowest _global_index among matches min_index = min( matched_indices, key=lambda i: all_data["metadatas"][i].get("_global_index", float("inf")), ) return { "ids": [all_data["ids"][min_index]], "documents": [all_data["documents"][min_index]], "metadatas": [all_data["metadatas"][min_index]], } def count_where( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None, ) -> int: """ Count the number of matching verses in the collection without fetching documents. Uses the same filtering and fallback logic as fetch_all_matches. """ def normalize_for_match(s: str) -> str: s = unicodedata.normalize("NFD", s) s = "".join(ch for ch in s if not unicodedata.combining(ch)) return s logger.info( "count_where: counting matches in [%s] | filters=%s", collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) where_clause = ( metadata_where_clause.to_chroma_where() if metadata_where_clause else None ) # If conversion returns an empty dict, treat as None if isinstance(where_clause, dict) and not where_clause: where_clause = None # Strict filter first data = collection.get(include=["metadatas"], where=where_clause) if not data["metadatas"]: # fallback regex logger.warning("count_where: No matches found with strict filter. Trying regex fallback.") if not metadata_where_clause or not metadata_where_clause.filters: return 0 regex_filters = [ f for f in metadata_where_clause.filters if f.metadata_search_operator == "$eq" and isinstance(f.metadata_value, str) ] if regex_filters: all_data = collection.get(include=["metadatas"]) matched_count = 0 for meta in all_data["metadatas"]: ok = True for f in regex_filters: field_val = str(meta.get(f.metadata_field, "")) norm_val = normalize_for_match(field_val) norm_query = normalize_for_match(f.metadata_value) if not re.search( re.escape(norm_query), norm_val, flags=re.IGNORECASE ): ok = False break if ok: matched_count += 1 return matched_count else: return 0 # Direct count return len(data["metadatas"]) def fetch_all_matches( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None, page: int = 1, page_size: int = 20, ): """ Fetch all matching verses from the collection with optional pagination, sorted by _global_index ascending. If page or page_size is None, return all results without pagination. """ def normalize_for_match(s: str) -> str: s = unicodedata.normalize("NFD", s) s = "".join(ch for ch in s if not unicodedata.combining(ch)) return s logger.info( "fetching all matches from [%s] | filters=%s | page=%s | page_size=%s", collection_name, metadata_where_clause, page, page_size, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) where_clause = ( metadata_where_clause.to_chroma_where() if metadata_where_clause else None ) # If the conversion returns an empty dict, treat it as None if isinstance(where_clause, dict) and not where_clause: where_clause = None # First, try strict filter data = collection.get(include=["metadatas", "documents"], where=where_clause) if not data["metadatas"]: # fallback regex logger.warning("No data found using strict filter. Trying regex fallback.") if not metadata_where_clause or not metadata_where_clause.filters: return {"ids": [], "documents": [], "metadatas": [], "total_matches": 0} regex_filters = [ f for f in metadata_where_clause.filters if f.metadata_search_operator == "$eq" and isinstance(f.metadata_value, str) ] if regex_filters: all_data = collection.get(include=["metadatas", "documents"]) matched_indices = [] for i, meta in enumerate(all_data["metadatas"]): ok = True for f in regex_filters: field_val = str(meta.get(f.metadata_field, "")) norm_val = normalize_for_match(field_val) norm_query = normalize_for_match(f.metadata_value) if not re.search( re.escape(norm_query), norm_val, flags=re.IGNORECASE ): ok = False break if ok: matched_indices.append(i) data = { "ids": [all_data["ids"][i] for i in matched_indices], "documents": [all_data["documents"][i] for i in matched_indices], "metadatas": [all_data["metadatas"][i] for i in matched_indices], } total_matches = len(data["ids"]) if total_matches == 0: return {"ids": [], "documents": [], "metadatas": [], "total_matches": 0} # --- Sort by _global_index ascending --- combined = list(zip(data["ids"], data["documents"], data["metadatas"])) combined.sort(key=lambda x: x[2].get("_global_index", float("inf"))) ids_sorted, documents_sorted, metadatas_sorted = zip(*combined) # --- Apply pagination only if both page and page_size are not None --- if page is not None and page_size is not None: start = (page - 1) * page_size end = start + page_size paged_data = { "ids": list(ids_sorted[start:end]), "documents": list(documents_sorted[start:end]), "metadatas": list(metadatas_sorted[start:end]), "total_matches": total_matches, } return paged_data else: # Return all results return { "ids": list(ids_sorted), "documents": list(documents_sorted), "metadatas": list(metadatas_sorted), "total_matches": total_matches, } def search( self, collection_name: str, query: str = None, metadata_where_clause: MetadataWhereClause = None, n_results=2, search_type: Literal["semantic", "literal", "random"] = "semantic", ): logger.info( "Search for [%s] in [%s]| metadata_where_clause=%s | search_type=%s | n_results=%d", query, collection_name, metadata_where_clause, search_type, n_results, ) if search_type == "semantic": return self.search_semantic( collection_name=collection_name, query=query, metadata_where_clause=metadata_where_clause, n_results=n_results, ) elif search_type == "literal": return self.search_for_literal( collection_name=collection_name, literal_to_search_for=query, metadata_where_clause=metadata_where_clause, n_results=n_results, ) else: # random return self.fetch_random_data( collection_name=collection_name, metadata_where_clause=metadata_where_clause, n_results=n_results, ) def fetch_document_by_index(self, collection_name: str, index: int): """ Fetch one document at a time from a ChromaDB collection using pagination (index = 0-based). Args: collection_name: Name of the ChromaDB collection. index: Zero-based index of the document to fetch. Returns: dict: { "document": , : , : , ... } Or a dict with "error" key if something went wrong. """ logger.info("fetching index %d from [%s]", index, collection_name) collection = self.chroma_client.get_or_create_collection(name=collection_name) try: response = collection.get( limit=1, # offset=index, # pagination via offset include=["metadatas", "documents"], where={"_global_index": index}, ) except Exception as e: logger.error("Error fetching document: %s", e, exc_info=True) return {"error": f"There was an error fetching the document: {str(e)}"} documents = response.get("documents", []) metadatas = response.get("metadatas", []) ids = response.get("ids", []) if documents: # merge document text with metadata result = {"document": documents[0]} if metadatas: result.update(metadatas[0]) if ids: result["id"] = ids[0] # print("raw data = ", result) return result else: print("No data available") # show a sample data record response1 = collection.get( limit=2, # offset=index, # pagination via offset include=["metadatas", "documents"], ) # print("sample data : ", response1) return {"error": "No data available."} def search_semantic( self, collection_name: str, query: str | None = None, metadata_where_clause: MetadataWhereClause | None = None, n_results=2, ): logger.info( "Vector Semantic Search for [%s] in [%s] | metadata_where_clause = %s", query, collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) try: q = query.strip() if query is not None else "" if not q: # fallback: fetch random verse return self.fetch_random_data( collection_name=collection_name, metadata_where_clause=metadata_where_clause, n_results=n_results, ) else: response = collection.query( query_embeddings=get_embedding( [query], SanatanConfig().get_embedding_for_collection(collection_name), ), # query_texts=[query], n_results=n_results, where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), include=["metadatas", "documents", "distances"], ) except Exception as e: logger.error("Error in search: %s", e, exc_info=True) return chromadb.QueryResult( documents=[], ids=[], metadatas=[], distances=[], ) validated_response = validate_relevance_queryresult(query, response) logger.info( "status = %s | reason= %s", validated_response.status, validated_response.reason, ) return validated_response.result def search_for_literal( self, collection_name: str, literal_to_search_for: str | None = None, metadata_where_clause: MetadataWhereClause | None = None, n_results=2, ): logger.info( "Searching literally for [%s] in [%s] | metadata_where_clause = %s", literal_to_search_for, collection_name, metadata_where_clause, ) if literal_to_search_for is None or literal_to_search_for.strip() == "": logger.warning("Nothing to search literally.") raise Exception("query cannot be None or empty for a literal search!") # return self.fetch_random_data( # collection_name=collection_name, # ) collection = self.chroma_client.get_or_create_collection(name=collection_name) def normalize(text): return unicodedata.normalize("NFKC", text).lower() # 1. Try native contains response = collection.get( where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), where_document={"$contains": literal_to_search_for}, limit=n_results, ) if response["documents"] and any(response["documents"]): return chromadb.QueryResult( ids=response["ids"], documents=response["documents"], metadatas=response["metadatas"], ) # 2. Regex fallback (normalized) logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) regex = re.compile(re.escape(normalize(literal_to_search_for))) logger.info("regex = %s", regex) all_docs = collection.get( where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), ) matched_docs = [] for doc_list, metadata_list, doc_id_list in zip( all_docs["documents"], all_docs["metadatas"], all_docs["ids"] ): # Ensure all are lists if isinstance(doc_list, str): doc_list = [doc_list] if isinstance(metadata_list, dict): metadata_list = [metadata_list] if isinstance(doc_id_list, str): doc_id_list = [doc_id_list] for i in range(len(doc_list)): d = doc_list[i] current_metadata = metadata_list[i] current_id = doc_id_list[i] doc_match = regex.search(normalize(d)) metadata_match = False for key, value in current_metadata.items(): if isinstance(value, str) and regex.search(normalize(value)): metadata_match = True break elif isinstance(value, list): if any( isinstance(v, str) and regex.search(normalize(v)) for v in value ): metadata_match = True break if doc_match or metadata_match: matched_docs.append( { "id": current_id, "document": d, "metadata": current_metadata, } ) if len(matched_docs) >= n_results: break if len(matched_docs) >= n_results: break return chromadb.QueryResult( { "documents": [[d["document"] for d in matched_docs]], "ids": [[d["id"] for d in matched_docs]], "metadatas": [[d["metadata"] for d in matched_docs]], } ) def count(self, collection_name: str): # check cache first now = time.time() cached_entry = self._count_cache.get(collection_name) if cached_entry: ts, cached_count = cached_entry if now - ts < self._cache_ttl and cached_count > 0: logger.debug("Cache hit for collection [%s]: %d", collection_name, cached_count) return cached_count else: logger.debug("Cache expired for [%s]", collection_name) # fetch fresh count collection = self.chroma_client.get_or_create_collection(name=collection_name) total_count = collection.count() logger.info("Total records in [%s] = %d", collection_name, total_count) # update cache self._count_cache[collection_name] = (now, total_count) return total_count def test_sanity(self): for scripture in SanatanConfig().scriptures: count = self.count(collection_name=scripture["collection_name"]) if count == 0: raise Exception(f"No data in collection {scripture["collection_name"]}") def reembed_collection_openai(self, collection_name: str, batch_size: int = 50): """ Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings. All existing documents are re-embedded and inserted into the new collection. Args: collection_name: The name of the collection to delete/recreate. batch_size: Number of documents to process per batch. """ # Step 1: Fetch old collection data (if exists) try: old_collection = self.chroma_client.get_collection(name=collection_name) old_data = old_collection.get(include=["documents", "metadatas"]) documents = old_data["documents"] metadatas = old_data["metadatas"] ids = old_data["ids"] print(f"Fetched {len(documents)} documents from old collection.") # Step 2: Delete old collection # self.chroma_client.delete_collection(collection_name) # print(f"Deleted old collection '{collection_name}'.") except chromadb.errors.NotFoundError: print(f"No existing collection named '{collection_name}', starting fresh.") documents, metadatas, ids = [], [], [] # Step 3: Create new collection with correct embedding dimension new_collection = self.chroma_client.create_collection( name=f"{collection_name}_openai", embedding_function=None, # embeddings will be provided manually ) print( f"Created new collection '{collection_name}_openai' with embedding_dim=3072." ) # Step 4: Re-embed and insert documents in batches for i in tqdm( range(0, len(documents), batch_size), desc="Re-embedding batches" ): batch_docs = documents[i : i + batch_size] batch_metadatas = metadatas[i : i + batch_size] batch_ids = ids[i : i + batch_size] embeddings = get_embedding(batch_docs, backend="openai") new_collection.add( ids=batch_ids, documents=batch_docs, metadatas=batch_metadatas, embeddings=embeddings, ) print("All documents re-embedded and added to new collection successfully!") def add_unit_index_to_collection(self, collection_name: str, unit_field: str): if collection_name != "yt_metadata": # safeguard just incase return collection = self.chroma_client.get_collection(name=collection_name) # fetch everything in batches (in case your collection is large) batch_size = 100 offset = 0 unit_counter = 1 while True: result = collection.get( limit=batch_size, offset=offset, include=["documents", "metadatas", "embeddings"], ) ids = result["ids"] if not ids: break # no more docs docs = result["documents"] metas = result["metadatas"] embeddings = result["embeddings"] # add unit_index to metadata updated_metas = [] for meta in metas: # ensure meta is not None m = meta.copy() if meta else {} m[unit_field] = unit_counter updated_metas.append(m) unit_counter += 1 # upsert with same IDs (will overwrite metadata but keep same id+doc) collection.upsert( ids=ids, documents=docs, metadatas=updated_metas, embeddings=embeddings, ) offset += batch_size print( f"✅ Finished adding {unit_field} to {unit_counter-1} documents in {collection_name}." ) def get_list_of_values( self, collection_name: str, metadata_field_name: str ) -> list: """ Returns the unique values for a given metadata field in a collection. """ # Get the collection collection = self.chroma_client.get_or_create_collection(name=collection_name) # Fetch all metadata from the collection query_result = collection.get(include=["metadatas"]) values = set() # use a set to automatically deduplicate metadatas = query_result.get("metadatas", []) if metadatas: # Handle both flat list and nested list formats if isinstance(metadatas[0], dict): # flat list of dicts for md in metadatas: if metadata_field_name in md: values.add(md[metadata_field_name]) elif isinstance(metadatas[0], list): # nested list for md_list in metadatas: for md in md_list: if metadata_field_name in md: values.add(md[metadata_field_name]) return sorted(list(values)) def build_global_index_for_scripture(self, scripture: dict, force: bool = False): scripture_name = scripture["name"] chapter_order = scripture.get("chapter_order", None) # if scripture_name != "vishnu_sahasranamam": # continue logger.info( "build_global_index_for_all_scriptures:%s: Processing", scripture_name ) collection_name = scripture["collection_name"] collection = self.chroma_client.get_or_create_collection(name=collection_name) metadata_fields = scripture.get("metadata_fields", []) # Get metadata field names marked as unique unique_fields = [f["name"] for f in metadata_fields if f.get("is_unique")] if not unique_fields: if metadata_fields: unique_fields = [metadata_fields[0]["name"]] else: logger.warning( f"No metadata fields defined for {collection_name}, skipping" ) return logger.info( "build_global_index_for_all_scriptures:%s:unique fields: %s", scripture_name, unique_fields, ) # Build chapter_order mapping if defined chapter_order_mapping = {} for field in metadata_fields: if callable(chapter_order): chapter_order_mapping = chapter_order() logger.info( "build_global_index_for_all_scriptures:%s:chapter_order_mapping: %s", scripture_name, chapter_order_mapping, ) # Fetch all records (keep embeddings for upsert) try: results = collection.get(include=["metadatas", "documents", "embeddings"]) except Exception as e: logger.error( "build_global_index_for_all_scriptures:%s Error getting data from chromadb", scripture_name, exc_info=True, ) return ids = results["ids"] metadatas = results["metadatas"] documents = results["documents"] embeddings = results.get("embeddings", [None] * len(ids)) if not force and metadatas and "_global_index" in metadatas[0]: logger.warning( "build_global_index_for_all_scriptures:%s: global index already available. skipping collection", scripture_name, ) return # Create a DataFrame for metadata sorting df = pd.DataFrame(metadatas) df["_id"] = ids df["_doc"] = documents logger.info( "build_global_index_for_all_scriptures:%s:unique_fields: %s", scripture_name, unique_fields, ) # Add sortable columns for each unique field for field_name in unique_fields: if field_name.lower() in ("chapter","prabandham_name") and chapter_order_mapping: logger.info( "build_global_index_for_all_scriptures:%s:sorting", scripture_name, ) # Map chapter names to their defined order df["_sort_" + field_name] = ( df[field_name].map(chapter_order_mapping).fillna(np.inf) ) else: # Try numeric, fallback to string lowercase def parse_val(v): if v is None: return float("inf") if isinstance(v, int): return v if isinstance(v, str): v = v.strip() return int(v) if v.isdigit() else v.lower() return str(v) df["_sort_" + field_name] = df[field_name].apply(parse_val) sort_cols = ["_sort_" + f for f in unique_fields] logger.info( "build_global_index_for_all_scriptures:%s:sort_cols=%s", scripture_name, sort_cols ) df = df.sort_values(by=sort_cols, kind="stable").reset_index(drop=True) # Assign global index df["_global_index"] = range(1, len(df) + 1) logger.info( "build_global_index_for_all_scriptures:%s: updating database", scripture_name, ) # Batch upsert BATCH_SIZE = 5000 # safely below max batch size for i in range(0, len(df), BATCH_SIZE): batch_df = df.iloc[i : i + BATCH_SIZE] batch_ids = batch_df["_id"].tolist() batch_docs = batch_df["_doc"].tolist() batch_metas = [ {k: record[k] for k in metadatas[0].keys() if k in record} | {"_global_index": record["_global_index"]} for record in batch_df.to_dict(orient="records") ] # Use original metadata keys for upsert batch_metas = [ {k: record[k] for k in metadatas[0].keys() if k in record} | {"_global_index": record["_global_index"]} for record in batch_df.to_dict(orient="records") ] batch_embeds = [embeddings[idx] for idx in batch_df.index] collection.update( ids=batch_ids, # documents=batch_docs, metadatas=batch_metas, # embeddings=batch_embeds, ) logger.info( "build_global_index_for_all_scriptures:%s: ✅ Updated with %d records", scripture_name, len(df), ) def build_global_index_for_all_scriptures(self, force: bool = False): logger.info("build_global_index_for_all_scriptures: started") config = SanatanConfig() for scripture in config.scriptures: self.build_global_index_for_scripture(scripture=scripture, force=force) def fix_taniyans_in_divya_prabandham(self): nalayiram_helper.reorder_taniyan( self.chroma_client.get_collection("divya_prabandham") ) def delete_taniyans_in_divya_prabandham(self): nalayiram_helper.delete_taniyan( self.chroma_client.get_collection("divya_prabandham") )