Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 9,019 Bytes
19d30fe fd1b271 c6893be fd1b271 d434239 5f4344d 0aef7d0 fd1b271 7290ba6 fd1b271 3e95dda 19d30fe 0aef7d0 19d30fe bbb5184 c6893be 63d1774 500d0e4 c6893be 8d1a737 c6893be 8d1a737 c6893be 5c1cea6 c6893be 5c1cea6 8d1a737 d434239 8d1a737 c6893be 5c1cea6 c6893be 8d1a737 c6893be bbb5184 d434239 bbb5184 7290ba6 d434239 7290ba6 d434239 7290ba6 bbb5184 74c37c0 d434239 bbb5184 fd1b271 d2bda67 4646386 0aef7d0 19d30fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import json
import chromadb
import re, unicodedata
from config import SanatanConfig
from embeddings import get_embedding
import logging
from pydantic import BaseModel
from metadata import MetadataFilter, MetadataWhereClause
from modules.db.relevance import validate_relevance_queryresult
from tqdm import tqdm
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class SanatanDatabase:
def __init__(self) -> None:
self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath)
def does_data_exist(self, collection_name: str) -> bool:
collection = self.chroma_client.get_or_create_collection(name=collection_name)
num_rows = collection.count()
logger.info("num_rows in %s = %d", collection_name, num_rows)
return num_rows > 0
def load(self, collection_name: str, ids, documents, embeddings, metadatas):
collection = self.chroma_client.get_or_create_collection(name=collection_name)
collection.add(
ids=ids,
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
)
def search(self, collection_name: str, query: str, n_results=2):
logger.info("Vector Semantic Search for [%s] in [%s]", query, collection_name)
collection = self.chroma_client.get_or_create_collection(name=collection_name)
try:
response = collection.query(
query_embeddings=get_embedding(
[query], SanatanConfig().get_embedding_for_collection(collection_name)
),
# query_texts=[query],
n_results=n_results,
include=["metadatas","documents","distances"],
)
except Exception as e:
logger.error("Error in search: %s", e)
return chromadb.QueryResult(
documents=[],
ids=[],
metadatas=[],
distances=[],
)
validated_response = validate_relevance_queryresult(query, response)
return validated_response["result"]
def search_for_literal(
self, collection_name: str, literal_to_search_for: str, n_results=2
):
logger.info(
"Searching literally for [%s] in [%s]",
literal_to_search_for,
collection_name,
)
collection = self.chroma_client.get_or_create_collection(name=collection_name)
def normalize(text):
return unicodedata.normalize("NFKC", text).lower()
# 1. Try native contains
response = collection.query(
query_embeddings=get_embedding(
[literal_to_search_for], SanatanConfig().get_embedding_for_collection(collection_name)
),
where_document={"$contains": literal_to_search_for},
n_results=n_results,
)
if response["documents"] and any(response["documents"]):
return response
# 2. Regex fallback (normalized)
logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for)
regex = re.compile(re.escape(normalize(literal_to_search_for)))
logger.info("regex = %s", regex)
all_docs = collection.get()
matched_docs = []
for doc_list, metadata_list, doc_id_list in zip(
all_docs["documents"], all_docs["metadatas"], all_docs["ids"]
):
# Ensure all are lists
if isinstance(doc_list, str):
doc_list = [doc_list]
if isinstance(metadata_list, dict):
metadata_list = [metadata_list]
if isinstance(doc_id_list, str):
doc_id_list = [doc_id_list]
for i in range(len(doc_list)):
d = doc_list[i]
current_metadata = metadata_list[i]
current_id = doc_id_list[i]
doc_match = regex.search(normalize(d))
metadata_match = False
for key, value in current_metadata.items():
if isinstance(value, str) and regex.search(normalize(value)):
metadata_match = True
break
elif isinstance(value, list):
if any(
isinstance(v, str) and regex.search(normalize(v))
for v in value
):
metadata_match = True
break
if doc_match or metadata_match:
matched_docs.append(
{
"id": current_id,
"document": d,
"metadata": current_metadata,
}
)
if len(matched_docs) >= n_results:
break
if len(matched_docs) >= n_results:
break
return {
"documents": [[d["document"] for d in matched_docs]],
"ids": [[d["id"] for d in matched_docs]],
"metadatas": [[d["metadata"] for d in matched_docs]],
}
def search_by_metadata(
self,
collection_name: str,
query: str,
metadata_where_clause: MetadataWhereClause,
n_results=2,
):
"""Search by a metadata field inside a specific collection using a specific operator. For instance {"azhwar_name": {"$in": "Thirumangai Azhwar"}}"""
logger.info(
"Searching by metadata for [%s] in [%s] with metadata_filters=%s",
query,
collection_name,
metadata_where_clause,
)
collection = self.chroma_client.get_or_create_collection(name=collection_name)
response = collection.query(
query_embeddings=get_embedding(
[query], SanatanConfig().get_embedding_for_collection(collection_name)
),
where=metadata_where_clause.to_chroma_where(),
# query_texts=[query],
n_results=n_results,
)
return response
def count(self, collection_name: str):
collection = self.chroma_client.get_or_create_collection(name=collection_name)
total_count = collection.count()
logger.info("Total records in [%s] = %d", collection_name, total_count)
return total_count
def test_sanity(self):
for scripture in SanatanConfig().scriptures:
count = self.count(collection_name=scripture["collection_name"])
if count == 0:
raise Exception(f"No data in collection {scripture["collection_name"]}")
def reembed_collection_openai(self, collection_name: str, batch_size: int = 50):
"""
Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings.
All existing documents are re-embedded and inserted into the new collection.
Args:
collection_name: The name of the collection to delete/recreate.
batch_size: Number of documents to process per batch.
"""
# Step 1: Fetch old collection data (if exists)
try:
old_collection = self.chroma_client.get_collection(name=collection_name)
old_data = old_collection.get(include=["documents", "metadatas"])
documents = old_data["documents"]
metadatas = old_data["metadatas"]
ids = old_data["ids"]
print(f"Fetched {len(documents)} documents from old collection.")
# Step 2: Delete old collection
# self.chroma_client.delete_collection(collection_name)
# print(f"Deleted old collection '{collection_name}'.")
except chromadb.errors.NotFoundError:
print(f"No existing collection named '{collection_name}', starting fresh.")
documents, metadatas, ids = [], [], []
# Step 3: Create new collection with correct embedding dimension
new_collection = self.chroma_client.create_collection(
name=f"{collection_name}_openai",
embedding_function=None, # embeddings will be provided manually
)
print(f"Created new collection '{collection_name}_openai' with embedding_dim=3072.")
# Step 4: Re-embed and insert documents in batches
for i in tqdm(range(0, len(documents), batch_size), desc="Re-embedding batches"):
batch_docs = documents[i:i+batch_size]
batch_metadatas = metadatas[i:i+batch_size]
batch_ids = ids[i:i+batch_size]
embeddings = get_embedding(batch_docs, backend="openai")
new_collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metadatas,
embeddings=embeddings
)
print("All documents re-embedded and added to new collection successfully!") |