sanatan_ai / graph_helper.py
vikramvasudevan's picture
Upload folder using huggingface_hub
63d1774 verified
raw
history blame
1.45 kB
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import CompiledStateGraph
from modules.nodes.chat import chatNode, tools
from modules.nodes.conditions import branching_condition
from modules.nodes.dedup import dedup_tool_call
from modules.nodes.init import init_system_prompt_node
from modules.nodes.state import ChatState
from modules.nodes.tool_calls import increment_tool_calls
from modules.nodes.validator import validatorNode
from langgraph.prebuilt import ToolNode
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def generate_graph() -> CompiledStateGraph:
memory = MemorySaver()
graph = StateGraph(ChatState)
graph.add_node("init", init_system_prompt_node)
graph.add_node("llm", chatNode)
graph.add_node("dedup", dedup_tool_call)
graph.add_node("tools", ToolNode(tools))
graph.add_node("count_tools", increment_tool_calls)
graph.add_node("validator", validatorNode)
graph.add_edge(START, "init")
graph.add_edge("init", "llm")
# branching happens *after* dedup
graph.add_conditional_edges(
"llm",
branching_condition,
{"tools": "tools", "validator": "validator", "__end__": END},
)
graph.add_edge("tools", "count_tools")
graph.add_edge("count_tools", "llm")
graph.add_edge("validator", END)
return graph.compile(checkpointer=memory)