File size: 2,869 Bytes
eec1daf
 
 
 
 
 
 
 
 
 
 
 
7c08cff
 
 
 
 
 
eec1daf
 
 
7c08cff
 
 
 
 
 
eec1daf
 
 
 
 
 
 
 
 
7c08cff
eec1daf
 
 
 
 
 
7c08cff
eec1daf
 
 
7c08cff
 
eec1daf
 
 
 
 
 
 
 
 
7c08cff
eec1daf
 
 
 
 
 
 
 
 
 
 
 
7c08cff
eec1daf
 
7c08cff
eec1daf
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import asyncio
from typing import List, AsyncIterator, TypedDict

from langgraph.graph import StateGraph, END
import gradio as gr


# 1. Define chat state
class ChatState(TypedDict):
    messages: List[dict]


def start_node(state: ChatState):
    messages = state["messages"] + [{"role": "assistant", "content": "start node: Thinking ...⏳"}]
    return {"messages": messages}


# 2. Respond Node β€” stream word by word, then ⏳
# 2. Respond Node β€” stream word by word, then ⏳
async def respond_node(state: ChatState) -> AsyncIterator[ChatState]:
    messages = state["messages"]
    print(messages)

    await asyncio.sleep(1)  # adjust timing as needed
    thinking_msg = "respond_node: give me some more time. please"
    await asyncio.sleep(0.2)  # adjust timing as needed
    messages[-1]["content"] = thinking_msg + " ⏳"
    yield {"messages": messages}


# 3. Post-process node β€” replace ⏳ with βœ…
async def post_process_node(state: ChatState) -> ChatState:
    messages = state["messages"]
    await asyncio.sleep(1.5)

    if messages[-1]["role"] == "assistant" and "⏳" in messages[-1]["content"]:
        messages[-1]["content"] = "post_process_node: here is your final result!" + "βœ…"

    return {"messages": messages}


# 4. Define the graph
graph_builder = StateGraph(ChatState)
graph_builder.add_node("start", start_node)
graph_builder.add_node("respond", respond_node)
graph_builder.add_node("post_process", post_process_node)

graph_builder.set_entry_point("start")
graph_builder.add_edge("start", "respond")
graph_builder.add_edge("respond", "post_process")
graph_builder.add_edge("post_process", END)

graph = graph_builder.compile()


# 5. Gradio streaming handler
async def bot_respond_streaming(message: str, history: List[dict]) -> AsyncIterator:
    state = {"messages": (history or []) + [{"role": "user", "content": message}]}

    async for step in graph.astream(state):
        # LangGraph yields steps like {"respond": ChatState, "post_process": ChatState}
        for node_output in step.values():
            yield node_output["messages"]


# 6. Gradio UI
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Sanatan AI", type="messages")
    with gr.Row():
        textbox = gr.Textbox(placeholder="Ask something...", scale=8, container=False)
        send_btn = gr.Button("Send", scale=1)

    def user_submit(message, history):
        return "", history + [{"role": "user", "content": message}]

    send_event = send_btn.click(user_submit, [textbox, chatbot], [textbox, chatbot])
    send_event.then(bot_respond_streaming, [textbox, chatbot], chatbot)

    # Optional: allow enter to also send message
    submit_event = textbox.submit(user_submit, [textbox, chatbot], [textbox, chatbot])
    submit_event.then(bot_respond_streaming, [textbox, chatbot], chatbot)


if __name__ == "__main__":
    demo.launch()