from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.state import CompiledStateGraph from modules.nodes.chat import chatNode, tools from modules.nodes.conditions import branching_condition from modules.nodes.dedup import dedup_tool_call from modules.nodes.init import init_system_prompt_node from modules.nodes.state import ChatState from modules.nodes.tool_calls import increment_tool_calls from modules.nodes.validator import validatorNode from langgraph.prebuilt import ToolNode import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def generate_graph() -> CompiledStateGraph: memory = MemorySaver() graph = StateGraph(ChatState) graph.add_node("init", init_system_prompt_node) graph.add_node("llm", chatNode) graph.add_node("dedup", dedup_tool_call) graph.add_node("tools", ToolNode(tools)) graph.add_node("count_tools", increment_tool_calls) graph.add_node("validator", validatorNode) graph.add_edge(START, "init") graph.add_edge("init", "llm") # branching happens *after* dedup graph.add_conditional_edges( "llm", branching_condition, {"tools": "tools", "validator": "validator", "__end__": END}, ) graph.add_edge("tools", "count_tools") graph.add_edge("count_tools", "llm") graph.add_edge("validator", END) return graph.compile(checkpointer=memory)