from fastapi import ( APIRouter, status, HTTPException, File, UploadFile, Form, ) from fastapi.responses import JSONResponse, StreamingResponse from src.utils.logger import logger from src.agents.role_play.flow import role_play_agent from src.services.tts_service import tts_service from pydantic import BaseModel, Field from typing import Dict, Any, Optional from src.agents.role_play.scenarios import get_scenarios import json import base64 import asyncio router = APIRouter(prefix="/ai", tags=["AI"]) class RoleplayRequest(BaseModel): query: str = Field(..., description="User's query for the AI agent") session_id: str = Field( ..., description="Session ID for tracking user interactions" ) scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") @router.get("/scenarios", status_code=status.HTTP_200_OK) async def list_scenarios(): """Get all available scenarios""" return JSONResponse(content=get_scenarios()) @router.post("/roleplay", status_code=status.HTTP_200_OK) async def roleplay( session_id: str = Form( ..., description="Session ID for tracking user interactions" ), scenario: str = Form( ..., description="The scenario for the roleplay as JSON string" ), text_message: Optional[str] = Form(None, description="Text message from user"), audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), ): """Send a message (text or audio) to the roleplay agent""" logger.info(f"Received roleplay request: {session_id}") # Validate that at least one input is provided if not text_message and not audio_file: raise HTTPException( status_code=400, detail="Either text_message or audio_file must be provided" ) # Parse scenario from JSON string try: scenario_dict = json.loads(scenario) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid scenario JSON format") if not scenario_dict: raise HTTPException(status_code=400, detail="Scenario not provided") # Prepare message content message_content = [] # Handle text input if text_message: message_content.append({"type": "text", "text": text_message}) # Handle audio input if audio_file: try: # Read audio file content audio_data = await audio_file.read() # Convert to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") # Determine mime type based on file extension file_extension = ( audio_file.filename.split(".")[-1].lower() if audio_file.filename else "wav" ) mime_type_map = { "wav": "audio/wav", "mp3": "audio/mpeg", "ogg": "audio/ogg", "webm": "audio/webm", "m4a": "audio/mp4", } mime_type = mime_type_map.get(file_extension, "audio/wav") message_content.append( { "type": "audio", "source_type": "base64", "data": audio_base64, "mime_type": mime_type, } ) except Exception as e: logger.error(f"Error processing audio file: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing audio file: {str(e)}" ) # Create message in the required format message = {"role": "user", "content": message_content} try: response = await role_play_agent().ainvoke( { "messages": [message], "scenario_title": scenario_dict["scenario_title"], "scenario_description": scenario_dict["scenario_description"], "scenario_context": scenario_dict["scenario_context"], "your_role": scenario_dict["your_role"], "key_vocabulary": scenario_dict["key_vocabulary"], }, {"configurable": {"thread_id": session_id}}, ) # Extract AI response content ai_response = response["messages"][-1].content logger.info(f"AI response: {ai_response}") return JSONResponse(content={"response": ai_response}) except Exception as e: logger.error(f"Error in roleplay: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @router.post("/roleplay/stream", status_code=status.HTTP_200_OK) async def roleplay_stream( session_id: str = Form( ..., description="Session ID for tracking user interactions" ), scenario: str = Form( ..., description="The scenario for the roleplay as JSON string" ), text_message: Optional[str] = Form(None, description="Text message from user"), audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), audio: bool = Form(False, description="Whether to return TTS audio response"), ): """Send a message (text or audio) to the roleplay agent with streaming response""" logger.info(f"Received streaming roleplay request: {session_id}") # Validate that at least one input is provided if not text_message and not audio_file: raise HTTPException( status_code=400, detail="Either text_message or audio_file must be provided" ) # Parse scenario from JSON string try: scenario_dict = json.loads(scenario) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid scenario JSON format") if not scenario_dict: raise HTTPException(status_code=400, detail="Scenario not provided") # Prepare message content message_content = [] # Handle text input if text_message: message_content.append({"type": "text", "text": text_message}) # Handle audio input if audio_file: try: # Read audio file content audio_data = await audio_file.read() # Convert to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") # Determine mime type based on file extension file_extension = ( audio_file.filename.split(".")[-1].lower() if audio_file.filename else "wav" ) mime_type_map = { "wav": "audio/wav", "mp3": "audio/mpeg", "ogg": "audio/ogg", "webm": "audio/webm", "m4a": "audio/mp4", } mime_type = mime_type_map.get(file_extension, "audio/wav") message_content.append( { "type": "audio", "source_type": "base64", "data": audio_base64, "mime_type": mime_type, } ) except Exception as e: logger.error(f"Error processing audio file: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing audio file: {str(e)}" ) # Create message in the required format message = {"role": "user", "content": message_content} async def generate_stream(): """Generator function for streaming responses""" accumulated_content = "" conversation_ended = False try: input_graph = { "messages": [message], "scenario_title": scenario_dict["scenario_title"], "scenario_description": scenario_dict["scenario_description"], "scenario_context": scenario_dict["scenario_context"], "your_role": scenario_dict["your_role"], "key_vocabulary": scenario_dict["key_vocabulary"], } config = {"configurable": {"thread_id": session_id}} async for event in role_play_agent().astream( input=input_graph, stream_mode=["messages", "values"], config=config, subgraphs=True, ): _, event_type, message_chunk = event if event_type == "messages": # message_chunk is a tuple, get the first element which is the actual AIMessageChunk if isinstance(message_chunk, tuple) and len(message_chunk) > 0: actual_message = message_chunk[0] content = getattr(actual_message, "content", "") else: actual_message = message_chunk content = getattr(message_chunk, "content", "") # Check if this is a tool call message and if it's an end conversation tool call if ( hasattr(actual_message, "tool_calls") and actual_message.tool_calls ): # Check if any tool call is for ending the conversation for tool_call in actual_message.tool_calls: if ( isinstance(tool_call, dict) and tool_call.get("name") == "end_conversation" ): # Send a special termination message to the client termination_data = { "type": "termination", "content": "Conversation ended", "reason": tool_call.get("args", {}).get("reason", "Unknown reason") } yield f"data: {json.dumps(termination_data)}\n\n" conversation_ended = True break if content and not conversation_ended: # Accumulate content for TTS accumulated_content += content # Create SSE-formatted response response_data = { "type": "message_chunk", "content": content, "metadata": { "agent": getattr(actual_message, "name", "unknown"), "id": getattr(actual_message, "id", ""), "usage_metadata": getattr( actual_message, "usage_metadata", {} ), }, } yield f"data: {json.dumps(response_data)}\n\n" # Small delay to prevent overwhelming the client await asyncio.sleep(0.01) # Only send completion signal if conversation wasn't ended by tool call if not conversation_ended: # Generate TTS audio if requested audio_data = None if audio and accumulated_content.strip(): try: logger.info( f"Generating TTS for accumulated content: {len(accumulated_content)} chars" ) audio_result = await tts_service.text_to_speech(accumulated_content) if audio_result: audio_data = { "audio_data": audio_result["audio_data"], "mime_type": audio_result["mime_type"], "format": audio_result["format"], } logger.info("TTS audio generated successfully") else: logger.warning("TTS generation failed") except Exception as tts_error: logger.error(f"TTS generation error: {str(tts_error)}") # Send completion signal with optional audio completion_data = {"type": "completion", "content": "", "audio": audio_data} yield f"data: {json.dumps(completion_data)}\n\n" except Exception as e: logger.error(f"Error in streaming roleplay: {str(e)}") error_data = {"type": "error", "content": str(e)} yield f"data: {json.dumps(error_data)}\n\n" return StreamingResponse( generate_stream(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", }, )