Spaces:
Sleeping
Sleeping
Refactor roleplay agent implementation and update session handling for improved message processing
Browse files- sessions.json +0 -16
- src/agents/role_play/flow.py +35 -14
- src/agents/role_play/func.py +50 -18
- src/apis/routes/chat_route.py +10 -7
sessions.json
CHANGED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
[
|
| 2 |
-
{
|
| 3 |
-
"id": "82a6779d-ad13-4edd-a046-575e563a4348",
|
| 4 |
-
"name": "New Conversation",
|
| 5 |
-
"created_at": "2025-08-21T11:57:23.992279",
|
| 6 |
-
"last_message": "[Audio message]",
|
| 7 |
-
"message_count": 37
|
| 8 |
-
},
|
| 9 |
-
{
|
| 10 |
-
"id": "4fbf6c50-6054-4f3d-ac4e-d8281c306d72",
|
| 11 |
-
"name": "New Conversation",
|
| 12 |
-
"created_at": "2025-08-21T11:57:23.993885",
|
| 13 |
-
"last_message": null,
|
| 14 |
-
"message_count": 0
|
| 15 |
-
}
|
| 16 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agents/role_play/flow.py
CHANGED
|
@@ -1,24 +1,45 @@
|
|
| 1 |
from langgraph.graph import StateGraph, START, END
|
| 2 |
-
from .func import State
|
| 3 |
from langgraph.graph.state import CompiledStateGraph
|
| 4 |
-
from langgraph.
|
| 5 |
|
| 6 |
|
| 7 |
-
class
|
| 8 |
def __init__(self):
|
| 9 |
-
|
| 10 |
|
| 11 |
@staticmethod
|
| 12 |
-
def
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
def node(self):
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
def edge(self):
|
| 19 |
-
pass
|
| 20 |
|
| 21 |
-
|
| 22 |
-
self.node()
|
| 23 |
-
self.edge()
|
| 24 |
-
return self.builder.compile(checkpointer=InMemoryStore())
|
|
|
|
| 1 |
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from .func import State, trim_history, call_roleplay, call_guiding_agent
|
| 3 |
from langgraph.graph.state import CompiledStateGraph
|
| 4 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
| 5 |
|
| 6 |
|
| 7 |
+
class RolePlayAgent:
|
| 8 |
def __init__(self):
|
| 9 |
+
pass
|
| 10 |
|
| 11 |
@staticmethod
|
| 12 |
+
def route_to_active_agent(state: State) -> str:
|
| 13 |
+
if state["active_agent"] == "Roleplay Agent":
|
| 14 |
+
return "Roleplay Agent"
|
| 15 |
+
elif state["active_agent"] == "Guiding Agent":
|
| 16 |
+
return "Guiding Agent"
|
| 17 |
|
| 18 |
+
def node(self, graph: StateGraph):
|
| 19 |
+
graph.add_node("trim_history", trim_history)
|
| 20 |
+
graph.add_node("Roleplay Agent", call_roleplay, destinations=("Guiding Agent",))
|
| 21 |
+
graph.add_node(
|
| 22 |
+
"Guiding Agent", call_guiding_agent, destinations=("Roleplay Agent",)
|
| 23 |
+
)
|
| 24 |
+
return graph
|
| 25 |
+
|
| 26 |
+
def edge(self, graph: StateGraph):
|
| 27 |
+
graph.add_edge(START, "trim_history")
|
| 28 |
+
graph.add_conditional_edges(
|
| 29 |
+
"trim_history",
|
| 30 |
+
self.route_to_active_agent,
|
| 31 |
+
{
|
| 32 |
+
"Roleplay Agent": "Roleplay Agent",
|
| 33 |
+
"Guiding Agent": "Guiding Agent",
|
| 34 |
+
},
|
| 35 |
+
)
|
| 36 |
+
return graph
|
| 37 |
+
|
| 38 |
+
def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
|
| 39 |
+
graph = StateGraph(State)
|
| 40 |
+
graph: StateGraph = self.node(graph)
|
| 41 |
+
graph: StateGraph = self.edge(graph)
|
| 42 |
+
return graph.compile(checkpointer=checkpointer)
|
| 43 |
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
role_play_agent = RolePlayAgent()
|
|
|
|
|
|
|
|
|
src/agents/role_play/func.py
CHANGED
|
@@ -1,17 +1,40 @@
|
|
| 1 |
from typing import TypedDict
|
| 2 |
from src.config.llm import model
|
| 3 |
-
from langgraph.checkpoint.memory import InMemorySaver
|
| 4 |
from langgraph.prebuilt import create_react_agent
|
| 5 |
-
from langgraph_swarm import create_handoff_tool
|
|
|
|
| 6 |
from .prompt import roleplay_prompt, guiding_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class State(TypedDict):
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
roleplay_agent = create_react_agent(
|
| 16 |
model,
|
| 17 |
[
|
|
@@ -21,15 +44,21 @@ def create_agents(scenario, checkpointer=InMemorySaver()):
|
|
| 21 |
),
|
| 22 |
],
|
| 23 |
prompt=roleplay_prompt.format(
|
| 24 |
-
scenario_title=
|
| 25 |
-
scenario_description=
|
| 26 |
-
scenario_context=
|
| 27 |
-
your_role=
|
| 28 |
-
key_vocabulary=
|
| 29 |
),
|
| 30 |
name="Roleplay Agent",
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
guiding_agent = create_react_agent(
|
| 34 |
model,
|
| 35 |
[
|
|
@@ -39,17 +68,20 @@ def create_agents(scenario, checkpointer=InMemorySaver()):
|
|
| 39 |
),
|
| 40 |
],
|
| 41 |
prompt=guiding_prompt.format(
|
| 42 |
-
scenario_title=
|
| 43 |
-
scenario_description=
|
| 44 |
-
scenario_context=
|
| 45 |
-
your_role=
|
| 46 |
-
key_vocabulary=
|
| 47 |
),
|
| 48 |
name="Guiding Agent",
|
| 49 |
)
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
workflow = create_swarm(
|
| 52 |
-
[roleplay_agent, guiding_agent], default_active_agent="Roleplay Agent"
|
| 53 |
-
)
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import TypedDict
|
| 2 |
from src.config.llm import model
|
|
|
|
| 3 |
from langgraph.prebuilt import create_react_agent
|
| 4 |
+
from langgraph_swarm import create_handoff_tool
|
| 5 |
+
from langchain_core.messages import RemoveMessage
|
| 6 |
from .prompt import roleplay_prompt, guiding_prompt
|
| 7 |
+
from typing_extensions import TypedDict, Annotated
|
| 8 |
+
from langchain_core.messages import AnyMessage
|
| 9 |
+
from langgraph.graph import add_messages
|
| 10 |
+
from loguru import logger
|
| 11 |
|
| 12 |
|
| 13 |
class State(TypedDict):
|
| 14 |
+
active_agent: str | None
|
| 15 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 16 |
+
scenario_title: str
|
| 17 |
+
scenario_description: str
|
| 18 |
+
scenario_context: str
|
| 19 |
+
your_role: str
|
| 20 |
+
key_vocabulary: str
|
| 21 |
|
| 22 |
|
| 23 |
+
def trim_history(state: State):
|
| 24 |
+
if not state.get("active_agent"):
|
| 25 |
+
state["active_agent"] = "Roleplay Agent"
|
| 26 |
+
history = state.get("messages", [])
|
| 27 |
+
if len(history) > 25:
|
| 28 |
+
num_to_remove = len(history) - 5
|
| 29 |
+
remove_messages = [
|
| 30 |
+
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
|
| 31 |
+
]
|
| 32 |
+
state["messages"] = remove_messages
|
| 33 |
+
return state
|
| 34 |
|
| 35 |
+
|
| 36 |
+
async def call_roleplay(state: State):
|
| 37 |
+
logger.info("Calling roleplay agent...")
|
| 38 |
roleplay_agent = create_react_agent(
|
| 39 |
model,
|
| 40 |
[
|
|
|
|
| 44 |
),
|
| 45 |
],
|
| 46 |
prompt=roleplay_prompt.format(
|
| 47 |
+
scenario_title=state["scenario_title"],
|
| 48 |
+
scenario_description=state["scenario_description"],
|
| 49 |
+
scenario_context=state["scenario_context"],
|
| 50 |
+
your_role=state["your_role"],
|
| 51 |
+
key_vocabulary=state["key_vocabulary"],
|
| 52 |
),
|
| 53 |
name="Roleplay Agent",
|
| 54 |
)
|
| 55 |
+
response = await roleplay_agent.ainvoke({"messages": state["messages"]})
|
| 56 |
+
|
| 57 |
+
return {"messages": response["messages"]}
|
| 58 |
|
| 59 |
+
|
| 60 |
+
async def call_guiding_agent(state: State):
|
| 61 |
+
logger.info("Calling guiding agent...")
|
| 62 |
guiding_agent = create_react_agent(
|
| 63 |
model,
|
| 64 |
[
|
|
|
|
| 68 |
),
|
| 69 |
],
|
| 70 |
prompt=guiding_prompt.format(
|
| 71 |
+
scenario_title=state["scenario_title"],
|
| 72 |
+
scenario_description=state["scenario_description"],
|
| 73 |
+
scenario_context=state["scenario_context"],
|
| 74 |
+
your_role=state["your_role"],
|
| 75 |
+
key_vocabulary=state["key_vocabulary"],
|
| 76 |
),
|
| 77 |
name="Guiding Agent",
|
| 78 |
)
|
| 79 |
+
response = await guiding_agent.ainvoke({"messages": state["messages"]})
|
| 80 |
+
return {"messages": response["messages"]}
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def route_to_active_agent(state: State) -> str:
|
| 84 |
+
if state["active_agent"] == "Roleplay Agent":
|
| 85 |
+
return "Roleplay Agent"
|
| 86 |
+
elif state["active_agent"] == "Guiding Agent":
|
| 87 |
+
return "Guiding Agent"
|
src/apis/routes/chat_route.py
CHANGED
|
@@ -10,7 +10,7 @@ from fastapi import (
|
|
| 10 |
)
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
from src.utils.logger import logger
|
| 13 |
-
from src.agents.role_play.
|
| 14 |
from pydantic import BaseModel, Field
|
| 15 |
from typing import List, Dict, Any, Optional
|
| 16 |
from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id
|
|
@@ -200,21 +200,24 @@ async def roleplay(
|
|
| 200 |
message = {"role": "user", "content": message_content}
|
| 201 |
|
| 202 |
try:
|
| 203 |
-
response = await
|
| 204 |
{
|
| 205 |
"messages": [message],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
},
|
| 207 |
{"configurable": {"thread_id": session_id}},
|
| 208 |
)
|
| 209 |
-
|
| 210 |
-
# Update session with last message (use text if available, otherwise indicate audio)
|
| 211 |
last_message = text_message if text_message else "[Audio message]"
|
| 212 |
update_session_last_message(session_id, last_message)
|
| 213 |
-
|
| 214 |
# Extract AI response content
|
| 215 |
ai_response = response["messages"][-1].content
|
| 216 |
logger.info(f"AI response: {ai_response}")
|
| 217 |
-
|
| 218 |
return JSONResponse(content={"response": ai_response})
|
| 219 |
|
| 220 |
except Exception as e:
|
|
@@ -228,7 +231,7 @@ async def get_messages(request: SessionRequest):
|
|
| 228 |
try:
|
| 229 |
|
| 230 |
# Create agent instance
|
| 231 |
-
agent =
|
| 232 |
|
| 233 |
# Get current state
|
| 234 |
current_state = agent.get_state(
|
|
|
|
| 10 |
)
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
from src.utils.logger import logger
|
| 13 |
+
from src.agents.role_play.flow import role_play_agent
|
| 14 |
from pydantic import BaseModel, Field
|
| 15 |
from typing import List, Dict, Any, Optional
|
| 16 |
from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id
|
|
|
|
| 200 |
message = {"role": "user", "content": message_content}
|
| 201 |
|
| 202 |
try:
|
| 203 |
+
response = await role_play_agent().ainvoke(
|
| 204 |
{
|
| 205 |
"messages": [message],
|
| 206 |
+
"scenario_title": scenario_dict["scenario_title"],
|
| 207 |
+
"scenario_description": scenario_dict["scenario_description"],
|
| 208 |
+
"scenario_context": scenario_dict["scenario_context"],
|
| 209 |
+
"your_role": scenario_dict["your_role"],
|
| 210 |
+
"key_vocabulary": scenario_dict["key_vocabulary"],
|
| 211 |
},
|
| 212 |
{"configurable": {"thread_id": session_id}},
|
| 213 |
)
|
|
|
|
|
|
|
| 214 |
last_message = text_message if text_message else "[Audio message]"
|
| 215 |
update_session_last_message(session_id, last_message)
|
| 216 |
+
|
| 217 |
# Extract AI response content
|
| 218 |
ai_response = response["messages"][-1].content
|
| 219 |
logger.info(f"AI response: {ai_response}")
|
| 220 |
+
|
| 221 |
return JSONResponse(content={"response": ai_response})
|
| 222 |
|
| 223 |
except Exception as e:
|
|
|
|
| 231 |
try:
|
| 232 |
|
| 233 |
# Create agent instance
|
| 234 |
+
agent = role_play_agent()
|
| 235 |
|
| 236 |
# Get current state
|
| 237 |
current_state = agent.get_state(
|