File size: 1,453 Bytes
fd1b271
 
 
63d1774
 
 
 
 
 
 
 
fd1b271
 
 
 
 
 
 
 
 
c18d7a8
63d1774
fd1b271
63d1774
7290ba6
63d1774
fd1b271
c18d7a8
63d1774
 
bc05cd4
 
 
ec6bd64
bc05cd4
63d1774
 
 
7290ba6
63d1774
fd1b271
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import CompiledStateGraph
from modules.nodes.chat import chatNode, tools
from modules.nodes.conditions import branching_condition
from modules.nodes.dedup import dedup_tool_call
from modules.nodes.init import init_system_prompt_node
from modules.nodes.state import ChatState
from modules.nodes.tool_calls import increment_tool_calls
from modules.nodes.validator import validatorNode
from langgraph.prebuilt import ToolNode
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

def generate_graph() -> CompiledStateGraph:
    memory = MemorySaver()
    graph = StateGraph(ChatState)
    graph.add_node("init", init_system_prompt_node)
    graph.add_node("llm", chatNode)
    graph.add_node("dedup", dedup_tool_call)
    graph.add_node("tools", ToolNode(tools))
    graph.add_node("count_tools", increment_tool_calls)
    graph.add_node("validator", validatorNode)

    graph.add_edge(START, "init")
    graph.add_edge("init", "llm")

    # branching happens *after* dedup
    graph.add_conditional_edges(
        "llm",
        branching_condition,
        {"tools": "tools", "validator": "validator", "__end__": END},
    )

    graph.add_edge("tools", "count_tools")
    graph.add_edge("count_tools", "llm")
    graph.add_edge("validator", END)

    return graph.compile(checkpointer=memory)