from fastapi import APIRouter, status, Depends, BackgroundTasks, HTTPException from fastapi.responses import JSONResponse from src.utils.logger import logger from src.agents.role_play.func import create_agents from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id import json import os import uuid from datetime import datetime router = APIRouter(prefix="/ai", tags=["AI"]) class RoleplayRequest(BaseModel): query: str = Field(..., description="User's query for the AI agent") session_id: str = Field( ..., description="Session ID for tracking user interactions" ) scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") class SessionRequest(BaseModel): session_id: str = Field(..., description="Session ID to perform operations on") class CreateSessionRequest(BaseModel): name: str = Field(..., description="Name for the new session") class UpdateSessionRequest(BaseModel): session_id: str = Field(..., description="Session ID to update") name: str = Field(..., description="New name for the session") # Session management helper functions SESSIONS_FILE = "sessions.json" def load_sessions() -> List[Dict[str, Any]]: """Load sessions from JSON file""" try: if os.path.exists(SESSIONS_FILE): with open(SESSIONS_FILE, "r", encoding="utf-8") as f: return json.load(f) return [] except Exception as e: logger.error(f"Error loading sessions: {str(e)}") return [] def save_sessions(sessions: List[Dict[str, Any]]): """Save sessions to JSON file""" try: with open(SESSIONS_FILE, "w", encoding="utf-8") as f: json.dump(sessions, f, ensure_ascii=False, indent=2, default=str) except Exception as e: logger.error(f"Error saving sessions: {str(e)}") def create_session(name: str) -> Dict[str, Any]: """Create a new session""" session_id = str(uuid.uuid4()) session = { "id": session_id, "name": name, "created_at": datetime.now().isoformat(), "last_message": None, "message_count": 0, } sessions = load_sessions() sessions.append(session) save_sessions(sessions) return session def get_session_by_id(session_id: str) -> Optional[Dict[str, Any]]: """Get session by ID""" sessions = load_sessions() return next((s for s in sessions if s["id"] == session_id), None) def update_session_last_message(session_id: str, message: str): """Update session's last message""" sessions = load_sessions() for session in sessions: if session["id"] == session_id: session["last_message"] = message session["message_count"] = session.get("message_count", 0) + 1 break save_sessions(sessions) def delete_session_by_id(session_id: str) -> bool: """Delete session by ID""" sessions = load_sessions() original_count = len(sessions) sessions = [s for s in sessions if s["id"] != session_id] if len(sessions) < original_count: save_sessions(sessions) return True return False @router.get("/scenarios", status_code=status.HTTP_200_OK) async def list_scenarios(): """Get all available scenarios""" return JSONResponse(content=get_scenarios()) @router.post("/roleplay", status_code=status.HTTP_200_OK) async def roleplay(request: RoleplayRequest): """Send a message to the roleplay agent""" scenario = request.scenario if not scenario: raise HTTPException(status_code=400, detail="Scenario not provided") response = await create_agents(scenario).ainvoke( { "messages": [request.query], }, {"configurable": {"thread_id": request.session_id}}, ) # Update session with last message update_session_last_message(request.session_id, request.query) return JSONResponse(content=response["messages"][-1].content) @router.post("/get-messages", status_code=status.HTTP_200_OK) async def get_messages(request: SessionRequest): """Get all messages from a conversation session""" try: # Create agent instance agent = create_agents() # Get current state current_state = agent.get_state( {"configurable": {"thread_id": request.session_id}} ) if not current_state or not current_state.values: return JSONResponse( content={ "session_id": request.session_id, "messages": [], "total_messages": 0, } ) # Extract messages from state messages = [] if "messages" in current_state.values: raw_messages = current_state.values["messages"] for msg in raw_messages: # Convert message object to dict format if hasattr(msg, "content") and hasattr(msg, "type"): messages.append( { "role": getattr(msg, "type", "unknown"), "content": getattr(msg, "content", ""), "timestamp": getattr(msg, "timestamp", None), } ) elif hasattr(msg, "content"): # Handle different message formats role = ( "human" if hasattr(msg, "__class__") and "Human" in msg.__class__.__name__ else "ai" ) messages.append( { "role": role, "content": msg.content, "timestamp": getattr(msg, "timestamp", None), } ) else: # Fallback for unexpected message format messages.append( {"role": "unknown", "content": str(msg), "timestamp": None} ) return JSONResponse( content={ "session_id": request.session_id, "messages": messages, "total_messages": len(messages), } ) except Exception as e: logger.error( f"Error getting messages for session {request.session_id}: {str(e)}" ) raise HTTPException(status_code=500, detail=f"Failed to get messages: {str(e)}") @router.get("/sessions", status_code=status.HTTP_200_OK) async def get_sessions(): """Get all sessions""" try: sessions = load_sessions() return JSONResponse(content={"sessions": sessions}) except Exception as e: logger.error(f"Error getting sessions: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get sessions: {str(e)}") @router.post("/sessions", status_code=status.HTTP_201_CREATED) async def create_new_session(request: CreateSessionRequest): """Create a new session""" try: session = create_session(request.name) return JSONResponse(content={"session": session}) except Exception as e: logger.error(f"Error creating session: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to create session: {str(e)}" ) @router.get("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def get_session(session_id: str): """Get a specific session by ID""" try: session = get_session_by_id(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return JSONResponse(content={"session": session}) except HTTPException: raise except Exception as e: logger.error(f"Error getting session {session_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get session: {str(e)}") @router.put("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def update_session(session_id: str, request: UpdateSessionRequest): """Update a session""" try: sessions = load_sessions() session_found = False for session in sessions: if session["id"] == session_id: session["name"] = request.name session_found = True break if not session_found: raise HTTPException(status_code=404, detail="Session not found") save_sessions(sessions) updated_session = get_session_by_id(session_id) return JSONResponse(content={"session": updated_session}) except HTTPException: raise except Exception as e: logger.error(f"Error updating session {session_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to update session: {str(e)}" ) @router.delete("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def delete_session(session_id: str): """Delete a session""" try: success = delete_session_by_id(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") return JSONResponse(content={"message": "Session deleted successfully"}) except HTTPException: raise except Exception as e: logger.error(f"Error deleting session {session_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to delete session: {str(e)}" )