File size: 3,151 Bytes
6cbca40
 
 
c0827a3
 
6cbca40
61e4b1e
c0827a3
 
 
 
6cbca40
 
 
c0827a3
 
 
 
 
 
 
6cbca40
 
c0827a3
 
 
 
 
 
 
 
 
 
 
6cbca40
c0827a3
 
 
61e4b1e
6cbca40
 
 
 
 
 
 
61e4b1e
6cbca40
 
c0827a3
 
 
 
 
6cbca40
 
 
c0827a3
 
 
6cbca40
c0827a3
 
 
61e4b1e
6cbca40
 
 
 
 
 
 
61e4b1e
6cbca40
 
c0827a3
 
 
 
 
6cbca40
 
 
c0827a3
61e4b1e
c0827a3
6cbca40
 
c0827a3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import TypedDict
from src.config.llm import model
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool
from langchain_core.messages import RemoveMessage
from .prompt import roleplay_prompt, guiding_prompt
from .tools import create_end_conversation_tool
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from loguru import logger


class State(TypedDict):
    active_agent: str | None
    messages: Annotated[list[AnyMessage], add_messages]
    scenario_title: str
    scenario_description: str
    scenario_context: str
    your_role: str
    key_vocabulary: str


def trim_history(state: State):
    if not state.get("active_agent"):
        state["active_agent"] = "Roleplay Agent"
    history = state.get("messages", [])
    if len(history) > 25:
        num_to_remove = len(history) - 5
        remove_messages = [
            RemoveMessage(id=history[i].id) for i in range(num_to_remove)
        ]
        state["messages"] = remove_messages
    return state


async def call_roleplay(state: State):
    logger.info("Calling roleplay agent...")

    roleplay_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Guiding Agent",
                description="Hand off to Guiding Agent when user shows signs of needing help, guidance, or struggles with communication",
            ),
            create_end_conversation_tool(),
        ],
        prompt=roleplay_prompt.format(
            scenario_title=state["scenario_title"],
            scenario_description=state["scenario_description"],
            scenario_context=state["scenario_context"],
            your_role=state["your_role"],
            key_vocabulary=state["key_vocabulary"],
        ),
        name="Roleplay Agent",
    )
    response = await roleplay_agent.ainvoke({"messages": state["messages"]})

    return {"messages": response["messages"]}


async def call_guiding_agent(state: State):
    logger.info("Calling guiding agent...")

    guiding_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Roleplay Agent",
                description="Hand off back to Roleplay Agent when user is ready for scenario practice and shows improved confidence",
            ),
            create_end_conversation_tool(),
        ],
        prompt=guiding_prompt.format(
            scenario_title=state["scenario_title"],
            scenario_description=state["scenario_description"],
            scenario_context=state["scenario_context"],
            your_role=state["your_role"],
            key_vocabulary=state["key_vocabulary"],
        ),
        name="Guiding Agent",
    )
    response = await guiding_agent.ainvoke({"messages": state["messages"]})

    return {"messages": response["messages"]}


def route_to_active_agent(state: State) -> str:
    if state["active_agent"] == "Roleplay Agent":
        return "Roleplay Agent"
    elif state["active_agent"] == "Guiding Agent":
        return "Guiding Agent"