Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import json | |
| import random | |
| from typing import Literal | |
| import chromadb | |
| import re, unicodedata | |
| from config import SanatanConfig | |
| from embeddings import get_embedding | |
| import logging | |
| from pydantic import BaseModel | |
| from metadata import MetadataFilter, MetadataWhereClause | |
| from modules.db.relevance import validate_relevance_queryresult | |
| from tqdm import tqdm | |
| logging.basicConfig() | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class SanatanDatabase: | |
| def __init__(self) -> None: | |
| self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) | |
| def does_data_exist(self, collection_name: str) -> bool: | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| num_rows = collection.count() | |
| logger.info("num_rows in %s = %d", collection_name, num_rows) | |
| return num_rows > 0 | |
| def load(self, collection_name: str, ids, documents, embeddings, metadatas): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| collection.add( | |
| ids=ids, | |
| documents=documents, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ) | |
| def fetch_random_data( | |
| self, | |
| collection_name: str, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| n_results=1, | |
| ): | |
| # fetch all documents once | |
| logger.info( | |
| "getting %d random verses from [%s] | metadata_where_clause = %s", | |
| n_results, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| data = collection.get( | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ) | |
| ) | |
| docs = data["documents"] # list of all verse texts | |
| ids = data["ids"] | |
| metas = data["metadatas"] | |
| if not docs: | |
| logger.warning("No data found! - data=%s", data) | |
| return chromadb.QueryResult(ids=[], documents=[], metadatas=[]) | |
| # pick k random indices | |
| indices = random.sample(range(len(docs)), k=min(n_results, len(docs))) | |
| return chromadb.QueryResult( | |
| ids=[ids[i] for i in indices], | |
| documents=[docs[i] for i in indices], | |
| metadatas=[metas[i] for i in indices], | |
| ) | |
| def search( | |
| self, | |
| collection_name: str, | |
| query: str = None, | |
| metadata_where_clause: MetadataWhereClause = None, | |
| n_results=2, | |
| search_type: Literal["semantic", "literal", "random"] = "semantic", | |
| ): | |
| logger.info( | |
| "Search for [%s] in [%s]| metadata_where_clause=%s | search_type=%s | n_results=%d", | |
| query, | |
| collection_name, | |
| metadata_where_clause, | |
| search_type, | |
| n_results, | |
| ) | |
| if search_type == "semantic": | |
| return self.search_semantic( | |
| collection_name=collection_name, | |
| query=query, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| elif search_type == "literal": | |
| return self.search_for_literal( | |
| collection_name=collection_name, | |
| literal_to_search_for=query, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| else: | |
| # random | |
| return self.fetch_random_data( | |
| collection_name=collection_name, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| def search_semantic( | |
| self, | |
| collection_name: str, | |
| query: str | None = None, | |
| metadata_where_clause: MetadataWhereClause | None = None, | |
| n_results=2, | |
| ): | |
| logger.info( | |
| "Vector Semantic Search for [%s] in [%s] | metadata_where_clause = %s", | |
| query, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| try: | |
| q = query.strip() if query is not None else "" | |
| if not q: | |
| # fallback: fetch random verse | |
| return self.fetch_random_data( | |
| collection_name=collection_name, | |
| metadata_where_clause=metadata_where_clause, | |
| n_results=n_results, | |
| ) | |
| else: | |
| response = collection.query( | |
| query_embeddings=get_embedding( | |
| [query], | |
| SanatanConfig().get_embedding_for_collection(collection_name), | |
| ), | |
| # query_texts=[query], | |
| n_results=n_results, | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| include=["metadatas", "documents", "distances"], | |
| ) | |
| except Exception as e: | |
| logger.error("Error in search: %s", e) | |
| return chromadb.QueryResult( | |
| documents=[], | |
| ids=[], | |
| metadatas=[], | |
| distances=[], | |
| ) | |
| validated_response = validate_relevance_queryresult(query, response) | |
| logger.info( | |
| "status = %s | reason= %s", | |
| validated_response.status, | |
| validated_response.reason, | |
| ) | |
| return validated_response.result | |
| def search_for_literal( | |
| self, | |
| collection_name: str, | |
| literal_to_search_for: str | None = None, | |
| metadata_where_clause: MetadataWhereClause | None = None, | |
| n_results=2, | |
| ): | |
| logger.info( | |
| "Searching literally for [%s] in [%s] | metadata_where_clause = %s", | |
| literal_to_search_for, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| if literal_to_search_for is None or literal_to_search_for.strip() == "": | |
| logger.warning("Nothing to search literally.") | |
| raise Exception("query cannot be None or empty for a literal search!") | |
| # return self.fetch_random_data( | |
| # collection_name=collection_name, | |
| # ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| def normalize(text): | |
| return unicodedata.normalize("NFKC", text).lower() | |
| # 1. Try native contains | |
| response = collection.get( | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| where_document={"$contains": literal_to_search_for}, | |
| limit=n_results, | |
| ) | |
| if response["documents"] and any(response["documents"]): | |
| return chromadb.QueryResult( | |
| ids=response["ids"], | |
| documents=response["documents"], | |
| metadatas=response["metadatas"], | |
| ) | |
| # 2. Regex fallback (normalized) | |
| logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) | |
| regex = re.compile(re.escape(normalize(literal_to_search_for))) | |
| logger.info("regex = %s", regex) | |
| all_docs = collection.get( | |
| where=( | |
| metadata_where_clause.to_chroma_where() | |
| if metadata_where_clause is not None | |
| else None | |
| ), | |
| ) | |
| matched_docs = [] | |
| for doc_list, metadata_list, doc_id_list in zip( | |
| all_docs["documents"], all_docs["metadatas"], all_docs["ids"] | |
| ): | |
| # Ensure all are lists | |
| if isinstance(doc_list, str): | |
| doc_list = [doc_list] | |
| if isinstance(metadata_list, dict): | |
| metadata_list = [metadata_list] | |
| if isinstance(doc_id_list, str): | |
| doc_id_list = [doc_id_list] | |
| for i in range(len(doc_list)): | |
| d = doc_list[i] | |
| current_metadata = metadata_list[i] | |
| current_id = doc_id_list[i] | |
| doc_match = regex.search(normalize(d)) | |
| metadata_match = False | |
| for key, value in current_metadata.items(): | |
| if isinstance(value, str) and regex.search(normalize(value)): | |
| metadata_match = True | |
| break | |
| elif isinstance(value, list): | |
| if any( | |
| isinstance(v, str) and regex.search(normalize(v)) | |
| for v in value | |
| ): | |
| metadata_match = True | |
| break | |
| if doc_match or metadata_match: | |
| matched_docs.append( | |
| { | |
| "id": current_id, | |
| "document": d, | |
| "metadata": current_metadata, | |
| } | |
| ) | |
| if len(matched_docs) >= n_results: | |
| break | |
| if len(matched_docs) >= n_results: | |
| break | |
| return chromadb.QueryResult( | |
| { | |
| "documents": [[d["document"] for d in matched_docs]], | |
| "ids": [[d["id"] for d in matched_docs]], | |
| "metadatas": [[d["metadata"] for d in matched_docs]], | |
| } | |
| ) | |
| def count(self, collection_name: str): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| total_count = collection.count() | |
| logger.info("Total records in [%s] = %d", collection_name, total_count) | |
| return total_count | |
| def test_sanity(self): | |
| for scripture in SanatanConfig().scriptures: | |
| count = self.count(collection_name=scripture["collection_name"]) | |
| if count == 0: | |
| raise Exception(f"No data in collection {scripture["collection_name"]}") | |
| def reembed_collection_openai(self, collection_name: str, batch_size: int = 50): | |
| """ | |
| Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings. | |
| All existing documents are re-embedded and inserted into the new collection. | |
| Args: | |
| collection_name: The name of the collection to delete/recreate. | |
| batch_size: Number of documents to process per batch. | |
| """ | |
| # Step 1: Fetch old collection data (if exists) | |
| try: | |
| old_collection = self.chroma_client.get_collection(name=collection_name) | |
| old_data = old_collection.get(include=["documents", "metadatas"]) | |
| documents = old_data["documents"] | |
| metadatas = old_data["metadatas"] | |
| ids = old_data["ids"] | |
| print(f"Fetched {len(documents)} documents from old collection.") | |
| # Step 2: Delete old collection | |
| # self.chroma_client.delete_collection(collection_name) | |
| # print(f"Deleted old collection '{collection_name}'.") | |
| except chromadb.errors.NotFoundError: | |
| print(f"No existing collection named '{collection_name}', starting fresh.") | |
| documents, metadatas, ids = [], [], [] | |
| # Step 3: Create new collection with correct embedding dimension | |
| new_collection = self.chroma_client.create_collection( | |
| name=f"{collection_name}_openai", | |
| embedding_function=None, # embeddings will be provided manually | |
| ) | |
| print( | |
| f"Created new collection '{collection_name}_openai' with embedding_dim=3072." | |
| ) | |
| # Step 4: Re-embed and insert documents in batches | |
| for i in tqdm( | |
| range(0, len(documents), batch_size), desc="Re-embedding batches" | |
| ): | |
| batch_docs = documents[i : i + batch_size] | |
| batch_metadatas = metadatas[i : i + batch_size] | |
| batch_ids = ids[i : i + batch_size] | |
| embeddings = get_embedding(batch_docs, backend="openai") | |
| new_collection.add( | |
| ids=batch_ids, | |
| documents=batch_docs, | |
| metadatas=batch_metadatas, | |
| embeddings=embeddings, | |
| ) | |
| print("All documents re-embedded and added to new collection successfully!") | |