Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException, Body, UploadFile, File, Form, Request | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| from app.ai_agent.agent import handle_user_query, create_medical_agent, search_cases_with_timeout, register_attachment | |
| import logging | |
| import asyncio | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| # Basic size limits (bytes) | |
| MAX_IMAGE_BYTES = 6_000_000 # ~6 MB | |
| MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB | |
| MAX_FILE_BYTES = 2 * 1024 * 1024 # 2 MB | |
| class AIRequest(BaseModel): | |
| text: Optional[str] = None | |
| image: Optional[str] = None # URL ou base64 | |
| images: Optional[List[str]] = None # URLs ou base64 multiples | |
| audio: Optional[str] = None # URL ou base64 | |
| audios: Optional[List[str]] = None # URLs ou base64 multiples | |
| want_stats: Optional[bool] = False | |
| location: Optional[str] = None # Pour la recherche d'établissements | |
| files: Optional[List[str]] = None # URLs ou base64 de fichiers (petits) | |
| file_names: Optional[List[str]] = None # Noms des fichiers correspondants | |
| agent_mode: Optional[str] = None # 'messages' (zero-shot), 'string', or 'legacy' | |
| class AIResponse(BaseModel): | |
| result: str | |
| stats: Optional[dict] = None | |
| async def ai_endpoint(req: AIRequest = Body(...)): | |
| # Construction de la requête utilisateur pour l'agent | |
| user_query = "" | |
| if req.text: | |
| user_query += req.text + "\n" | |
| if req.image: | |
| user_query += f"[Image fournie]\n" | |
| if req.audio: | |
| user_query += f"[Audio fourni]\n" | |
| if req.location: | |
| user_query += f"[Localisation: {req.location}]\n" | |
| # Appel à l'agent LangChain dans un thread pour éviter de bloquer l'event loop | |
| result = await asyncio.to_thread( | |
| handle_user_query, | |
| user_query, | |
| req.location, | |
| req.image, | |
| req.audio, | |
| req.files or [], | |
| req.file_names or [], | |
| req.images or [], | |
| req.audios or [], | |
| req.agent_mode, | |
| ) | |
| stats = None | |
| if req.want_stats: | |
| stats = {} | |
| if req.text: | |
| stats["word_count"] = len(req.text.split()) | |
| if req.image: | |
| stats["image_url_or_b64_length"] = len(req.image) | |
| if req.images: | |
| stats["images_count"] = len(req.images) | |
| if req.audio: | |
| stats["audio_url_or_b64_length"] = len(req.audio) | |
| if req.audios: | |
| stats["audios_count"] = len(req.audios) | |
| if req.files: | |
| stats["files_count"] = len(req.files) | |
| # Ajoute d'autres stats pertinentes ici | |
| return AIResponse(result=result, stats=stats) | |
| # ============================================================================= | |
| # Multipart/form-data endpoint for uploads | |
| # ============================================================================= | |
| async def ai_form_endpoint( | |
| request: Request, | |
| text: Optional[str] = Form(None), | |
| location: Optional[str] = Form(None), | |
| want_stats: Optional[bool] = Form(False), | |
| agent_mode: Optional[str] = Form(None), | |
| ): | |
| # Parse the raw form to accept both UploadFile and string references | |
| try: | |
| form = await request.form() | |
| except Exception: | |
| form = None | |
| image_refs: List[str] = [] | |
| audio_refs: List[str] = [] | |
| file_refs: List[str] = [] | |
| file_names: List[str] = [] | |
| if form: | |
| # Helpers to iterate possible single/plural fields | |
| def _iter_values(keys: List[str]): | |
| for key in keys: | |
| for v in form.getlist(key): | |
| yield v | |
| # Images | |
| for v in _iter_values(["image", "images"]): | |
| if isinstance(v, UploadFile): | |
| try: | |
| data = await v.read() | |
| if data and len(data) > MAX_IMAGE_BYTES: | |
| raise HTTPException(status_code=413, detail=f"Image '{v.filename}' trop volumineuse (> 6 Mo)") | |
| ref = register_attachment(data, filename=v.filename, mime=v.content_type) | |
| image_refs.append(ref) | |
| finally: | |
| await v.close() | |
| elif isinstance(v, str) and v.strip(): | |
| image_refs.append(v.strip()) | |
| # Audios | |
| for v in _iter_values(["audio", "audios"]): | |
| if isinstance(v, UploadFile): | |
| try: | |
| data = await v.read() | |
| if data and len(data) > MAX_AUDIO_BYTES: | |
| raise HTTPException(status_code=413, detail=f"Audio '{v.filename}' trop volumineux (> 10 Mo)") | |
| ref = register_attachment(data, filename=v.filename, mime=v.content_type) | |
| audio_refs.append(ref) | |
| finally: | |
| await v.close() | |
| elif isinstance(v, str) and v.strip(): | |
| audio_refs.append(v.strip()) | |
| # Files (text/PDF) | |
| string_file_names = form.getlist("file_names") if "file_names" in form else [] | |
| string_file_index = 0 | |
| for v in _iter_values(["file", "files"]): | |
| if isinstance(v, UploadFile): | |
| try: | |
| data = await v.read() | |
| if data and len(data) > MAX_FILE_BYTES: | |
| raise HTTPException(status_code=413, detail=f"Fichier '{v.filename}' trop volumineux (> 2 Mo)") | |
| ref = register_attachment(data, filename=v.filename, mime=v.content_type) | |
| file_refs.append(ref) | |
| file_names.append(v.filename or "file") | |
| finally: | |
| await v.close() | |
| elif isinstance(v, str) and v.strip(): | |
| file_refs.append(v.strip()) | |
| # try map a provided filename | |
| name = None | |
| if string_file_names and string_file_index < len(string_file_names): | |
| maybe = string_file_names[string_file_index] | |
| if isinstance(maybe, str) and maybe.strip(): | |
| name = maybe.strip() | |
| file_names.append(name or "file") | |
| string_file_index += 1 | |
| # Validate agent_mode if provided | |
| if agent_mode and agent_mode.lower() not in {"messages", "string", "legacy"}: | |
| raise HTTPException(status_code=400, detail="agent_mode invalide: utilisez 'messages', 'string' ou 'legacy'") | |
| # Construct user query summary (all inputs optional) | |
| user_query = (text or "").strip() | |
| if image_refs: | |
| user_query += ("\n" if user_query else "") + "[Image(s) fournie(s)]" | |
| if audio_refs: | |
| user_query += ("\n" if user_query else "") + "[Audio(s) fourni(s)]" | |
| if location: | |
| user_query += ("\n" if user_query else "") + f"[Localisation: {location}]" | |
| # All inputs are optional; proceed even if user_query is empty. | |
| # Invoke agent with attach:// references | |
| result = await asyncio.to_thread( | |
| handle_user_query, | |
| user_query, | |
| location, | |
| None, # single image param not used here | |
| None, # single audio param not used here | |
| file_refs, | |
| file_names, | |
| image_refs, | |
| audio_refs, | |
| agent_mode, | |
| ) | |
| stats = None | |
| if want_stats: | |
| stats = { | |
| "word_count": len(text.split()) if text else 0, | |
| "images_count": len(image_refs), | |
| "audios_count": len(audio_refs), | |
| "files_count": len(file_refs), | |
| } | |
| return AIResponse(result=result, stats=stats) | |
| # ============================================================================= | |
| # DEBUG ENDPOINTS to isolate the hanging issue | |
| # ============================================================================= | |
| async def debug_create_agent(): | |
| """Tests if creating the medical agent works without hanging.""" | |
| logger.info("--- DEBUG: Testing agent creation ---") | |
| try: | |
| agent = create_medical_agent() | |
| if agent: | |
| logger.info("--- DEBUG: Agent creation successful ---") | |
| return {"status": "Agent created successfully"} | |
| else: | |
| logger.error("--- DEBUG: Agent creation failed, returned None ---") | |
| raise HTTPException(status_code=500, detail="Agent creation returned None") | |
| except Exception as e: | |
| logger.error(f"--- DEBUG: Agent creation failed with exception: {e} ---", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Agent creation failed: {e}") | |
| async def debug_search_data(q: str = "fever and headache"): | |
| """Tests if the clinical data search works without hanging.""" | |
| logger.info(f"--- DEBUG: Testing data search with query: '{q}' ---") | |
| try: | |
| context = search_cases_with_timeout(q, timeout=15) | |
| logger.info("--- DEBUG: Data search successful ---") | |
| return {"status": "Data search completed", "context_found": bool(context), "context": context} | |
| except Exception as e: | |
| logger.error(f"--- DEBUG: Data search failed with exception: {e} ---", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Data search failed: {e}") | |
| async def debug_invoke_agent(q: str = "hello, how are you?"): | |
| """Tests if invoking the agent with a simple query works without hanging.""" | |
| logger.info(f"--- DEBUG: Testing agent invocation with query: '{q}' ---") | |
| try: | |
| agent = create_medical_agent() | |
| logger.info("--- DEBUG: Agent created, invoking... ---") | |
| response = await asyncio.to_thread(agent.invoke, {"input": q}) | |
| logger.info("--- DEBUG: Agent invocation successful ---") | |
| return {"status": "Agent invoked successfully", "response": response} | |
| except Exception as e: | |
| logger.error(f"--- DEBUG: Agent invocation failed with exception: {e} ---", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Agent invocation failed: {e}") | |