ABAO77 commited on
Commit
c0827a3
·
1 Parent(s): b7a3e32

Refactor roleplay agent implementation and update session handling for improved message processing

Browse files
sessions.json CHANGED
@@ -1,16 +0,0 @@
1
- [
2
- {
3
- "id": "82a6779d-ad13-4edd-a046-575e563a4348",
4
- "name": "New Conversation",
5
- "created_at": "2025-08-21T11:57:23.992279",
6
- "last_message": "[Audio message]",
7
- "message_count": 37
8
- },
9
- {
10
- "id": "4fbf6c50-6054-4f3d-ac4e-d8281c306d72",
11
- "name": "New Conversation",
12
- "created_at": "2025-08-21T11:57:23.993885",
13
- "last_message": null,
14
- "message_count": 0
15
- }
16
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agents/role_play/flow.py CHANGED
@@ -1,24 +1,45 @@
1
  from langgraph.graph import StateGraph, START, END
2
- from .func import State
3
  from langgraph.graph.state import CompiledStateGraph
4
- from langgraph.store.memory import InMemoryStore
5
 
6
 
7
- class PrimaryChatBot:
8
  def __init__(self):
9
- self.builder = StateGraph(State)
10
 
11
  @staticmethod
12
- def routing(state: State):
13
- pass
 
 
 
14
 
15
- def node(self):
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def edge(self):
19
- pass
20
 
21
- def __call__(self) -> CompiledStateGraph:
22
- self.node()
23
- self.edge()
24
- return self.builder.compile(checkpointer=InMemoryStore())
 
1
  from langgraph.graph import StateGraph, START, END
2
+ from .func import State, trim_history, call_roleplay, call_guiding_agent
3
  from langgraph.graph.state import CompiledStateGraph
4
+ from langgraph.checkpoint.memory import InMemorySaver
5
 
6
 
7
+ class RolePlayAgent:
8
  def __init__(self):
9
+ pass
10
 
11
  @staticmethod
12
+ def route_to_active_agent(state: State) -> str:
13
+ if state["active_agent"] == "Roleplay Agent":
14
+ return "Roleplay Agent"
15
+ elif state["active_agent"] == "Guiding Agent":
16
+ return "Guiding Agent"
17
 
18
+ def node(self, graph: StateGraph):
19
+ graph.add_node("trim_history", trim_history)
20
+ graph.add_node("Roleplay Agent", call_roleplay, destinations=("Guiding Agent",))
21
+ graph.add_node(
22
+ "Guiding Agent", call_guiding_agent, destinations=("Roleplay Agent",)
23
+ )
24
+ return graph
25
+
26
+ def edge(self, graph: StateGraph):
27
+ graph.add_edge(START, "trim_history")
28
+ graph.add_conditional_edges(
29
+ "trim_history",
30
+ self.route_to_active_agent,
31
+ {
32
+ "Roleplay Agent": "Roleplay Agent",
33
+ "Guiding Agent": "Guiding Agent",
34
+ },
35
+ )
36
+ return graph
37
+
38
+ def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
39
+ graph = StateGraph(State)
40
+ graph: StateGraph = self.node(graph)
41
+ graph: StateGraph = self.edge(graph)
42
+ return graph.compile(checkpointer=checkpointer)
43
 
 
 
44
 
45
+ role_play_agent = RolePlayAgent()
 
 
 
src/agents/role_play/func.py CHANGED
@@ -1,17 +1,40 @@
1
  from typing import TypedDict
2
  from src.config.llm import model
3
- from langgraph.checkpoint.memory import InMemorySaver
4
  from langgraph.prebuilt import create_react_agent
5
- from langgraph_swarm import create_handoff_tool, create_swarm
 
6
  from .prompt import roleplay_prompt, guiding_prompt
 
 
 
 
7
 
8
 
9
  class State(TypedDict):
10
- pass
 
 
 
 
 
 
11
 
12
 
13
- def create_agents(scenario, checkpointer=InMemorySaver()):
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
15
  roleplay_agent = create_react_agent(
16
  model,
17
  [
@@ -21,15 +44,21 @@ def create_agents(scenario, checkpointer=InMemorySaver()):
21
  ),
22
  ],
23
  prompt=roleplay_prompt.format(
24
- scenario_title=scenario["scenario_title"],
25
- scenario_description=scenario["scenario_description"],
26
- scenario_context=scenario["scenario_context"],
27
- your_role=scenario["your_role"],
28
- key_vocabulary=scenario["key_vocabulary"],
29
  ),
30
  name="Roleplay Agent",
31
  )
 
 
 
32
 
 
 
 
33
  guiding_agent = create_react_agent(
34
  model,
35
  [
@@ -39,17 +68,20 @@ def create_agents(scenario, checkpointer=InMemorySaver()):
39
  ),
40
  ],
41
  prompt=guiding_prompt.format(
42
- scenario_title=scenario["scenario_title"],
43
- scenario_description=scenario["scenario_description"],
44
- scenario_context=scenario["scenario_context"],
45
- your_role=scenario["your_role"],
46
- key_vocabulary=scenario["key_vocabulary"],
47
  ),
48
  name="Guiding Agent",
49
  )
 
 
50
 
51
- workflow = create_swarm(
52
- [roleplay_agent, guiding_agent], default_active_agent="Roleplay Agent"
53
- )
54
 
55
- return workflow.compile(checkpointer)
 
 
 
 
 
1
  from typing import TypedDict
2
  from src.config.llm import model
 
3
  from langgraph.prebuilt import create_react_agent
