sanatan_ai / config.py
vikramvasudevan's picture
Upload folder using huggingface_hub
90dc9aa verified
raw
history blame
6.52 kB
from metadata import MetadataWhereClause
from typing import List, Dict
from modules.config import scripture_configurations
from modules.languages.transliterator import fn_transliterate
class SanatanConfig:
dbStorePath: str = "./chromadb-store"
scriptures = scripture_configurations
def get_scripture_by_collection(self, collection_name: str):
return [
scripture
for scripture in self.scriptures
if scripture["collection_name"] == collection_name
][0]
def get_scripture_by_name(self, scripture_name: str):
return [
scripture
for scripture in self.scriptures
if scripture["name"] == scripture_name
][0]
def is_metadata_field_allowed(
self, collection_name: str, metadata_where_clause: MetadataWhereClause
):
scripture = self.get_scripture_by_collection(collection_name=collection_name)
allowed_fields = [field["name"] for field in scripture["metadata_fields"]]
def validate_clause(clause: MetadataWhereClause):
# validate direct filters
if clause.filters:
for f in clause.filters:
if f.metadata_field not in allowed_fields:
raise Exception(
f"metadata_field: [{f.metadata_field}] not allowed in collection [{collection_name}]. "
f"Here are the allowed fields with their descriptions: {scripture['metadata_fields']}"
)
# recurse into groups
if clause.groups:
for g in clause.groups:
validate_clause(g)
validate_clause(metadata_where_clause)
return True
def get_embedding_for_collection(self, collection_name: str):
scripture = self.get_scripture_by_collection(collection_name)
embedding_fn = "hf" # default is huggingface sentence transformaers
if "collection_embedding_fn" in scripture:
embedding_fn = scripture["collection_embedding_fn"] # overridden in config
return embedding_fn
def remove_callables(self, obj):
if isinstance(obj, dict):
return {
k: self.remove_callables(v) for k, v in obj.items() if not callable(v)
}
elif isinstance(obj, list):
return [self.remove_callables(v) for v in obj if not callable(v)]
else:
return obj
def filter_scriptures_fields(self, fields_to_keep: List[str]) -> List[Dict]:
"""
Return a list of scripture dicts containing only the specified fields.
"""
filtered = []
for s in self.scriptures:
filtered.append({k: s[k] for k in fields_to_keep if k in s})
return self.remove_callables(filtered)
def canonicalize_document(
self, scripture_name: str, document_text: str, metadata_doc: dict
):
"""
Convert scripture-specific document to a flattened canonical form.
Supports strings, lambdas, or nested dicts in field mapping.
Only allows keys from the allowed canonical fields list.
"""
allowed_keys = {
"_global_index",
"id",
"verse",
"text",
"title",
"unit",
"unit_index",
"word_by_word_native",
"translation",
"transliteration",
"transliteration_v2", # support v2
"reference_link",
"author",
"chapter_name",
"relative_path",
"location",
}
config = next((s for s in self.scriptures if s["name"] == scripture_name), None)
if not config:
raise ValueError(f"Unknown scripture: {scripture_name}")
mapping = config.get("field_mapping", {})
# ------------------------------------
# Inject transliteration_v2 if missing
# ------------------------------------
if "transliteration_v2" not in mapping:
text_field = mapping.get("text", "text") # fallback to "text"
mapping["transliteration_v2"] = lambda doc: {
lang: t for lang, t in fn_transliterate(doc.get(text_field, "")).items()
}
def resolve_field(field):
"""Resolve a field: string key, callable, or nested dict"""
if isinstance(field, dict):
# Recursively resolve nested dict values
return {
subkey: resolve_field(subval) for subkey, subval in field.items()
}
elif callable(field):
try:
return field(metadata_doc)
except Exception:
return None
elif isinstance(field, str):
return metadata_doc.get(field)
return None
canonical_doc = {}
for key, field in mapping.items():
if key in allowed_keys:
canonical_doc[key] = resolve_field(field)
# Add standard fields from config
canonical_doc["scripture_name"] = config.get("name")
canonical_doc["scripture_title"] = config.get("title")
canonical_doc["source"] = config.get("source")
canonical_doc["language"] = config.get("language")
canonical_doc["unit"] = config.get("unit")
canonical_doc["document"] = document_text
# Handle text/document swap if text missing
if canonical_doc.get("text") in (None, "-"):
canonical_doc["text"] = canonical_doc["document"]
canonical_doc["document"] = "-"
# Verse resolution
verse = resolve_field(config.get("unit_field", config.get("unit")))
if verse == "-":
canonical_doc["verse"] = -1
else:
canonical_doc["verse"] = int(verse) if verse else 0
# ID and global index
canonical_doc["id"] = resolve_field("id")
canonical_doc["_global_index"] = resolve_field("_global_index")
return canonical_doc
def get_collection_name(self, scripture_name):
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name), None
)
collection_name = config.get("collection_name")
return collection_name
if __name__ == "__main__":
print(SanatanConfig.scriptures)
[scripture["collection_name"] for scripture in SanatanConfig.scriptures]