sanatan_ai / server.py
vikramvasudevan's picture
Upload folder using huggingface_hub
0930865 verified
# server.py
import os
import random
import traceback
from typing import Optional
import uuid
from fastapi import APIRouter, HTTPException, Header, Request, Query
from fastapi.responses import JSONResponse
import pycountry
from pydantic import BaseModel, Field
from chat_utils import chat
from config import SanatanConfig
from db import SanatanDatabase
from metadata import MetadataWhereClause
from modules.audio.model import AudioRequest, AudioType
from modules.audio.service import svc_get_audio_urls, svc_get_indices_with_audio
from modules.config.categories import get_scripture_categories
from modules.dropbox.discources import get_discourse_by_id, get_discourse_summaries
from modules.firebase.messaging import FcmRequest, fcm_service
from modules.languages.get_v2 import handle_fetch_languages_v2
from modules.llm.summarizer.helpers.db_helper import get_scripture_from_db
from modules.llm.summarizer.models import ScriptureRequest
from modules.llm.summarizer.service import svc_summarize_scripture_verse
from modules.quiz.answer_validator import validate_answer
from modules.quiz.models import Question
from modules.quiz.quiz_helper import generate_question
import logging
from modules.video.model import VideoRequest
from modules.video.service import svc_get_video_urls
from modules.languages.models import TranslationRequest
from modules.languages.translator import svc_translate_text
from slowapi.util import get_remote_address
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
router = APIRouter()
limiter = Limiter(key_func=get_remote_address)
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter()
# In-memory mapping from session_id -> thread_id
# For production, you may want Redis or a DB for persistence
thread_map = {}
class Message(BaseModel):
language: str
text: str
session_id: str | None = None # Optional session ID from client
class QuizGeneratePayload(BaseModel):
language: Optional[str] = "English"
scripture: Optional[str] = None
complexity: Optional[str] = None
mode: Optional[str] = None
session_id: Optional[str] = None # Optional session ID from client
class QuizEvalPayload(BaseModel):
language: Optional[str] = "English"
q: Question
answer: str
session_id: Optional[str] = None # Optional session ID from client
LANG_NATIVE_NAMES = {
"en": "English",
"fr": "Français",
"es": "Español",
"hi": "हिन्दी",
"bn": "বাংলা",
"te": "తెలుగు",
"mr": "मराठी",
"ta": "தமிழ்",
"ur": "اردو",
"gu": "ગુજરાતી",
"kn": "ಕನ್ನಡ",
"ml": "മലയാളം",
"pa": "ਪੰਜਾਬੀ",
"as": "অসমীয়া",
"mai": "मैथिली",
"sd": "سنڌي",
"sat": "ᱥᱟᱱᱛᱟᱲᱤ",
}
@router.get("/languages")
async def handle_fetch_languages():
supported_lang_codes = [
"en",
"fr",
"es",
"hi",
"bn",
"te",
"mr",
"ta",
"ur",
"gu",
"kn",
"ml",
"pa",
"as",
"mai",
"sd",
"sat",
]
languages = []
for code in supported_lang_codes:
lang = pycountry.languages.get(alpha_2=code) or pycountry.languages.get(
alpha_3=code
)
if lang is None:
continue # skip unknown codes
english_name = lang.name
native_name = LANG_NATIVE_NAMES.get(code, english_name)
languages.append(
{
"code": code,
"name": english_name,
"native_name": native_name,
}
)
languages.sort(key=lambda x: x["name"])
return languages
@router.get("/languages_v2")
async def fn_handle_fetch_languages_v2():
val = await handle_fetch_languages_v2()
return val
@router.post("/greet")
async def handle_greet(msg: Message):
markdown = "Namaskaram 🙏 I am **bhashyam.ai** and I can help you explore the following scriptures:\n---\n"
for scripture in sorted(SanatanConfig().scriptures, key=lambda doc: doc["title"]):
num_units = SanatanDatabase().count(
collection_name=scripture["collection_name"]
)
markdown += f"- {scripture['title']} : `{num_units}` {scripture["unit"]}s\n"
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
return {"reply": markdown, "session_id": session_id}
@router.post("/chat")
async def handle_chat(msg: Message, request: Request):
try:
# Use existing session_id if provided, else generate new
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
print(session_id, ": user sent message : ", msg.text)
# Get or create a persistent thread_id for this session
if session_id not in thread_map:
thread_map[session_id] = str(uuid.uuid4())
thread_id = thread_map[session_id]
# Call your graph/chat function
reply_text = chat(
debug_mode=False,
message=msg.text,
history=None,
thread_id=thread_id,
preferred_language=msg.language or "English",
)
# Return both reply and session_id to the client
return {"reply": reply_text, "session_id": session_id}
except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"reply": f"Error: {e}"})
@router.post("/quiz/generate")
async def handle_quiz_generate(payload: QuizGeneratePayload, request: Request):
q = generate_question(
collection=payload.scripture
or random.choice(
[
s["collection_name"]
for s in SanatanConfig.scriptures
if s["collection_name"] != "yt_metadata"
]
),
complexity=payload.complexity
or random.choice(["beginner", "intermediate", "advanced"]),
mode=payload.mode or random.choice(["mcq", "open"]),
preferred_lamguage=payload.language or "English",
)
print(q.model_dump_json(indent=1))
return q.model_dump()
@router.post("/quiz/eval")
async def handle_quiz_eval(payload: QuizEvalPayload, request: Request):
result = validate_answer(
payload.q, payload.answer, preferred_language=payload.language or "English"
)
print(result.model_dump_json(indent=1))
return result
@router.get("/scriptures")
async def handle_get_scriptures():
return_values = {}
for scripture in SanatanConfig().scriptures:
if scripture["collection_name"] != "yt_metadata":
return_values[scripture["collection_name"]] = scripture["title"]
return return_values
@router.post("/scripture")
async def get_scripture(req: ScriptureRequest):
response = await get_scripture_from_db(req)
return response
@router.get("/scripture_configs")
async def get_scripture_configs():
scriptures = []
config = SanatanConfig()
for s in config.scriptures:
num_units = SanatanDatabase().count(collection_name=s["collection_name"])
# Deep copy metadata_fields so we don’t mutate the original config
metadata_fields = []
for f in s.get("metadata_fields", []):
f_copy = dict(f)
lov = f_copy.get("lov")
if callable(lov): # evaluate the function
try:
f_copy["lov"] = lov()
except Exception as e:
f_copy["lov"] = []
metadata_fields.append(f_copy)
scriptures.append(
{
"name": s["name"], # e.g. "bhagavad_gita"
"title": s["title"], # e.g. "Bhagavad Gita"
"banner_url": s.get("banner_url", None),
"category": s["category"], # e.g. "Philosophy"
"unit": s["unit"], # e.g. "verse" or "page"
"unit_field": s.get("unit_field", s.get("unit")),
"total": num_units,
"enabled": "field_mapping" in s,
"source": s.get("source", ""),
"credits": s.get(
"credits", {"art": [], "data": [], "audio": [], "video": []}
),
"metadata_fields": metadata_fields,
"field_mapping": config.remove_callables(s.get("field_mapping", {})),
}
)
return {"scriptures": sorted(scriptures, key=lambda s: s["title"])}
class ScriptureFirstSearchRequst(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search")
async def search_scripture_find_first_match(
scripture_name: str,
req: ScriptureFirstSearchRequst,
):
"""
Search scripture collection and return the first matching result after applying audio filter.
"""
filter_obj = req.filter_obj
has_audio = req.has_audio
try:
logger.info(
"search_scripture_find_first_match: searching for %s with filters=%s | has_audio=%s",
scripture_name,
filter_obj,
has_audio,
)
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Fetch all matches
if has_audio:
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None, # Fetch all to apply audio filter
page_size=None,
)
else:
# optimization. get only first match if no has_audio parameter is provided.
result = db.fetch_first_match(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
)
results = {
"ids": list(result["ids"]),
"documents": list(result["documents"]),
"metadatas": list(result["metadatas"]),
"total_matches": 1,
}
formatted_results = []
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
# 2️⃣ Apply has_audio filter
if has_audio and formatted_results:
if has_audio == AudioType.none:
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
formatted_results = [
r
for r in formatted_results
if r["_global_index"] not in all_audio_indices
]
else:
audio_indices = set()
if has_audio == AudioType.any:
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(
scripture_name, atype
)
audio_indices.update(indices)
else:
audio_indices.update(
await svc_get_indices_with_audio(scripture_name, has_audio)
)
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Sort by global index
formatted_results.sort(key=lambda x: x["_global_index"])
# print(f"formatted_results = {formatted_results}")
# 4️⃣ Return only the first valid result
return {"results": formatted_results[:1] if formatted_results else []}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
class ScriptureMultiSearchRequest(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
page: int = 1
page_size: int = 20
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search/all")
async def search_scripture_find_all_matches(
scripture_name: str, req: ScriptureMultiSearchRequest
):
"""
Search scripture collection and return all matching results with pagination.
- `scripture_name`: Name of the collection
- `filter_obj`: MetadataWhereClause (filters, groups, operator)
- `page`: 1-based page number
- `page_size`: Number of results per page
- `has_audio`: optional. can take values any|none|recitation|virutham|upanyasam
"""
filter_obj = req.filter_obj
page = req.page or 1
page_size = req.page_size or 20
has_audio = req.has_audio
logger.info(
"search_scripture_find_all_matches: searching for %s | filters=%s | page=%s | page_size=%s | has_audio=%s",
scripture_name,
filter_obj,
page,
page_size,
has_audio,
)
try:
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Decide how much to fetch
fetch_all = has_audio is not None
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None if fetch_all else page,
page_size=None if fetch_all else page_size,
)
formatted_results = []
all_indices = []
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
all_indices.append(canonical_doc["_global_index"])
# 2️⃣ Apply audio filter only if requested
if has_audio:
if has_audio == AudioType.none:
# Remove anything that *has* any audio
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
formatted_results = [
r
for r in formatted_results
if r["_global_index"] not in all_audio_indices
]
else:
# Filter for specific or 'any'
if has_audio == AudioType.any:
audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(
scripture_name, atype
)
audio_indices.update(indices)
else:
audio_indices = set(
await svc_get_indices_with_audio(scripture_name, has_audio)
)
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Paginate *after* filtering if needed
total_matches = None
if fetch_all:
total_matches = len(formatted_results)
else:
total_matches = db.count_where(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
)
if fetch_all:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = formatted_results[start_idx:end_idx]
else:
paginated_results = formatted_results
return {
"results": paginated_results,
"total_matches": total_matches,
"page": page,
"page_size": page_size,
}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
@router.post("/audio")
async def generate_audio_urls(req: AudioRequest):
logger.info("generate_audio_urls: %s", req)
audio_urls = await svc_get_audio_urls(req)
return audio_urls
@router.post("/video")
async def generate_audio_urls(req: VideoRequest):
logger.info("generate_audio_urls: %s", req)
video_urls = await svc_get_video_urls(req)
return video_urls
@router.get("/scripture_categories")
def route_get_scripture_categories():
return get_scripture_categories()
@router.get("/donation/products")
def route_get_donation_product_ids(include_tests: bool = False):
products = [
{"id": "donation_unit_0100", "os" : "android"},
{"id": "donation_unit_0500", "os" : "android"},
{"id": "donation_unit_1000", "os" : "android"},
{"id": "donation_unit_2500", "os" : "android"},
{"id": "donation_unit_5000", "os" : "android"},
{"id": "ios_donation_unit_0100", "os" : "ios"},
{"id": "ioc_donation_unit_0500", "os" : "ios"},
{"id": "ios_donation_unit_1000", "os" : "ios"},
{"id": "ios_donation_unit_2500", "os" : "ios"},
{"id": "ioc_donation_unit_5000", "os" : "ios"},
]
if include_tests:
products += [
{"id": "android.test.purchased", "os" : "android"},
{"id": "android.test.canceled", "os" : "android"},
{"id": "android.test.refunded", "os" : "android"},
{"id": "android.test.item_unavailable", "os" : "android"},
]
return products
@router.get("/discourse/list")
async def get_all_discourses(
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
per_page: int = Query(10, ge=1, le=100, description="Number of items per page"),
):
"""
Returns a paginated list of discourse topics.
Each topic includes:
- id
- topic_name
- thumbnail_url
"""
result = await get_discourse_summaries(page=page, per_page=per_page)
return result
@router.get("/discourse/find/{topic_id}")
async def get_discourse_detail(topic_id: int):
"""
Returns the full details of a discourse topic by its unique ID.
"""
topic = await get_discourse_by_id(topic_id)
if not topic:
raise HTTPException(status_code=404, detail="Discourse topic not found")
return topic
@router.post("/translate")
@limiter.limit("5/minute")
async def translate_text(request: Request, body: TranslationRequest):
resp = await svc_translate_text(request, body)
return resp
ADMIN_KEY = os.getenv("FIREBASE_API_ADMIN_KEY", "super-secret-admin-key")
@router.post("/send_fcm")
async def send_fcm_endpoint(
request: FcmRequest,
x_admin_key: str = Header(None)
):
if x_admin_key != ADMIN_KEY:
raise HTTPException(status_code=403, detail="Unauthorized")
return await fcm_service.send_fcm(request)
@router.post("/summarize_scripture_verse")
async def summarize_scripture_verse(req: ScriptureRequest):
response = await svc_summarize_scripture_verse(req)
return response