4
+ from langgraph_swarm import create_handoff_tool
5
+ from langchain_core.messages import RemoveMessage
6
  from .prompt import roleplay_prompt, guiding_prompt
7
+ from typing_extensions import TypedDict, Annotated
8
+ from langchain_core.messages import AnyMessage
9
+ from langgraph.graph import add_messages
10
+ from loguru import logger
11
 
12
 
13
  class State(TypedDict):
14
+ active_agent: str | None
15
+ messages: Annotated[list[AnyMessage], add_messages]
16
+ scenario_title: str
17
+ scenario_description: str
18
+ scenario_context: str
19
+ your_role: str
20
+ key_vocabulary: str
21
 
22
 
23
+ def trim_history(state: State):
24
+ if not state.get("active_agent"):
25
+ state["active_agent"] = "Roleplay Agent"
26
+ history = state.get("messages", [])
27
+ if len(history) > 25:
28
+ num_to_remove = len(history) - 5
29
+ remove_messages = [
30
+ RemoveMessage(id=history[i].id) for i in range(num_to_remove)
31
+ ]
32
+ state["messages"] = remove_messages
33
+ return state
34
 
35
+
36
+ async def call_roleplay(state: State):
37
+ logger.info("Calling roleplay agent...")
38
  roleplay_agent = create_react_agent(
39
  model,
40
  [
 
44
  ),
45
  ],
46
  prompt=roleplay_prompt.format(
47
+ scenario_title=state["scenario_title"],
48
+ scenario_description=state["scenario_description"],
49
+ scenario_context=state["scenario_context"],
50
+ your_role=state["your_role"],
51
+ key_vocabulary=state["key_vocabulary"],
52
  ),
53
  name="Roleplay Agent",
54
  )
55
+ response = await roleplay_agent.ainvoke({"messages": state["messages"]})
56
+
57
+ return {"messages": response["messages"]}
58
 
59
+
60
+ async def call_guiding_agent(state: State):
61
+ logger.info("Calling guiding agent...")
62
  guiding_agent = create_react_agent(
63
  model,
64
  [
 
68
  ),
69
  ],
70
  prompt=guiding_prompt.format(
71
+ scenario_title=state["scenario_title"],
72
+ scenario_description=state["scenario_description"],
73
+ scenario_context=state["scenario_context"],
74
+ your_role=state["your_role"],
75
+ key_vocabulary=state["key_vocabulary"],
76
  ),
77
  name="Guiding Agent",
78
  )
79
+ response = await guiding_agent.ainvoke({"messages": state["messages"]})
80
+ return {"messages": response["messages"]}
81
 
 
 
 
82
 
83
+ def route_to_active_agent(state: State) -> str:
84
+ if state["active_agent"] == "Roleplay Agent":
85
+ return "Roleplay Agent"
86
+ elif state["active_agent"] == "Guiding Agent":
87
+ return "Guiding Agent"
src/apis/routes/chat_route.py CHANGED
@@ -10,7 +10,7 @@ from fastapi import (
10
  )
11
  from fastapi.responses import JSONResponse
12
  from src.utils.logger import logger
13
- from src.agents.role_play.func import create_agents
14
  from pydantic import BaseModel, Field
15
  from typing import List, Dict, Any, Optional
16
  from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id
@@ -200,21 +200,24 @@ async def roleplay(
200
  message = {"role": "user", "content": message_content}
201
 
202
  try:
203
- response = await create_agents(scenario_dict).ainvoke(
204
  {
205
  "messages": [message],
 
 
 
 
 
206
  },
207
  {"configurable": {"thread_id": session_id}},
208
  )
209
-
210
- # Update session with last message (use text if available, otherwise indicate audio)
211
  last_message = text_message if text_message else "[Audio message]"
212
  update_session_last_message(session_id, last_message)
213
-
214
  # Extract AI response content
215
  ai_response = response["messages"][-1].content
216
  logger.info(f"AI response: {ai_response}")
217
-
218
  return JSONResponse(content={"response": ai_response})
219
 
220
  except Exception as e:
@@ -228,7 +231,7 @@ async def get_messages(request: SessionRequest):
228
  try:
229
 
230
  # Create agent instance
231
- agent = create_agents()
232
 
233
  # Get current state
234
  current_state = agent.get_state(
 
10
  )
11
  from fastapi.responses import JSONResponse
12
  from src.utils.logger import logger
13
+ from src.agents.role_play.flow import role_play_agent
14
  from pydantic import BaseModel, Field
15
  from typing import List, Dict, Any, Optional
16
  from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id
 
200
  message = {"role": "user", "content": message_content}
201
 
202
  try:
203
+ response = await role_play_agent().ainvoke(
204
  {
205
  "messages": [message],
206
+ "scenario_title": scenario_dict["scenario_title"],
207
+ "scenario_description": scenario_dict["scenario_description"],
208
+ "scenario_context": scenario_dict["scenario_context"],
209
+ "your_role": scenario_dict["your_role"],
210
+ "key_vocabulary": scenario_dict["key_vocabulary"],
211
  },
212
  {"configurable": {"thread_id": session_id}},
213
  )
 
 
214
  last_message = text_message if text_message else "[Audio message]"
215
  update_session_last_message(session_id, last_message)
216
+
217
  # Extract AI response content
218
  ai_response = response["messages"][-1].content
219
  logger.info(f"AI response: {ai_response}")
220
+
221
  return JSONResponse(content={"response": ai_response})
222
 
223
  except Exception as e:
 
231
  try:
232
 
233
  # Create agent instance
234
+ agent = role_play_agent()
235
 
236
  # Get current state
237
  current_state = agent.get_state(