Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from metadata import MetadataWhereClause | |
| from typing import List, Dict | |
| from modules.config import scripture_configurations | |
| from modules.languages.transliterator import fn_transliterate | |
| class SanatanConfig: | |
| dbStorePath: str = "./chromadb-store" | |
| scriptures = scripture_configurations | |
| def get_scripture_by_collection(self, collection_name: str): | |
| return [ | |
| scripture | |
| for scripture in self.scriptures | |
| if scripture["collection_name"] == collection_name | |
| ][0] | |
| def get_scripture_by_name(self, scripture_name: str): | |
| return [ | |
| scripture | |
| for scripture in self.scriptures | |
| if scripture["name"] == scripture_name | |
| ][0] | |
| def is_metadata_field_allowed( | |
| self, collection_name: str, metadata_where_clause: MetadataWhereClause | |
| ): | |
| scripture = self.get_scripture_by_collection(collection_name=collection_name) | |
| allowed_fields = [field["name"] for field in scripture["metadata_fields"]] | |
| def validate_clause(clause: MetadataWhereClause): | |
| # validate direct filters | |
| if clause.filters: | |
| for f in clause.filters: | |
| if f.metadata_field not in allowed_fields: | |
| raise Exception( | |
| f"metadata_field: [{f.metadata_field}] not allowed in collection [{collection_name}]. " | |
| f"Here are the allowed fields with their descriptions: {scripture['metadata_fields']}" | |
| ) | |
| # recurse into groups | |
| if clause.groups: | |
| for g in clause.groups: | |
| validate_clause(g) | |
| validate_clause(metadata_where_clause) | |
| return True | |
| def get_embedding_for_collection(self, collection_name: str): | |
| scripture = self.get_scripture_by_collection(collection_name) | |
| embedding_fn = "hf" # default is huggingface sentence transformaers | |
| if "collection_embedding_fn" in scripture: | |
| embedding_fn = scripture["collection_embedding_fn"] # overridden in config | |
| return embedding_fn | |
| def remove_callables(self, obj): | |
| if isinstance(obj, dict): | |
| return { | |
| k: self.remove_callables(v) for k, v in obj.items() if not callable(v) | |
| } | |
| elif isinstance(obj, list): | |
| return [self.remove_callables(v) for v in obj if not callable(v)] | |
| else: | |
| return obj | |
| def filter_scriptures_fields(self, fields_to_keep: List[str]) -> List[Dict]: | |
| """ | |
| Return a list of scripture dicts containing only the specified fields. | |
| """ | |
| filtered = [] | |
| for s in self.scriptures: | |
| filtered.append({k: s[k] for k in fields_to_keep if k in s}) | |
| return self.remove_callables(filtered) | |
| def canonicalize_document( | |
| self, scripture_name: str, document_text: str, metadata_doc: dict | |
| ): | |
| """ | |
| Convert scripture-specific document to a flattened canonical form. | |
| Supports strings, lambdas, or nested dicts in field mapping. | |
| Only allows keys from the allowed canonical fields list. | |
| """ | |
| allowed_keys = { | |
| "_global_index", | |
| "id", | |
| "verse", | |
| "text", | |
| "title", | |
| "unit", | |
| "unit_index", | |
| "word_by_word_native", | |
| "translation", | |
| "transliteration", | |
| "transliteration_v2", # support v2 | |
| "reference_link", | |
| "author", | |
| "chapter_name", | |
| "relative_path", | |
| "location", | |
| } | |
| config = next((s for s in self.scriptures if s["name"] == scripture_name), None) | |
| if not config: | |
| raise ValueError(f"Unknown scripture: {scripture_name}") | |
| mapping = config.get("field_mapping", {}) | |
| # ------------------------------------ | |
| # Inject transliteration_v2 if missing | |
| # ------------------------------------ | |
| if "transliteration_v2" not in mapping: | |
| text_field = mapping.get("text", "text") # fallback to "text" | |
| mapping["transliteration_v2"] = lambda doc: { | |
| lang: t for lang, t in fn_transliterate(doc.get(text_field, "")).items() | |
| } | |
| def resolve_field(field): | |
| """Resolve a field: string key, callable, or nested dict""" | |
| if isinstance(field, dict): | |
| # Recursively resolve nested dict values | |
| return { | |
| subkey: resolve_field(subval) for subkey, subval in field.items() | |
| } | |
| elif callable(field): | |
| try: | |
| return field(metadata_doc) | |
| except Exception: | |
| return None | |
| elif isinstance(field, str): | |
| return metadata_doc.get(field) | |
| return None | |
| canonical_doc = {} | |
| for key, field in mapping.items(): | |
| if key in allowed_keys: | |
| canonical_doc[key] = resolve_field(field) | |
| # Add standard fields from config | |
| canonical_doc["scripture_name"] = config.get("name") | |
| canonical_doc["scripture_title"] = config.get("title") | |
| canonical_doc["source"] = config.get("source") | |
| canonical_doc["language"] = config.get("language") | |
| canonical_doc["unit"] = config.get("unit") | |
| canonical_doc["document"] = document_text | |
| # Handle text/document swap if text missing | |
| if canonical_doc.get("text") in (None, "-"): | |
| canonical_doc["text"] = canonical_doc["document"] | |
| canonical_doc["document"] = "-" | |
| # Verse resolution | |
| verse = resolve_field(config.get("unit_field", config.get("unit"))) | |
| if verse == "-": | |
| canonical_doc["verse"] = -1 | |
| else: | |
| canonical_doc["verse"] = int(verse) if verse else 0 | |
| # ID and global index | |
| canonical_doc["id"] = resolve_field("id") | |
| canonical_doc["_global_index"] = resolve_field("_global_index") | |
| return canonical_doc | |
| def get_collection_name(self, scripture_name): | |
| config = next( | |
| (s for s in SanatanConfig().scriptures if s["name"] == scripture_name), None | |
| ) | |
| collection_name = config.get("collection_name") | |
| return collection_name | |
| if __name__ == "__main__": | |
| print(SanatanConfig.scriptures) | |
| [scripture["collection_name"] for scripture in SanatanConfig.scriptures] | |