Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import random | |
| from typing import Literal | |
| import chromadb | |
| import re, unicodedata | |
| from config import SanatanConfig | |
| from embeddings import get_embedding | |
| import logging | |
| from metadata import MetadataWhereClause | |
| from modules.db.relevance import validate_relevance_queryresult | |
| from tqdm import tqdm | |
| import nalayiram_helper | |
| logging.basicConfig() | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class SanatanDatabase: | |
| _instance = None | |
| def __new__(cls, *args, **kwargs): | |
| # ✅ Ensure only one instance exists | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._init_once() | |
| return cls._instance | |
| def _init_once(self): | |
| """Initialize once per process""" | |
| self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) | |
| self._count_cache = {} # {collection_name: (timestamp, count)} | |
| self._cache_ttl = 84600 # seconds (24 hours) | |
| logger.info("✅ SanatanDatabase singleton initialized") | |
| def does_data_exist(self, collection_name: str) -> bool: | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| num_rows = collection.count() | |
| logger.info("num_rows in %s = %d", collection_name, num_rows) | |
| return num_rows > 0 | |
| def load(self, collection_name: str, ids, documents, embeddings, metadatas): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| collection.add( | |
| ids=ids, | |
| documents=documents, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ) | |
| def get(self, collection_name: str, where, n_results=5): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| return collection.get(where=where, limit=n_results) | |
| def fetch_random_data( | |
| self, | |
| collection_name: str, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| n_results=1, | |
| ): | |
| # fetch all documents once | |
| logger.info( | |
| "getting %d random verses from [%s] | metadata_where_clause = %s", | |
| n_results, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| data = collection.get( | |
| include=["metadatas", "documents"], | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| ) | |
| docs = data["documents"] # list of all verse texts | |
| ids = data["ids"] | |
| metas = data["metadatas"] | |
| if not docs: | |
| logger.warning("No data found! - data=%s", data) | |
| return chromadb.QueryResult(ids=[], documents=[], metadatas=[]) | |
| # pick k random indices | |
| indices = random.sample(range(len(docs)), k=min(n_results, len(docs))) | |
| return chromadb.QueryResult( | |
| ids=[ids[i] for i in indices], | |
| documents=[docs[i] for i in indices], | |
| metadatas=[metas[i] for i in indices], | |
| ) | |
| def fetch_first_match( | |
| self, collection_name: str, metadata_where_clause: MetadataWhereClause = None | |
| ): | |
| """This version is created to support the browse module with fallback regex matching""" | |
| def normalize_for_match(s: str) -> str: | |
| # Convert to canonical decomposed form (NFD), then strip combining marks | |
| s = unicodedata.normalize("NFD", s) | |
| s = "".join(ch for ch in s if not unicodedata.combining(ch)) | |
| return s | |
| logger.info( | |
| "getting first matching verses from [%s] | metadata_where_clause = %s", | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| where_clause = ( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ) | |
| # If the conversion returns an empty dict, treat it as None | |
| if isinstance(where_clause, dict) and not where_clause: | |
| where_clause = None | |
| data = collection.get(include=["metadatas", "documents"], where=where_clause) | |
| if data["metadatas"]: | |
| # ✅ normal path | |
| min_index = min( | |
| range(len(data["metadatas"])), | |
| key=lambda i: data["metadatas"][i].get("_global_index", float("inf")), | |
| ) | |
| return { | |
| "ids": [data["ids"][min_index]], | |
| "documents": [data["documents"][min_index]], | |
| "metadatas": [data["metadatas"][min_index]], | |
| } | |
| # ⚠️ fallback path | |
| logger.warning("No data found using strict filter. Attempting regex fallback.") | |
| if not metadata_where_clause or not metadata_where_clause.filters: | |
| return chromadb.GetResult(ids=[], documents=[], metadatas=[]) | |
| # find filters with $eq string type | |
| regex_filters = [ | |
| f | |
| for f in metadata_where_clause.filters | |
| if f.metadata_search_operator == "$eq" and isinstance(f.metadata_value, str) | |
| ] | |
| if not regex_filters: | |
| return chromadb.GetResult(ids=[], documents=[], metadatas=[]) | |
| # Pull all documents for manual regex scan | |
| all_data = collection.get(include=["metadatas", "documents"]) | |
| matched_indices = [] | |
| for i, meta in enumerate(all_data["metadatas"]): | |
| ok = True | |
| for f in regex_filters: | |
| field_val = str(meta.get(f.metadata_field, "")) | |
| # Normalize both the stored field and the search value | |
| norm_val = normalize_for_match(field_val) | |
| norm_query = normalize_for_match(f.metadata_value) | |
| # Do case-insensitive substring/regex search | |
| if not re.search(re.escape(norm_query), norm_val, flags=re.IGNORECASE): | |
| ok = False | |
| break | |
| if ok: | |
| matched_indices.append(i) | |
| if not matched_indices: | |
| logger.warning("Regex fallback also found no matches.") | |
| return chromadb.GetResult(ids=[], documents=[], metadatas=[]) | |
| # Pick lowest _global_index among matches | |
| min_index = min( | |
| matched_indices, | |
| key=lambda i: all_data["metadatas"][i].get("_global_index", float("inf")), | |
| ) | |
| return { | |
| "ids": [all_data["ids"][min_index]], | |
| "documents": [all_data["documents"][min_index]], | |
| "metadatas": [all_data["metadatas"][min_index]], | |
| } | |
| def count_where( | |
| self, | |
| collection_name: str, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| ) -> int: | |
| """ | |
| Count the number of matching verses in the collection without fetching documents. | |
| Uses the same filtering and fallback logic as fetch_all_matches. | |
| """ | |
| def normalize_for_match(s: str) -> str: | |
| s = unicodedata.normalize("NFD", s) | |
| s = "".join(ch for ch in s if not unicodedata.combining(ch)) | |
| return s | |
| logger.info( | |
| "count_where: counting matches in [%s] | filters=%s", | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| where_clause = ( | |
| metadata_where_clause.to_chroma_where() if metadata_where_clause else None | |
| ) | |
| # If conversion returns an empty dict, treat as None | |
| if isinstance(where_clause, dict) and not where_clause: | |
| where_clause = None | |
| # Strict filter first | |
| data = collection.get(include=["metadatas"], where=where_clause) | |
| if not data["metadatas"]: | |
| # fallback regex | |
| logger.warning("count_where: No matches found with strict filter. Trying regex fallback.") | |
| if not metadata_where_clause or not metadata_where_clause.filters: | |
| return 0 | |
| regex_filters = [ | |
| f | |
| for f in metadata_where_clause.filters | |
| if f.metadata_search_operator == "$eq" | |
| and isinstance(f.metadata_value, str) | |
| ] | |
| if regex_filters: | |
| all_data = collection.get(include=["metadatas"]) | |
| matched_count = 0 | |
| for meta in all_data["metadatas"]: | |
| ok = True | |
| for f in regex_filters: | |
| field_val = str(meta.get(f.metadata_field, "")) | |
| norm_val = normalize_for_match(field_val) | |
| norm_query = normalize_for_match(f.metadata_value) | |
| if not re.search( | |
| re.escape(norm_query), norm_val, flags=re.IGNORECASE | |
| ): | |
| ok = False | |
| break | |
| if ok: | |
| matched_count += 1 | |
| return matched_count | |
| else: | |
| return 0 | |
| # Direct count | |
| return len(data["metadatas"]) | |
| def fetch_all_matches( | |
| self, | |
| collection_name: str, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| page: int = 1, | |
| page_size: int = 20, | |
| ): | |
| """ | |
| Fetch all matching verses from the collection with optional pagination, | |
| sorted by _global_index ascending. | |
| If page or page_size is None, return all results without pagination. | |
| """ | |
| def normalize_for_match(s: str) -> str: | |
| s = unicodedata.normalize("NFD", s) | |
| s = "".join(ch for ch in s if not unicodedata.combining(ch)) | |
| return s | |
| logger.info( | |
| "fetching all matches from [%s] | filters=%s | page=%s | page_size=%s", | |
| collection_name, | |
| metadata_where_clause, | |
| page, | |
| page_size, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| where_clause = ( | |
| metadata_where_clause.to_chroma_where() if metadata_where_clause else None | |
| ) | |
| # If the conversion returns an empty dict, treat it as None | |
| if isinstance(where_clause, dict) and not where_clause: | |
| where_clause = None | |
| # First, try strict filter | |
| data = collection.get(include=["metadatas", "documents"], where=where_clause) | |
| if not data["metadatas"]: | |
| # fallback regex | |
| logger.warning("No data found using strict filter. Trying regex fallback.") | |
| if not metadata_where_clause or not metadata_where_clause.filters: | |
| return {"ids": [], "documents": [], "metadatas": [], "total_matches": 0} | |
| regex_filters = [ | |
| f | |
| for f in metadata_where_clause.filters | |
| if f.metadata_search_operator == "$eq" | |
| and isinstance(f.metadata_value, str) | |
| ] | |
| if regex_filters: | |
| all_data = collection.get(include=["metadatas", "documents"]) | |
| matched_indices = [] | |
| for i, meta in enumerate(all_data["metadatas"]): | |
| ok = True | |
| for f in regex_filters: | |
| field_val = str(meta.get(f.metadata_field, "")) | |
| norm_val = normalize_for_match(field_val) | |
| norm_query = normalize_for_match(f.metadata_value) | |
| if not re.search( | |
| re.escape(norm_query), norm_val, flags=re.IGNORECASE | |
| ): | |
| ok = False | |
| break | |
| if ok: | |
| matched_indices.append(i) | |
| data = { | |
| "ids": [all_data["ids"][i] for i in matched_indices], | |
| "documents": [all_data["documents"][i] for i in matched_indices], | |
| "metadatas": [all_data["metadatas"][i] for i in matched_indices], | |
| } | |
| total_matches = len(data["ids"]) | |
| if total_matches == 0: | |
| return {"ids": [], "documents": [], "metadatas": [], "total_matches": 0} | |
| # --- Sort by _global_index ascending --- | |
| combined = list(zip(data["ids"], data["documents"], data["metadatas"])) | |
| combined.sort(key=lambda x: x[2].get("_global_index", float("inf"))) | |
| ids_sorted, documents_sorted, metadatas_sorted = zip(*combined) | |
| # --- Apply pagination only if both page and page_size are not None --- | |
| if page is not None and page_size is not None: | |
| start = (page - 1) * page_size | |
| end = start + page_size | |
| paged_data = { | |
| "ids": list(ids_sorted[start:end]), | |
| "documents": list(documents_sorted[start:end]), | |
| "metadatas": list(metadatas_sorted[start:end]), | |
| "total_matches": total_matches, | |
| } | |
| return paged_data | |
| else: | |
| # Return all results | |
| return { | |
| "ids": list(ids_sorted), | |
| "documents": list(documents_sorted), | |
| "metadatas": list(metadatas_sorted), | |
| "total_matches": total_matches, | |
| } | |
| def search( | |
| self, | |
| collection_name: str, | |
| query: str = None, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| n_results=2, | |
| search_type: Literal["semantic", "literal", "random"] = "semantic", | |
| ): | |
| logger.info( | |
| "Search for [%s] in [%s]| metadata_where_clause=%s | search_type=%s | n_results=%d", | |
| query, | |
| collection_name, | |
| metadata_where_clause, | |
| search_type, | |
| n_results, | |
| ) | |
| if search_type == "semantic": | |
| return self.search_semantic( | |
| collection_name=collection_name, | |
| query=query, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| elif search_type == "literal": | |
| return self.search_for_literal( | |
| collection_name=collection_name, | |
| literal_to_search_for=query, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| else: | |
| # random | |
| return self.fetch_random_data( | |
| collection_name=collection_name, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| def fetch_document_by_index(self, collection_name: str, index: int): | |
| """ | |
| Fetch one document at a time from a ChromaDB collection using pagination (index = 0-based). | |
| Args: | |
| collection_name: Name of the ChromaDB collection. | |
| index: Zero-based index of the document to fetch. | |
| Returns: | |
| dict: { | |
| "document": <document_text>, | |
| <metadata_key_1>: <value>, | |
| <metadata_key_2>: <value>, | |
| ... | |
| } | |
| Or a dict with "error" key if something went wrong. | |
| """ | |
| logger.info("fetching index %d from [%s]", index, collection_name) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| try: | |
| response = collection.get( | |
| limit=1, | |
| # offset=index, # pagination via offset | |
| include=["metadatas", "documents"], | |
| where={"_global_index": index}, | |
| ) | |
| except Exception as e: | |
| logger.error("Error fetching document: %s", e, exc_info=True) | |
| return {"error": f"There was an error fetching the document: {str(e)}"} | |
| documents = response.get("documents", []) | |
| metadatas = response.get("metadatas", []) | |
| ids = response.get("ids", []) | |
| if documents: | |
| # merge document text with metadata | |
| result = {"document": documents[0]} | |
| if metadatas: | |
| result.update(metadatas[0]) | |
| if ids: | |
| result["id"] = ids[0] | |
| # print("raw data = ", result) | |
| return result | |
| else: | |
| print("No data available") | |
| # show a sample data record | |
| response1 = collection.get( | |
| limit=2, | |
| # offset=index, # pagination via offset | |
| include=["metadatas", "documents"], | |
| ) | |
| # print("sample data : ", response1) | |
| return {"error": "No data available."} | |
| def search_semantic( | |
| self, | |
| collection_name: str, | |
| query: str | None = None, | |
| metadata_where_clause: MetadataWhereClause | None = None, | |
| n_results=2, | |
| ): | |
| logger.info( | |
| "Vector Semantic Search for [%s] in [%s] | metadata_where_clause = %s", | |
| query, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| try: | |
| q = query.strip() if query is not None else "" | |
| if not q: | |
| # fallback: fetch random verse | |
| return self.fetch_random_data( | |
| collection_name=collection_name, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| else: | |
| response = collection.query( | |
| query_embeddings=get_embedding( | |
| [query], | |
| SanatanConfig().get_embedding_for_collection(collection_name), | |
| ), | |
| # query_texts=[query], | |
| n_results=n_results, | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| include=["metadatas", "documents", "distances"], | |
| ) | |
| except Exception as e: | |
| logger.error("Error in search: %s", e, exc_info=True) | |
| return chromadb.QueryResult( | |
| documents=[], | |
| ids=[], | |
| metadatas=[], | |
| distances=[], | |
| ) | |
| validated_response = validate_relevance_queryresult(query, response) | |
| logger.info( | |
| "status = %s | reason= %s", | |
| validated_response.status, | |
| validated_response.reason, | |
| ) | |
| return validated_response.result | |
| def search_for_literal( | |
| self, | |
| collection_name: str, | |
| literal_to_search_for: str | None = None, | |
| metadata_where_clause: MetadataWhereClause | None = None, | |
| n_results=2, | |
| ): | |
| logger.info( | |
| "Searching literally for [%s] in [%s] | metadata_where_clause = %s", | |
| literal_to_search_for, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| if literal_to_search_for is None or literal_to_search_for.strip() == "": | |
| logger.warning("Nothing to search literally.") | |
| raise Exception("query cannot be None or empty for a literal search!") | |
| # return self.fetch_random_data( | |
| # collection_name=collection_name, | |
| # ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| def normalize(text): | |
| return unicodedata.normalize("NFKC", text).lower() | |
| # 1. Try native contains | |
| response = collection.get( | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| where_document={"$contains": literal_to_search_for}, | |
| limit=n_results, | |
| ) | |
| if response["documents"] and any(response["documents"]): | |
| return chromadb.QueryResult( | |
| ids=response["ids"], | |
| documents=response["documents"], | |
| metadatas=response["metadatas"], | |
| ) | |
| # 2. Regex fallback (normalized) | |
| logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) | |
| regex = re.compile(re.escape(normalize(literal_to_search_for))) | |
| logger.info("regex = %s", regex) | |
| all_docs = collection.get( | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| ) | |
| matched_docs = [] | |
| for doc_list, metadata_list, doc_id_list in zip( | |
| all_docs["documents"], all_docs["metadatas"], all_docs["ids"] | |
| ): | |
| # Ensure all are lists | |
| if isinstance(doc_list, str): | |
| doc_list = [doc_list] | |
| if isinstance(metadata_list, dict): | |
| metadata_list = [metadata_list] | |
| if isinstance(doc_id_list, str): | |
| doc_id_list = [doc_id_list] | |
| for i in range(len(doc_list)): | |
| d = doc_list[i] | |
| current_metadata = metadata_list[i] | |
| current_id = doc_id_list[i] | |
| doc_match = regex.search(normalize(d)) | |
| metadata_match = False | |
| for key, value in current_metadata.items(): | |
| if isinstance(value, str) and regex.search(normalize(value)): | |
| metadata_match = True | |
| break | |
| elif isinstance(value, list): | |
| if any( | |
| isinstance(v, str) and regex.search(normalize(v)) | |
| for v in value | |
| ): | |
| metadata_match = True | |
| break | |
| if doc_match or metadata_match: | |
| matched_docs.append( | |
| { | |
| "id": current_id, | |
| "document": d, | |
| "metadata": current_metadata, | |
| } | |
| ) | |
| if len(matched_docs) >= n_results: | |
| break | |
| if len(matched_docs) >= n_results: | |
| break | |
| return chromadb.QueryResult( | |
| { | |
| "documents": [[d["document"] for d in matched_docs]], | |
| "ids": [[d["id"] for d in matched_docs]], | |
| "metadatas": [[d["metadata"] for d in matched_docs]], | |
| } | |
| ) | |
| def count(self, collection_name: str): | |
| # check cache first | |
| now = time.time() | |
| cached_entry = self._count_cache.get(collection_name) | |
| if cached_entry: | |
| ts, cached_count = cached_entry | |
| if now - ts < self._cache_ttl and cached_count > 0: | |
| logger.debug("Cache hit for collection [%s]: %d", collection_name, cached_count) | |
| return cached_count | |
| else: | |
| logger.debug("Cache expired for [%s]", collection_name) | |
| # fetch fresh count | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| total_count = collection.count() | |
| logger.info("Total records in [%s] = %d", collection_name, total_count) | |
| # update cache | |
| self._count_cache[collection_name] = (now, total_count) | |
| return total_count | |
| def test_sanity(self): | |
| for scripture in SanatanConfig().scriptures: | |
| count = self.count(collection_name=scripture["collection_name"]) | |
| if count == 0: | |
| raise Exception(f"No data in collection {scripture["collection_name"]}") | |
| def reembed_collection_openai(self, collection_name: str, batch_size: int = 50): | |
| """ | |
| Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings. | |
| All existing documents are re-embedded and inserted into the new collection. | |
| Args: | |
| collection_name: The name of the collection to delete/recreate. | |
| batch_size: Number of documents to process per batch. | |
| """ | |
| # Step 1: Fetch old collection data (if exists) | |
| try: | |
| old_collection = self.chroma_client.get_collection(name=collection_name) | |
| old_data = old_collection.get(include=["documents", "metadatas"]) | |
| documents = old_data["documents"] | |
| metadatas = old_data["metadatas"] | |
| ids = old_data["ids"] | |
| print(f"Fetched {len(documents)} documents from old collection.") | |
| # Step 2: Delete old collection | |
| # self.chroma_client.delete_collection(collection_name) | |
| # print(f"Deleted old collection '{collection_name}'.") | |
| except chromadb.errors.NotFoundError: | |
| print(f"No existing collection named '{collection_name}', starting fresh.") | |
| documents, metadatas, ids = [], [], [] | |
| # Step 3: Create new collection with correct embedding dimension | |
| new_collection = self.chroma_client.create_collection( | |
| name=f"{collection_name}_openai", | |
| embedding_function=None, # embeddings will be provided manually | |
| ) | |
| print( | |
| f"Created new collection '{collection_name}_openai' with embedding_dim=3072." | |
| ) | |
| # Step 4: Re-embed and insert documents in batches | |
| for i in tqdm( | |
| range(0, len(documents), batch_size), desc="Re-embedding batches" | |
| ): | |
| batch_docs = documents[i : i + batch_size] | |
| batch_metadatas = metadatas[i : i + batch_size] | |
| batch_ids = ids[i : i + batch_size] | |
| embeddings = get_embedding(batch_docs, backend="openai") | |
| new_collection.add( | |
| ids=batch_ids, | |
| documents=batch_docs, | |
| metadatas=batch_metadatas, | |
| embeddings=embeddings, | |
| ) | |
| print("All documents re-embedded and added to new collection successfully!") | |
| def add_unit_index_to_collection(self, collection_name: str, unit_field: str): | |
| if collection_name != "yt_metadata": | |
| # safeguard just incase | |
| return | |
| collection = self.chroma_client.get_collection(name=collection_name) | |
| # fetch everything in batches (in case your collection is large) | |
| batch_size = 100 | |
| offset = 0 | |
| unit_counter = 1 | |
| while True: | |
| result = collection.get( | |
| limit=batch_size, | |
| offset=offset, | |
| include=["documents", "metadatas", "embeddings"], | |
| ) | |
| ids = result["ids"] | |
| if not ids: | |
| break # no more docs | |
| docs = result["documents"] | |
| metas = result["metadatas"] | |
| embeddings = result["embeddings"] | |
| # add unit_index to metadata | |
| updated_metas = [] | |
| for meta in metas: | |
| # ensure meta is not None | |
| m = meta.copy() if meta else {} | |
| m[unit_field] = unit_counter | |
| updated_metas.append(m) | |
| unit_counter += 1 | |
| # upsert with same IDs (will overwrite metadata but keep same id+doc) | |
| collection.upsert( | |
| ids=ids, | |
| documents=docs, | |
| metadatas=updated_metas, | |
| embeddings=embeddings, | |
| ) | |
| offset += batch_size | |
| print( | |
| f"✅ Finished adding {unit_field} to {unit_counter-1} documents in {collection_name}." | |
| ) | |
| def get_list_of_values( | |
| self, collection_name: str, metadata_field_name: str | |
| ) -> list: | |
| """ | |
| Returns the unique values for a given metadata field in a collection. | |
| """ | |
| # Get the collection | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| # Fetch all metadata from the collection | |
| query_result = collection.get(include=["metadatas"]) | |
| values = set() # use a set to automatically deduplicate | |
| metadatas = query_result.get("metadatas", []) | |
| if metadatas: | |
| # Handle both flat list and nested list formats | |
| if isinstance(metadatas[0], dict): | |
| # flat list of dicts | |
| for md in metadatas: | |
| if metadata_field_name in md: | |
| values.add(md[metadata_field_name]) | |
| elif isinstance(metadatas[0], list): | |
| # nested list | |
| for md_list in metadatas: | |
| for md in md_list: | |
| if metadata_field_name in md: | |
| values.add(md[metadata_field_name]) | |
| return sorted(list(values)) | |
| def build_global_index_for_scripture(self, scripture: dict, force: bool = False): | |
| scripture_name = scripture["name"] | |
| chapter_order = scripture.get("chapter_order", None) | |
| # if scripture_name != "vishnu_sahasranamam": | |
| # continue | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s: Processing", scripture_name | |
| ) | |
| collection_name = scripture["collection_name"] | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| metadata_fields = scripture.get("metadata_fields", []) | |
| # Get metadata field names marked as unique | |
| unique_fields = [f["name"] for f in metadata_fields if f.get("is_unique")] | |
| if not unique_fields: | |
| if metadata_fields: | |
| unique_fields = [metadata_fields[0]["name"]] | |
| else: | |
| logger.warning( | |
| f"No metadata fields defined for {collection_name}, skipping" | |
| ) | |
| return | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s:unique fields: %s", | |
| scripture_name, | |
| unique_fields, | |
| ) | |
| # Build chapter_order mapping if defined | |
| chapter_order_mapping = {} | |
| for field in metadata_fields: | |
| if callable(chapter_order): | |
| chapter_order_mapping = chapter_order() | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s:chapter_order_mapping: %s", | |
| scripture_name, | |
| chapter_order_mapping, | |
| ) | |
| # Fetch all records (keep embeddings for upsert) | |
| try: | |
| results = collection.get(include=["metadatas", "documents", "embeddings"]) | |
| except Exception as e: | |
| logger.error( | |
| "build_global_index_for_all_scriptures:%s Error getting data from chromadb", | |
| scripture_name, | |
| exc_info=True, | |
| ) | |
| return | |
| ids = results["ids"] | |
| metadatas = results["metadatas"] | |
| documents = results["documents"] | |
| embeddings = results.get("embeddings", [None] * len(ids)) | |
| if not force and metadatas and "_global_index" in metadatas[0]: | |
| logger.warning( | |
| "build_global_index_for_all_scriptures:%s: global index already available. skipping collection", | |
| scripture_name, | |
| ) | |
| return | |
| # Create a DataFrame for metadata sorting | |
| df = pd.DataFrame(metadatas) | |
| df["_id"] = ids | |
| df["_doc"] = documents | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s:unique_fields: %s", | |
| scripture_name, | |
| unique_fields, | |
| ) | |
| # Add sortable columns for each unique field | |
| for field_name in unique_fields: | |
| if field_name.lower() in ("chapter","prabandham_name") and chapter_order_mapping: | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s:sorting", | |
| scripture_name, | |
| ) | |
| # Map chapter names to their defined order | |
| df["_sort_" + field_name] = ( | |
| df[field_name].map(chapter_order_mapping).fillna(np.inf) | |
| ) | |
| else: | |
| # Try numeric, fallback to string lowercase | |
| def parse_val(v): | |
| if v is None: | |
| return float("inf") | |
| if isinstance(v, int): | |
| return v | |
| if isinstance(v, str): | |
| v = v.strip() | |
| return int(v) if v.isdigit() else v.lower() | |
| return str(v) | |
| df["_sort_" + field_name] = df[field_name].apply(parse_val) | |
| sort_cols = ["_sort_" + f for f in unique_fields] | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s:sort_cols=%s", | |
| scripture_name, | |
| sort_cols | |
| ) | |
| df = df.sort_values(by=sort_cols, kind="stable").reset_index(drop=True) | |
| # Assign global index | |
| df["_global_index"] = range(1, len(df) + 1) | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s: updating database", | |
| scripture_name, | |
| ) | |
| # Batch upsert | |
| BATCH_SIZE = 5000 # safely below max batch size | |
| for i in range(0, len(df), BATCH_SIZE): | |
| batch_df = df.iloc[i : i + BATCH_SIZE] | |
| batch_ids = batch_df["_id"].tolist() | |
| batch_docs = batch_df["_doc"].tolist() | |
| batch_metas = [ | |
| {k: record[k] for k in metadatas[0].keys() if k in record} | |
| | {"_global_index": record["_global_index"]} | |
| for record in batch_df.to_dict(orient="records") | |
| ] | |
| # Use original metadata keys for upsert | |
| batch_metas = [ | |
| {k: record[k] for k in metadatas[0].keys() if k in record} | |
| | {"_global_index": record["_global_index"]} | |
| for record in batch_df.to_dict(orient="records") | |
| ] | |
| batch_embeds = [embeddings[idx] for idx in batch_df.index] | |
| collection.update( | |
| ids=batch_ids, | |
| # documents=batch_docs, | |
| metadatas=batch_metas, | |
| # embeddings=batch_embeds, | |
| ) | |
| logger.info( | |
| "build_global_index_for_all_scriptures:%s: ✅ Updated with %d records", | |
| scripture_name, | |
| len(df), | |
| ) | |
| def build_global_index_for_all_scriptures(self, force: bool = False): | |
| logger.info("build_global_index_for_all_scriptures: started") | |
| config = SanatanConfig() | |
| for scripture in config.scriptures: | |
| self.build_global_index_for_scripture(scripture=scripture, force=force) | |
| def fix_taniyans_in_divya_prabandham(self): | |
| nalayiram_helper.reorder_taniyan( | |
| self.chroma_client.get_collection("divya_prabandham") | |
| ) | |
| def delete_taniyans_in_divya_prabandham(self): | |
| nalayiram_helper.delete_taniyan( | |
| self.chroma_client.get_collection("divya_prabandham") | |
| ) | |