Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import json | |
| import chromadb | |
| import re, unicodedata | |
| from config import SanatanConfig | |
| from embeddings import get_embedding | |
| import logging | |
| from pydantic import BaseModel | |
| from metadata import MetadataFilter, MetadataWhereClause | |
| from modules.db.relevance import validate_relevance_queryresult | |
| from tqdm import tqdm | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class SanatanDatabase: | |
| def __init__(self) -> None: | |
| self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) | |
| def does_data_exist(self, collection_name: str) -> bool: | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| num_rows = collection.count() | |
| logger.info("num_rows in %s = %d", collection_name, num_rows) | |
| return num_rows > 0 | |
| def load(self, collection_name: str, ids, documents, embeddings, metadatas): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| collection.add( | |
| ids=ids, | |
| documents=documents, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ) | |
| def search(self, collection_name: str, query: str, n_results=2): | |
| logger.info("Vector Semantic Search for [%s] in [%s]", query, collection_name) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| try: | |
| response = collection.query( | |
| query_embeddings=get_embedding( | |
| [query], SanatanConfig().get_embedding_for_collection(collection_name) | |
| ), | |
| # query_texts=[query], | |
| n_results=n_results, | |
| include=["metadatas","documents","distances"], | |
| ) | |
| except Exception as e: | |
| logger.error("Error in search: %s", e) | |
| return chromadb.QueryResult( | |
| documents=[], | |
| ids=[], | |
| metadatas=[], | |
| distances=[], | |
| ) | |
| validated_response = validate_relevance_queryresult(query, response) | |
| return validated_response["result"] | |
| def search_for_literal( | |
| self, collection_name: str, literal_to_search_for: str, n_results=2 | |
| ): | |
| logger.info( | |
| "Searching literally for [%s] in [%s]", | |
| literal_to_search_for, | |
| collection_name, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| def normalize(text): | |
| return unicodedata.normalize("NFKC", text).lower() | |
| # 1. Try native contains | |
| response = collection.query( | |
| query_embeddings=get_embedding( | |
| [literal_to_search_for], SanatanConfig().get_embedding_for_collection(collection_name) | |
| ), | |
| where_document={"$contains": literal_to_search_for}, | |
| n_results=n_results, | |
| ) | |
| if response["documents"] and any(response["documents"]): | |
| return response | |
| # 2. Regex fallback (normalized) | |
| logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) | |
| regex = re.compile(re.escape(normalize(literal_to_search_for))) | |
| logger.info("regex = %s", regex) | |
| all_docs = collection.get() | |
| matched_docs = [] | |
| for doc_list, metadata_list, doc_id_list in zip( | |
| all_docs["documents"], all_docs["metadatas"], all_docs["ids"] | |
| ): | |
| # Ensure all are lists | |
| if isinstance(doc_list, str): | |
| doc_list = [doc_list] | |
| if isinstance(metadata_list, dict): | |
| metadata_list = [metadata_list] | |
| if isinstance(doc_id_list, str): | |
| doc_id_list = [doc_id_list] | |
| for i in range(len(doc_list)): | |
| d = doc_list[i] | |
| current_metadata = metadata_list[i] | |
| current_id = doc_id_list[i] | |
| doc_match = regex.search(normalize(d)) | |
| metadata_match = False | |
| for key, value in current_metadata.items(): | |
| if isinstance(value, str) and regex.search(normalize(value)): | |
| metadata_match = True | |
| break | |
| elif isinstance(value, list): | |
| if any( | |
| isinstance(v, str) and regex.search(normalize(v)) | |
| for v in value | |
| ): | |
| metadata_match = True | |
| break | |
| if doc_match or metadata_match: | |
| matched_docs.append( | |
| { | |
| "id": current_id, | |
| "document": d, | |
| "metadata": current_metadata, | |
| } | |
| ) | |
| if len(matched_docs) >= n_results: | |
| break | |
| if len(matched_docs) >= n_results: | |
| break | |
| return { | |
| "documents": [[d["document"] for d in matched_docs]], | |
| "ids": [[d["id"] for d in matched_docs]], | |
| "metadatas": [[d["metadata"] for d in matched_docs]], | |
| } | |
| def search_by_metadata( | |
| self, | |
| collection_name: str, | |
| query: str, | |
| metadata_where_clause: MetadataWhereClause, | |
| n_results=2, | |
| ): | |
| """Search by a metadata field inside a specific collection using a specific operator. For instance {"azhwar_name": {"$in": "Thirumangai Azhwar"}}""" | |
| logger.info( | |
| "Searching by metadata for [%s] in [%s] with metadata_filters=%s", | |
| query, | |
| collection_name, | |
| metadata_where_clause, | |
| ) | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| response = collection.query( | |
| query_embeddings=get_embedding( | |
| [query], SanatanConfig().get_embedding_for_collection(collection_name) | |
| ), | |
| where=metadata_where_clause.to_chroma_where(), | |
| # query_texts=[query], | |
| n_results=n_results, | |
| ) | |
| return response | |
| def count(self, collection_name: str): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name) | |
| total_count = collection.count() | |
| logger.info("Total records in [%s] = %d", collection_name, total_count) | |
| return total_count | |
| def test_sanity(self): | |
| for scripture in SanatanConfig().scriptures: | |
| count = self.count(collection_name=scripture["collection_name"]) | |
| if count == 0: | |
| raise Exception(f"No data in collection {scripture["collection_name"]}") | |
| def reembed_collection_openai(self, collection_name: str, batch_size: int = 50): | |
| """ | |
| Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings. | |
| All existing documents are re-embedded and inserted into the new collection. | |
| Args: | |
| collection_name: The name of the collection to delete/recreate. | |
| batch_size: Number of documents to process per batch. | |
| """ | |
| # Step 1: Fetch old collection data (if exists) | |
| try: | |
| old_collection = self.chroma_client.get_collection(name=collection_name) | |
| old_data = old_collection.get(include=["documents", "metadatas"]) | |
| documents = old_data["documents"] | |
| metadatas = old_data["metadatas"] | |
| ids = old_data["ids"] | |
| print(f"Fetched {len(documents)} documents from old collection.") | |
| # Step 2: Delete old collection | |
| # self.chroma_client.delete_collection(collection_name) | |
| # print(f"Deleted old collection '{collection_name}'.") | |
| except chromadb.errors.NotFoundError: | |
| print(f"No existing collection named '{collection_name}', starting fresh.") | |
| documents, metadatas, ids = [], [], [] | |
| # Step 3: Create new collection with correct embedding dimension | |
| new_collection = self.chroma_client.create_collection( | |
| name=f"{collection_name}_openai", | |
| embedding_function=None, # embeddings will be provided manually | |
| ) | |
| print(f"Created new collection '{collection_name}_openai' with embedding_dim=3072.") | |
| # Step 4: Re-embed and insert documents in batches | |
| for i in tqdm(range(0, len(documents), batch_size), desc="Re-embedding batches"): | |
| batch_docs = documents[i:i+batch_size] | |
| batch_metadatas = metadatas[i:i+batch_size] | |
| batch_ids = ids[i:i+batch_size] | |
| embeddings = get_embedding(batch_docs, backend="openai") | |
| new_collection.add( | |
| ids=batch_ids, | |
| documents=batch_docs, | |
| metadatas=batch_metadatas, | |
| embeddings=embeddings | |
| ) | |
| print("All documents re-embedded and added to new collection successfully!") |