sanatan_ai / db.py
vikramvasudevan's picture
Upload folder using huggingface_hub
12446b3 verified
raw
history blame
12.7 kB
import json
import random
from typing import Literal
import chromadb
import re, unicodedata
from config import SanatanConfig
from embeddings import get_embedding
import logging
from pydantic import BaseModel
from metadata import MetadataFilter, MetadataWhereClause
from modules.db.relevance import validate_relevance_queryresult
from tqdm import tqdm
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class SanatanDatabase:
def __init__(self) -> None:
self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath)
def does_data_exist(self, collection_name: str) -> bool:
collection = self.chroma_client.get_or_create_collection(name=collection_name)
num_rows = collection.count()
logger.info("num_rows in %s = %d", collection_name, num_rows)
return num_rows > 0
def load(self, collection_name: str, ids, documents, embeddings, metadatas):
collection = self.chroma_client.get_or_create_collection(name=collection_name)
collection.add(
ids=ids,
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
)
def fetch_random_data(
self,
collection_name: str,
metadata_where_clause: MetadataWhereClause = None,
n_results=1,
):
# fetch all documents once
logger.info(
"getting %d random verses from [%s] | metadata_where_clause = %s",
n_results,
collection_name,
metadata_where_clause,
)
collection = self.chroma_client.get_or_create_collection(name=collection_name)
data = collection.get(
where=(
metadata_where_clause.to_chroma_where()
if metadata_where_clause is not None
else None
)
)
docs = data["documents"] # list of all verse texts
ids = data["ids"]
metas = data["metadatas"]
if not docs:
logger.warning("No data found! - data=%s", data)
return chromadb.QueryResult(ids=[], documents=[], metadatas=[])
# pick k random indices
indices = random.sample(range(len(docs)), k=min(n_results, len(docs)))
return chromadb.QueryResult(
ids=[ids[i] for i in indices],
documents=[docs[i] for i in indices],
metadatas=[metas[i] for i in indices],
)
def search(
self,
collection_name: str,
query: str = None,
metadata_where_clause: MetadataWhereClause = None,
n_results=2,
search_type: Literal["semantic", "literal", "random"] = "semantic",
):
logger.info(
"Search for [%s] in [%s]| metadata_where_clause=%s | search_type=%s | n_results=%d",
query,
collection_name,
metadata_where_clause,
search_type,
n_results,
)
if search_type == "semantic":
return self.search_semantic(
collection_name=collection_name,
query=query,
metadata_where_clause=metadata_where_clause,
n_results=n_results,
)
elif search_type == "literal":
return self.search_for_literal(
collection_name=collection_name,
literal_to_search_for=query,
metadata_where_clause=metadata_where_clause,
n_results=n_results,
)
else:
# random
return self.fetch_random_data(
collection_name=collection_name,
metadata_where_clause=metadata_where_clause,
n_results=n_results,
)
def search_semantic(
self,
collection_name: str,
query: str | None = None,
metadata_where_clause: MetadataWhereClause | None = None,
n_results=2,
):
logger.info(
"Vector Semantic Search for [%s] in [%s] | metadata_where_clause = %s",
query,
collection_name,
metadata_where_clause,
)
collection = self.chroma_client.get_or_create_collection(name=collection_name)
try:
q = query.strip() if query is not None else ""
if not q:
# fallback: fetch random verse
return self.fetch_random_data(
collection_name=collection_name,
metadata_where_clause=metadata_where_clause,
n_results=n_results,
)
else:
response = collection.query(
query_embeddings=get_embedding(
[query],
SanatanConfig().get_embedding_for_collection(collection_name),
),
# query_texts=[query],
n_results=n_results,
where=(
metadata_where_clause.to_chroma_where()
if metadata_where_clause is not None
else None
),
include=["metadatas", "documents", "distances"],
)
except Exception as e:
logger.error("Error in search: %s", e)
return chromadb.QueryResult(
documents=[],
ids=[],
metadatas=[],
distances=[],
)
validated_response = validate_relevance_queryresult(query, response)
logger.info(
"status = %s | reason= %s",
validated_response.status,
validated_response.reason,
)
return validated_response.result
def search_for_literal(
self,
collection_name: str,
literal_to_search_for: str | None = None,
metadata_where_clause: MetadataWhereClause | None = None,
n_results=2,
):
logger.info(
"Searching literally for [%s] in [%s] | metadata_where_clause = %s",
literal_to_search_for,
collection_name,
metadata_where_clause,
)
if literal_to_search_for is None or literal_to_search_for.strip() == "":
logger.warning("Nothing to search literally.")
raise Exception("query cannot be None or empty for a literal search!")
# return self.fetch_random_data(
# collection_name=collection_name,
# )
collection = self.chroma_client.get_or_create_collection(name=collection_name)
def normalize(text):
return unicodedata.normalize("NFKC", text).lower()
# 1. Try native contains
response = collection.get(
where=(
metadata_where_clause.to_chroma_where()
if metadata_where_clause is not None
else None
),
where_document={"$contains": literal_to_search_for},
limit=n_results,
)
if response["documents"] and any(response["documents"]):
return chromadb.QueryResult(
ids=response["ids"],
documents=response["documents"],
metadatas=response["metadatas"],
)
# 2. Regex fallback (normalized)
logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for)
regex = re.compile(re.escape(normalize(literal_to_search_for)))
logger.info("regex = %s", regex)
all_docs = collection.get(
where=(
metadata_where_clause.to_chroma_where()
if metadata_where_clause is not None
else None
),
)
matched_docs = []
for doc_list, metadata_list, doc_id_list in zip(
all_docs["documents"], all_docs["metadatas"], all_docs["ids"]
):
# Ensure all are lists
if isinstance(doc_list, str):
doc_list = [doc_list]
if isinstance(metadata_list, dict):
metadata_list = [metadata_list]
if isinstance(doc_id_list, str):
doc_id_list = [doc_id_list]
for i in range(len(doc_list)):
d = doc_list[i]
current_metadata = metadata_list[i]
current_id = doc_id_list[i]
doc_match = regex.search(normalize(d))
metadata_match = False
for key, value in current_metadata.items():
if isinstance(value, str) and regex.search(normalize(value)):
metadata_match = True
break
elif isinstance(value, list):
if any(
isinstance(v, str) and regex.search(normalize(v))
for v in value
):
metadata_match = True
break
if doc_match or metadata_match:
matched_docs.append(
{
"id": current_id,
"document": d,
"metadata": current_metadata,
}
)
if len(matched_docs) >= n_results:
break
if len(matched_docs) >= n_results:
break
return chromadb.QueryResult(
{
"documents": [[d["document"] for d in matched_docs]],
"ids": [[d["id"] for d in matched_docs]],
"metadatas": [[d["metadata"] for d in matched_docs]],
}
)
def count(self, collection_name: str):
collection = self.chroma_client.get_or_create_collection(name=collection_name)
total_count = collection.count()
logger.info("Total records in [%s] = %d", collection_name, total_count)
return total_count
def test_sanity(self):
for scripture in SanatanConfig().scriptures:
count = self.count(collection_name=scripture["collection_name"])
if count == 0:
raise Exception(f"No data in collection {scripture["collection_name"]}")
def reembed_collection_openai(self, collection_name: str, batch_size: int = 50):
"""
Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings.
All existing documents are re-embedded and inserted into the new collection.
Args:
collection_name: The name of the collection to delete/recreate.
batch_size: Number of documents to process per batch.
"""
# Step 1: Fetch old collection data (if exists)
try:
old_collection = self.chroma_client.get_collection(name=collection_name)
old_data = old_collection.get(include=["documents", "metadatas"])
documents = old_data["documents"]
metadatas = old_data["metadatas"]
ids = old_data["ids"]
print(f"Fetched {len(documents)} documents from old collection.")
# Step 2: Delete old collection
# self.chroma_client.delete_collection(collection_name)
# print(f"Deleted old collection '{collection_name}'.")
except chromadb.errors.NotFoundError:
print(f"No existing collection named '{collection_name}', starting fresh.")
documents, metadatas, ids = [], [], []
# Step 3: Create new collection with correct embedding dimension
new_collection = self.chroma_client.create_collection(
name=f"{collection_name}_openai",
embedding_function=None, # embeddings will be provided manually
)
print(
f"Created new collection '{collection_name}_openai' with embedding_dim=3072."
)
# Step 4: Re-embed and insert documents in batches
for i in tqdm(
range(0, len(documents), batch_size), desc="Re-embedding batches"
):
batch_docs = documents[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
embeddings = get_embedding(batch_docs, backend="openai")
new_collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metadatas,
embeddings=embeddings,
)
print("All documents re-embedded and added to new collection successfully!")