Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import json | |
| import os | |
| import random | |
| import asyncio | |
| import logging | |
| import time | |
| import traceback | |
| import uuid | |
| from html import escape | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from langchain_core.messages.ai import AIMessageChunk, AIMessage | |
| from langchain_core.messages.system import SystemMessage | |
| from langchain_core.messages.tool import ToolMessage | |
| from config import SanatanConfig | |
| from db import SanatanDatabase | |
| from drive_downloader import ZipDownloader | |
| from graph_helper import generate_graph | |
| # Logging | |
| logging.basicConfig() | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| graph = generate_graph() | |
| def init(): | |
| load_dotenv(override=True) | |
| try: | |
| SanatanDatabase().test_sanity() | |
| except Exception as e: | |
| logger.warning("Sanity Test Failed - %s", e) | |
| logger.info("Downloading database ...") | |
| downloader = ZipDownloader( | |
| service_account_json=os.getenv("GOOGLE_SERVICE_ACCOUNT_JSON") | |
| ) | |
| zip_path = downloader.download_zip_from_drive( | |
| file_id=os.getenv("CHROMADB_FILE_ID"), | |
| output_path=SanatanConfig.dbStorePath, | |
| ) | |
| downloader.unzip(zip_path, extract_to="./") | |
| def init_session(): | |
| return str(uuid.uuid4()) | |
| def render_message_with_tooltip(content: str, max_chars=200): | |
| short = escape(content[:max_chars]) + ("…" if len(content) > max_chars else "") | |
| return f"<div title='{escape(content)}'>{short}</div>" | |
| thinking_verbs = [ | |
| "thinking", | |
| "processing", | |
| "crunching data", | |
| "please wait", | |
| "just a few more seconds", | |
| "closing in", | |
| "analyzing", | |
| "reasoning", | |
| "computing", | |
| "synthesizing insight", | |
| "searching through the cosmos", | |
| "decoding ancient knowledge", | |
| "scanning the scriptures", | |
| "accessing divine memory", | |
| "gathering wisdom", | |
| "consulting the rishis", | |
| "listening to the ātmā", | |
| "channeling sacred energy", | |
| "unfolding the divine word", | |
| "meditating on the meaning", | |
| "reciting from memory", | |
| "traversing the Vedas", | |
| "seeking the inner light", | |
| "invoking paramārtha", | |
| "putting it all together", | |
| "digging deeper", | |
| "making sense of it", | |
| "connecting the dots", | |
| "almost there", | |
| "getting closer", | |
| "wrapping it up", | |
| "piecing it together", | |
| "swirling through verses", | |
| "diving into the ocean of knowledge", | |
| "lighting the lamp of understanding", | |
| "walking the path of inquiry", | |
| "aligning stars of context", | |
| ] | |
| async def chat_wrapper(message, history, thread_id, debug): | |
| if debug: | |
| async for chunk in chat_streaming(debug, message, history, thread_id): | |
| yield chunk | |
| else: | |
| response = chat(debug, message, history, thread_id) | |
| yield response | |
| def chat(debug_mode, message, history, thread_id): | |
| config = {"configurable": {"thread_id": thread_id}} | |
| response = graph.invoke( | |
| {"debug_mode": debug_mode, "messages": [{"role": "user", "content": message}]}, | |
| config=config, | |
| ) | |
| return response["messages"][-1].content | |
| def add_node_to_tree( | |
| node_tree: list[str], node_label: str, tooltip: str = "no arguments to show" | |
| ) -> list[str]: | |
| if tooltip: | |
| tooltip = escape(tooltip).replace("'", "'") | |
| node_with_tooltip = ( | |
| f"<span class='node-label' title='{tooltip}'>{node_label}</span>" | |
| ) | |
| else: | |
| node_with_tooltip = node_label | |
| node_tree[-1] = node_with_tooltip | |
| node_tree.append("<span class='spinner'> </span>") | |
| return node_tree | |
| def end_node_tree(node_tree: list[str]) -> list[str]: | |
| node_tree[-1] = "🏁" | |
| return node_tree | |
| def get_args_for_toolcall(tool_calls_buffer: dict, tool_call_id: str): | |
| return ( | |
| tool_calls_buffer[tool_call_id]["args_str"] | |
| if tool_call_id in tool_calls_buffer | |
| and "args_str" in tool_calls_buffer[tool_call_id] | |
| else "" | |
| ) | |
| async def chat_streaming(debug_mode: bool, message, history, thread_id): | |
| state = { | |
| "debug_mode": debug_mode, | |
| "messages": (history or []) + [{"role": "user", "content": message}], | |
| } | |
| config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 15} | |
| start_time = time.time() | |
| streamed_response = "" | |
| final_response = "" | |
| # final_node = "validator" | |
| MAX_CONTENT = 500 | |
| try: | |
| node_tree = ["🚩", "<span class='spinner'> </span>"] | |
| tool_calls_buffer = {} | |
| async for msg, metadata in graph.astream( | |
| state, config=config, stream_mode="messages" | |
| ): | |
| node = metadata.get("langgraph_node", "?") | |
| name = getattr(msg, "name", "-") | |
| if not isinstance(msg, ToolMessage): | |
| node_icon = "🧠" | |
| else: | |
| node_icon = "⚙️" | |
| node_label = f"{node}" | |
| tool_label = f"{name or ''}" | |
| if tool_label: | |
| node_label = node_label + f":{tool_label}" | |
| label = f"{node_icon} {node_label}" | |
| tooltip = "" | |
| if isinstance(msg, ToolMessage): | |
| tooltip = get_args_for_toolcall(tool_calls_buffer, msg.tool_call_id) | |
| # logger.info("tooltip = ", tooltip) | |
| # checking for -2 last but one. since last entry is always a spinner | |
| if node_tree[-2] != label: | |
| add_node_to_tree(node_tree, label, tooltip) | |
| full: str = escape(msg.content) | |
| truncated = (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full | |
| def generate_processing_message(): | |
| return ( | |
| f"<div class='thinking-bubble'><em>🤔{random.choice(thinking_verbs)} ...</em></div>" | |
| ) | |
| if ( | |
| not isinstance(msg, ToolMessage) | |
| and not isinstance(msg, SystemMessage) | |
| and not isinstance(msg, AIMessageChunk) | |
| ): | |
| logger.info("msg = %s", msg) | |
| if isinstance(msg, ToolMessage): | |
| logger.debug("tool message = %s", msg) | |
| html = ( | |
| f"<div class='thinking-bubble'><em>🤔 {msg.name} tool: {random.choice(thinking_verbs)} ...</em></div>" | |
| ) | |
| yield f"### { ' → '.join(node_tree)}\n{html}" | |
| elif isinstance(msg, AIMessageChunk): | |
| def truncate_middle(text, front=50, back=50): | |
| if not text: | |
| return "" | |
| if len(text) <= front + back: | |
| return text | |
| return f"{text[:front]}…{text[-back:]}".replace( | |
| "\n", "" | |
| ) # remove new lines. | |
| if not msg.content: | |
| # logger.warning("*** No Message Chunk!") | |
| yield f"### { " → ".join(node_tree)}\n{generate_processing_message()}\n<div class='intermediate-output'>{escape(truncate_middle(streamed_response))}</div>" | |
| else: | |
| # Stream intermediate messages with transparent style | |
| # if node != final_node: | |
| streamed_response += msg.content | |
| yield f"### { ' → '.join(node_tree) }\n<div class='intermediate-output'>{escape(truncate_middle(streamed_response))}</div>" | |
| # else: | |
| # Buffer the final validated response instead of yielding | |
| final_response += msg.content | |
| if msg.tool_call_chunks: | |
| for tool_call_chunk in msg.tool_call_chunks: | |
| logger.debug("*** tool_call_chunk = ", tool_call_chunk) | |
| if tool_call_chunk["id"] is not None: | |
| tool_call_id = tool_call_chunk["id"] | |
| if tool_call_id not in tool_calls_buffer: | |
| tool_calls_buffer[tool_call_id] = { | |
| "name": "", | |
| "args_str": "", | |
| "id": tool_call_id, | |
| "type": "tool_call", | |
| } | |
| # Accumulate tool call name and arguments | |
| if tool_call_chunk["name"] is not None: | |
| tool_calls_buffer[tool_call_id]["name"] += tool_call_chunk[ | |
| "name" | |
| ] | |
| if tool_call_chunk["args"] is not None: | |
| tool_calls_buffer[tool_call_id][ | |
| "args_str" | |
| ] += tool_call_chunk["args"] | |
| else: | |
| logger.debug("message = ", type(msg), msg.content[:100]) | |
| full: str = escape(msg.content) | |
| truncated = ( | |
| (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full | |
| ) | |
| html = ( | |
| f"<div class='thinking-bubble'><em>🤔 {random.choice(thinking_verbs)} ...</em></div>" | |
| f"<div style='opacity: 0.1'>" | |
| f"<strong>Telling myself:</strong> {truncated or '...'}" | |
| f"</div>" | |
| ) | |
| yield f"### { " → ".join(node_tree)}\n{html}" | |
| if getattr(msg, "tool_calls", []): | |
| logger.info("ELSE::tool_calls = %s", msg.tool_calls) | |
| node_tree[-1] = "✅" | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| final_response = ( | |
| f"\n{final_response}" f"\n\n⏱️ Processed in {duration:.2f} seconds" | |
| ) | |
| buffer = f"### {' → '.join(node_tree)}\n" | |
| yield buffer | |
| for c in final_response: | |
| buffer += c | |
| yield buffer | |
| await asyncio.sleep(0.0005) | |
| logger.debug("************************************") | |
| # Now, you can process the complete tool calls from the buffer | |
| for tool_call_id, accumulated_tool_call in tool_calls_buffer.items(): | |
| # Attempt to parse arguments only if the 'args_str' isn't empty | |
| if accumulated_tool_call["args_str"]: | |
| try: | |
| parsed_args = json.loads(accumulated_tool_call["args_str"]) | |
| logger.debug(f"Tool Name: {accumulated_tool_call['name']}") | |
| logger.debug(f"Tool Arguments: {parsed_args}") | |
| except json.JSONDecodeError: | |
| logger.debug( | |
| f"Partial arguments for tool {accumulated_tool_call['name']}: {accumulated_tool_call['args_str']}" | |
| ) | |
| except asyncio.CancelledError: | |
| logger.warning("⚠️ Request cancelled by user") | |
| node_tree = end_node_tree(node_tree=node_tree) | |
| yield ( | |
| f"### {' → '.join(node_tree)}" | |
| "\n⚠️⚠️⚠️ Request cancelled by user" | |
| "\nhere is what I got so far ...\n" | |
| f"\n{streamed_response}" | |
| ) | |
| # Important: re-raise if you want upstream to also know | |
| # raise | |
| return | |
| except Exception as e: | |
| logger.error("❌❌❌ Error processing request: %s", e) | |
| traceback.print_exc() | |
| node_tree = end_node_tree(node_tree=node_tree) | |
| yield ( | |
| f"### { " → ".join(node_tree)}" | |
| f"\n❌❌❌ Error processing request : {str(e)}" | |
| "\nhere is what I got so far ...\n" | |
| f"\n{streamed_response}" | |
| ) | |
| return | |
| # UI Elements | |
| thread_id = gr.State(init_session) | |
| supported_scriptures = "\n - ".join( | |
| [ | |
| f"📖 **{scripture['title']}** [source]({scripture['source']})" | |
| for scripture in SanatanConfig.scriptures | |
| ] | |
| ) | |
| init() | |
| message_textbox = gr.Textbox( | |
| placeholder="Search the scriptures ...", submit_btn=True, stop_btn=True | |
| ) | |
| with gr.Blocks( | |
| theme=gr.themes.Citrus(), | |
| title="Sanatan-AI", | |
| css=""" | |
| table { | |
| border-collapse: collapse; | |
| width: 90%; | |
| } | |
| table, th, td { | |
| border: 1px solid #ddd; | |
| padding: 6px; | |
| font-size: small; | |
| } | |
| td { | |
| word-wrap: break-word; | |
| white-space: pre-wrap; /* preserves line breaks but wraps long lines */ | |
| max-width: 300px; /* control width */ | |
| vertical-align: top; | |
| } | |
| .spinner { | |
| display: inline-block; | |
| width: 1em; | |
| height: 1em; | |
| border: 2px solid transparent; | |
| border-top: 2px solid #333; | |
| border-radius: 50%; | |
| animation: spin 0.8s linear infinite; | |
| vertical-align: middle; | |
| margin-left: 0.5em; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| .thinking-bubble { | |
| opacity: 0.5; | |
| font-style: italic; | |
| animation: pulse 1.5s infinite; | |
| margin-bottom: 5px; | |
| } | |
| @keyframes pulse { | |
| 0% { opacity: 0.3; } | |
| 50% { opacity: 1; } | |
| 100% { opacity: 0.3; } | |
| } | |
| .node-label { | |
| cursor: help; | |
| border-bottom: 1px dotted #aaa; | |
| } | |
| .intermediate-output { | |
| opacity: 0.4; | |
| font-style: italic; | |
| white-space: nowrap; | |
| overflow: hidden; | |
| text-overflow: ellipsis; | |
| } | |
| """, | |
| ) as gradio_app: | |
| show_sidebar = gr.State(True) | |
| # with gr.Column(scale=1, visible=show_sidebar.value) as sidebar_container: | |
| with gr.Sidebar(open=show_sidebar.value) as sidebar: | |
| # session_id = gr.Textbox(value=f"Thread: {thread_id}") | |
| # gr.Markdown(value=f"{'\n'.join([msg['content'] for msg in intro_messages])}") | |
| gr.Markdown( | |
| value="Namaskaram 🙏 I am Sanatan-Bot and I can help you explore the following scriptures:\n\n" | |
| ) | |
| async def populate_chat_input(text: str): | |
| buffer = "" | |
| for c in text: | |
| buffer += c | |
| yield buffer | |
| await asyncio.sleep(0.05) | |
| return | |
| def close_side_bar(): | |
| print("close_side_bar invoked") | |
| yield gr.update(open=False) | |
| for scripture in sorted(SanatanConfig.scriptures, key=lambda d: d.get("title")): | |
| with gr.Accordion(label=f"{scripture['title']}", open=False): | |
| gr.Markdown(f"* Source: [🔗 click here]({scripture['source']})") | |
| gr.Markdown(f"* Language: {scripture['language']}") | |
| gr.Markdown(f"* Examples :") | |
| with gr.Row(): | |
| for example_label, example_text in zip( | |
| scripture["example_labels"], scripture["examples"] | |
| ): | |
| btn = gr.Button(value=f"{example_label}", size="sm") | |
| btn.click(close_side_bar, outputs=[sidebar]).then( | |
| populate_chat_input, | |
| inputs=[gr.State(example_text)], | |
| outputs=[message_textbox], | |
| ) | |
| gr.Markdown(value="------") | |
| debug_checkbox = gr.Checkbox(label="Debug (Streaming)", value=True) | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| avatar_images=("assets/avatar_user.png", "assets/adiyen_bot.png"), | |
| # value=intro_messages, | |
| label="Sanatan-AI-Bot", | |
| show_copy_button=True, | |
| show_copy_all_button=True, | |
| type="messages", | |
| height=700, | |
| render_markdown=True, | |
| ) | |
| chatInterface = gr.ChatInterface( | |
| title="Sanatan-AI", | |
| fn=chat_wrapper, | |
| additional_inputs=[thread_id, debug_checkbox], | |
| chatbot=chatbot, | |
| textbox=message_textbox, | |
| ) | |
| # app.launch() |