ggg
Browse files- app.py +18 -25
- fastapi_app.py +413 -126
- medical_ai.py +369 -150
- requirements.txt +16 -11
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Medical AI Assistant -
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
@@ -20,22 +20,22 @@ logging.basicConfig(
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def setup_environment():
|
| 23 |
-
"""Setup environment variables for
|
| 24 |
# Set environment variables for optimal performance
|
| 25 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 26 |
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
|
| 27 |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
|
| 28 |
|
| 29 |
-
#
|
| 30 |
os.environ.setdefault("HOST", "0.0.0.0")
|
| 31 |
-
os.environ.setdefault("PORT", "
|
| 32 |
|
| 33 |
-
logger.info("✅ Environment configured for
|
| 34 |
|
| 35 |
def main():
|
| 36 |
"""Main application entry point"""
|
| 37 |
try:
|
| 38 |
-
logger.info("🩺 Starting Medical AI Assistant -
|
| 39 |
|
| 40 |
# Setup environment
|
| 41 |
setup_environment()
|
|
@@ -44,22 +44,22 @@ def main():
|
|
| 44 |
from fastapi_app import app
|
| 45 |
import uvicorn
|
| 46 |
|
| 47 |
-
# Get port from environment
|
| 48 |
-
port = int(os.getenv("PORT",
|
| 49 |
host = os.getenv("HOST", "0.0.0.0")
|
| 50 |
|
| 51 |
logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
|
| 52 |
-
logger.info(f"📚 API Documentation available at: http://{host}:{port}/docs")
|
|
|
|
| 53 |
|
| 54 |
-
# Launch the FastAPI application
|
| 55 |
uvicorn.run(
|
| 56 |
app,
|
| 57 |
host=host,
|
| 58 |
port=port,
|
| 59 |
log_level="info",
|
| 60 |
-
reload=False, #
|
| 61 |
-
access_log=True
|
| 62 |
-
workers=1 # Single worker for Spaces
|
| 63 |
)
|
| 64 |
|
| 65 |
except KeyboardInterrupt:
|
|
@@ -71,23 +71,16 @@ def main():
|
|
| 71 |
# For direct import (Hugging Face Spaces compatibility)
|
| 72 |
try:
|
| 73 |
from fastapi_app import app
|
| 74 |
-
logger.info("✅ FastAPI app imported successfully
|
| 75 |
except ImportError as e:
|
| 76 |
logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
|
| 77 |
# Create minimal fallback app
|
| 78 |
from fastapi import FastAPI
|
| 79 |
-
app = FastAPI(
|
| 80 |
-
title="Medical AI - Loading",
|
| 81 |
-
description="Medical AI Assistant is starting up..."
|
| 82 |
-
)
|
| 83 |
|
| 84 |
@app.get("/")
|
| 85 |
-
async def
|
| 86 |
-
return {
|
| 87 |
-
"message": "🩺 Medical AI Assistant is starting up...",
|
| 88 |
-
"status": "loading",
|
| 89 |
-
"docs": "/docs"
|
| 90 |
-
}
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
main()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Medical AI Assistant - FastAPI Only Entry Point
|
| 4 |
+
Simplified for backend integration
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def setup_environment():
|
| 23 |
+
"""Setup environment variables for FastAPI deployment"""
|
| 24 |
# Set environment variables for optimal performance
|
| 25 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 26 |
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
|
| 27 |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
|
| 28 |
|
| 29 |
+
# FastAPI specific
|
| 30 |
os.environ.setdefault("HOST", "0.0.0.0")
|
| 31 |
+
os.environ.setdefault("PORT", "8000")
|
| 32 |
|
| 33 |
+
logger.info("✅ Environment configured for FastAPI Medical AI")
|
| 34 |
|
| 35 |
def main():
|
| 36 |
"""Main application entry point"""
|
| 37 |
try:
|
| 38 |
+
logger.info("🩺 Starting Medical AI Assistant - FastAPI Edition")
|
| 39 |
|
| 40 |
# Setup environment
|
| 41 |
setup_environment()
|
|
|
|
| 44 |
from fastapi_app import app
|
| 45 |
import uvicorn
|
| 46 |
|
| 47 |
+
# Get port from environment or use default
|
| 48 |
+
port = int(os.getenv("PORT", 8000))
|
| 49 |
host = os.getenv("HOST", "0.0.0.0")
|
| 50 |
|
| 51 |
logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
|
| 52 |
+
logger.info(f"📚 API Documentation will be available at: http://{host}:{port}/docs")
|
| 53 |
+
logger.info(f"🔄 Alternative docs at: http://{host}:{port}/redoc")
|
| 54 |
|
| 55 |
+
# Launch the FastAPI application
|
| 56 |
uvicorn.run(
|
| 57 |
app,
|
| 58 |
host=host,
|
| 59 |
port=port,
|
| 60 |
log_level="info",
|
| 61 |
+
reload=False, # Set to True for development
|
| 62 |
+
access_log=True
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
except KeyboardInterrupt:
|
|
|
|
| 71 |
# For direct import (Hugging Face Spaces compatibility)
|
| 72 |
try:
|
| 73 |
from fastapi_app import app
|
| 74 |
+
logger.info("✅ FastAPI app imported successfully")
|
| 75 |
except ImportError as e:
|
| 76 |
logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
|
| 77 |
# Create minimal fallback app
|
| 78 |
from fastapi import FastAPI
|
| 79 |
+
app = FastAPI(title="Medical AI - Error", description="Failed to load main application")
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
@app.get("/")
|
| 82 |
+
async def error_root():
|
| 83 |
+
return {"error": "Medical AI Assistant failed to load properly"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
main()
|
fastapi_app.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Medical AI Assistant -
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from fastapi.responses import JSONResponse
|
|
|
|
|
|
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
-
from typing import List, Optional, Dict, Any
|
| 12 |
import logging
|
| 13 |
import uuid
|
| 14 |
import os
|
|
|
|
| 15 |
import asyncio
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
import time
|
|
@@ -25,23 +28,37 @@ logger = logging.getLogger(__name__)
|
|
| 25 |
|
| 26 |
# Initialize models globally
|
| 27 |
pipeline = None
|
|
|
|
| 28 |
|
| 29 |
async def load_models():
|
| 30 |
"""Load ML models asynchronously"""
|
| 31 |
-
global pipeline
|
| 32 |
try:
|
| 33 |
-
logger.info("
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
pipeline = SpacesMedicalAIPipeline()
|
| 38 |
logger.info("✅ Medical pipeline loaded successfully")
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
except Exception as e:
|
| 43 |
logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
|
| 44 |
-
|
| 45 |
|
| 46 |
@asynccontextmanager
|
| 47 |
async def lifespan(app: FastAPI):
|
|
@@ -51,140 +68,284 @@ async def lifespan(app: FastAPI):
|
|
| 51 |
yield
|
| 52 |
except Exception as e:
|
| 53 |
logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
|
| 54 |
-
|
| 55 |
-
yield
|
| 56 |
finally:
|
| 57 |
logger.info("🔄 Shutting down...")
|
| 58 |
|
| 59 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
app = FastAPI(
|
| 61 |
title="🩺 Medical AI Assistant",
|
| 62 |
-
description="
|
| 63 |
-
version="2.0.0
|
| 64 |
lifespan=lifespan,
|
| 65 |
docs_url="/docs",
|
| 66 |
-
redoc_url="/redoc"
|
|
|
|
| 67 |
)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
# CORS middleware
|
| 70 |
app.add_middleware(
|
| 71 |
CORSMiddleware,
|
| 72 |
allow_origins=["*"],
|
| 73 |
allow_credentials=True,
|
| 74 |
allow_methods=["*"],
|
| 75 |
-
allow_headers=["*"]
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
# ============================================================================
|
| 79 |
-
# PYDANTIC MODELS
|
| 80 |
# ============================================================================
|
| 81 |
|
| 82 |
class MedicalQuestion(BaseModel):
|
| 83 |
"""Medical question request model"""
|
| 84 |
-
question: str = Field(..., description="The medical question", min_length=3, max_length=
|
| 85 |
-
language: str = Field("auto", description="
|
|
|
|
| 86 |
|
| 87 |
class Config:
|
| 88 |
schema_extra = {
|
| 89 |
"example": {
|
| 90 |
-
"question": "What are the symptoms of malaria?",
|
| 91 |
-
"language": "en"
|
|
|
|
| 92 |
}
|
| 93 |
}
|
| 94 |
|
| 95 |
class MedicalResponse(BaseModel):
|
| 96 |
"""Medical response model"""
|
| 97 |
-
success: bool = Field(..., description="
|
| 98 |
-
response: str = Field(..., description="
|
| 99 |
-
detected_language: str = Field(..., description="Detected language")
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
class Config:
|
| 103 |
schema_extra = {
|
| 104 |
"example": {
|
| 105 |
"success": True,
|
| 106 |
-
"response": "Malaria symptoms include fever, chills, headache...",
|
| 107 |
"detected_language": "en",
|
| 108 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
}
|
| 111 |
|
| 112 |
class HealthStatus(BaseModel):
|
| 113 |
-
"""System health status"""
|
| 114 |
-
status: str = Field(..., description="
|
| 115 |
-
models_loaded: bool = Field(..., description="
|
|
|
|
|
|
|
| 116 |
version: str = Field(..., description="API version")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def validate_models():
|
| 119 |
"""Check if models are loaded"""
|
| 120 |
if pipeline is None:
|
| 121 |
raise HTTPException(
|
| 122 |
status_code=503,
|
| 123 |
-
detail="Medical AI models are
|
| 124 |
)
|
| 125 |
|
| 126 |
# ============================================================================
|
| 127 |
-
# API ENDPOINTS
|
| 128 |
# ============================================================================
|
| 129 |
|
| 130 |
@app.get("/", tags=["system"])
|
| 131 |
async def root():
|
| 132 |
-
"""Root endpoint"""
|
| 133 |
return {
|
| 134 |
-
"message": "🩺 Medical AI Assistant
|
| 135 |
-
"version": "2.0.0
|
| 136 |
"status": "running",
|
| 137 |
"docs": "/docs",
|
|
|
|
| 138 |
"endpoints": {
|
| 139 |
"medical_consultation": "/medical/ask",
|
| 140 |
-
"
|
| 141 |
-
|
| 142 |
-
|
|
|
|
| 143 |
}
|
| 144 |
|
| 145 |
@app.get("/health", response_model=HealthStatus, tags=["system"])
|
| 146 |
async def health_check():
|
| 147 |
-
"""
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
return HealthStatus(
|
| 151 |
status="healthy" if pipeline is not None else "loading",
|
| 152 |
models_loaded=pipeline is not None,
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
@app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
|
| 157 |
async def medical_consultation(request: MedicalQuestion):
|
| 158 |
"""
|
| 159 |
-
## Medical Consultation
|
| 160 |
|
| 161 |
-
|
| 162 |
|
| 163 |
**Features:**
|
| 164 |
-
- 🌍 Multilingual support (
|
| 165 |
-
- 🧠
|
| 166 |
-
- ⚡
|
|
|
|
|
|
|
|
|
|
| 167 |
"""
|
| 168 |
start_time = time.time()
|
|
|
|
| 169 |
|
| 170 |
-
|
| 171 |
-
if pipeline is None:
|
| 172 |
-
logger.warning("Models not loaded, using fallback response")
|
| 173 |
-
processing_time = time.time() - start_time
|
| 174 |
-
|
| 175 |
-
fallback_responses = {
|
| 176 |
-
"en": "Medical AI is still initializing. For immediate medical concerns, please consult a healthcare professional. Common symptoms like fever, headache, or persistent pain should be evaluated by a doctor.",
|
| 177 |
-
"fr": "L'IA médicale s'initialise encore. Pour des préoccupations médicales immédiates, veuillez consulter un professionnel de santé. Les symptômes courants comme la fièvre, les maux de tête ou la douleur persistante doivent être évalués par un médecin."
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
detected_lang = "fr" if any(word in request.question.lower() for word in ["quoi", "comment", "pourquoi"]) else "en"
|
| 181 |
-
|
| 182 |
-
return MedicalResponse(
|
| 183 |
-
success=True,
|
| 184 |
-
response=fallback_responses.get(detected_lang, fallback_responses["en"]) + "\n\n⚕️ Medical Disclaimer: Always consult healthcare professionals for proper medical advice.",
|
| 185 |
-
detected_language=detected_lang,
|
| 186 |
-
processing_time=round(processing_time, 2)
|
| 187 |
-
)
|
| 188 |
|
| 189 |
try:
|
| 190 |
logger.info(f"🩺 Processing medical question: {request.question[:50]}...")
|
|
@@ -202,60 +363,201 @@ async def medical_consultation(request: MedicalQuestion):
|
|
| 202 |
success=True,
|
| 203 |
response=result["response"],
|
| 204 |
detected_language=result["source_lang"],
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
except Exception as e:
|
| 209 |
-
logger.error(f"❌ Error in medical consultation: {str(e)}")
|
| 210 |
processing_time = time.time() - start_time
|
| 211 |
|
| 212 |
raise HTTPException(
|
| 213 |
status_code=500,
|
| 214 |
detail={
|
| 215 |
"success": False,
|
| 216 |
-
"error": "
|
|
|
|
|
|
|
| 217 |
"processing_time": round(processing_time, 2)
|
| 218 |
}
|
| 219 |
)
|
| 220 |
|
| 221 |
-
@app.
|
| 222 |
-
async def
|
|
|
|
|
|
|
| 223 |
"""
|
| 224 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
@app.get("/medical/
|
| 247 |
-
async def
|
| 248 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
return {
|
| 250 |
-
"supported_languages": ["English", "French"],
|
| 251 |
"specialties": [
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
],
|
| 257 |
-
"
|
| 258 |
-
"
|
| 259 |
}
|
| 260 |
|
| 261 |
# ============================================================================
|
|
@@ -269,11 +571,13 @@ async def not_found_handler(request, exc):
|
|
| 269 |
content={
|
| 270 |
"success": False,
|
| 271 |
"error": "Endpoint not found",
|
|
|
|
| 272 |
"available_endpoints": [
|
| 273 |
"/docs - API Documentation",
|
| 274 |
-
"/medical/ask -
|
|
|
|
| 275 |
"/health - System status",
|
| 276 |
-
"/
|
| 277 |
]
|
| 278 |
}
|
| 279 |
)
|
|
@@ -284,19 +588,9 @@ async def validation_exception_handler(request, exc):
|
|
| 284 |
status_code=422,
|
| 285 |
content={
|
| 286 |
"success": False,
|
| 287 |
-
"error": "Invalid request data",
|
| 288 |
-
"
|
| 289 |
-
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
@app.exception_handler(500)
|
| 293 |
-
async def internal_error_handler(request, exc):
|
| 294 |
-
return JSONResponse(
|
| 295 |
-
status_code=500,
|
| 296 |
-
content={
|
| 297 |
-
"success": False,
|
| 298 |
-
"error": "Internal server error",
|
| 299 |
-
"message": "The Medical AI is experiencing technical difficulties"
|
| 300 |
}
|
| 301 |
)
|
| 302 |
|
|
@@ -304,24 +598,17 @@ async def internal_error_handler(request, exc):
|
|
| 304 |
# STARTUP MESSAGE
|
| 305 |
# ============================================================================
|
| 306 |
|
| 307 |
-
@app.on_event("startup")
|
| 308 |
-
async def startup_event():
|
| 309 |
-
"""Startup event handler"""
|
| 310 |
-
logger.info("🩺 Medical AI Assistant starting up on Hugging Face Spaces")
|
| 311 |
-
logger.info("📚 API Documentation: /docs")
|
| 312 |
-
logger.info("🔄 Alternative docs: /redoc")
|
| 313 |
-
logger.info("⚡ Optimized for Spaces deployment")
|
| 314 |
-
|
| 315 |
if __name__ == "__main__":
|
| 316 |
import uvicorn
|
| 317 |
|
| 318 |
-
print("🩺 Starting Medical AI Assistant
|
| 319 |
-
print("📚 Documentation available at: http://localhost:
|
|
|
|
| 320 |
|
| 321 |
uvicorn.run(
|
| 322 |
app,
|
| 323 |
host="0.0.0.0",
|
| 324 |
-
port=
|
| 325 |
log_level="info",
|
| 326 |
reload=False
|
| 327 |
)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Medical AI Assistant - FastAPI Only Version
|
| 4 |
+
Simplified endpoints for backend integration with Swagger UI
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from fastapi.responses import JSONResponse
|
| 10 |
+
from fastapi.openapi.docs import get_swagger_ui_html
|
| 11 |
+
from fastapi.openapi.utils import get_openapi
|
| 12 |
from pydantic import BaseModel, Field
|
| 13 |
+
from typing import List, Optional, Dict, Any, Union
|
| 14 |
import logging
|
| 15 |
import uuid
|
| 16 |
import os
|
| 17 |
+
import json
|
| 18 |
import asyncio
|
| 19 |
from contextlib import asynccontextmanager
|
| 20 |
import time
|
|
|
|
| 28 |
|
| 29 |
# Initialize models globally
|
| 30 |
pipeline = None
|
| 31 |
+
whisper_model = None
|
| 32 |
|
| 33 |
async def load_models():
|
| 34 |
"""Load ML models asynchronously"""
|
| 35 |
+
global pipeline, whisper_model
|
| 36 |
try:
|
| 37 |
+
logger.info("Loading Medical AI models...")
|
| 38 |
|
| 39 |
+
from medical_ai import CompetitionMedicalAIPipeline
|
| 40 |
+
pipeline = CompetitionMedicalAIPipeline()
|
|
|
|
| 41 |
logger.info("✅ Medical pipeline loaded successfully")
|
| 42 |
|
| 43 |
+
try:
|
| 44 |
+
from faster_whisper import WhisperModel
|
| 45 |
+
model_cache = os.getenv('HF_HOME', '/tmp/models')
|
| 46 |
+
whisper_model = WhisperModel(
|
| 47 |
+
"medium",
|
| 48 |
+
device="cpu",
|
| 49 |
+
compute_type="int8",
|
| 50 |
+
download_root=model_cache
|
| 51 |
+
)
|
| 52 |
+
logger.info("✅ Whisper model loaded successfully")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.warning(f"⚠️ Could not load Whisper model: {str(e)}")
|
| 55 |
+
whisper_model = None
|
| 56 |
+
|
| 57 |
+
logger.info("🚀 All models loaded successfully")
|
| 58 |
|
| 59 |
except Exception as e:
|
| 60 |
logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
|
| 61 |
+
raise
|
| 62 |
|
| 63 |
@asynccontextmanager
|
| 64 |
async def lifespan(app: FastAPI):
|
|
|
|
| 68 |
yield
|
| 69 |
except Exception as e:
|
| 70 |
logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
|
| 71 |
+
raise
|
|
|
|
| 72 |
finally:
|
| 73 |
logger.info("🔄 Shutting down...")
|
| 74 |
|
| 75 |
+
# Custom OpenAPI schema
|
| 76 |
+
def custom_openapi():
|
| 77 |
+
if app.openapi_schema:
|
| 78 |
+
return app.openapi_schema
|
| 79 |
+
|
| 80 |
+
openapi_schema = get_openapi(
|
| 81 |
+
title="🩺 Medical AI Assistant API",
|
| 82 |
+
version="2.0.0",
|
| 83 |
+
description="""
|
| 84 |
+
## 🎯 Advanced Medical AI Assistant
|
| 85 |
+
|
| 86 |
+
**Multilingual medical consultation API** supporting:
|
| 87 |
+
- 🌍 French, English, and local African languages
|
| 88 |
+
- 🎤 Audio processing with speech-to-text
|
| 89 |
+
- 🧠 Advanced medical knowledge retrieval
|
| 90 |
+
- ⚡ Real-time medical consultations
|
| 91 |
+
|
| 92 |
+
### 🔧 Main Endpoints:
|
| 93 |
+
- **POST /medical/ask** - Text-based medical consultation
|
| 94 |
+
- **POST /medical/audio** - Audio-based medical consultation
|
| 95 |
+
- **GET /health** - System health check
|
| 96 |
+
- **POST /feedback** - Submit user feedback
|
| 97 |
+
|
| 98 |
+
### 🔒 Important Medical Disclaimer:
|
| 99 |
+
This API provides educational medical information only. Always consult qualified healthcare professionals for medical advice.
|
| 100 |
+
""",
|
| 101 |
+
routes=app.routes,
|
| 102 |
+
contact={
|
| 103 |
+
"name": "Medical AI Support",
|
| 104 |
+
"email": "support@medicalai.com"
|
| 105 |
+
},
|
| 106 |
+
license_info={
|
| 107 |
+
"name": "MIT License",
|
| 108 |
+
"url": "https://opensource.org/licenses/MIT"
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Add custom tags
|
| 113 |
+
openapi_schema["tags"] = [
|
| 114 |
+
{
|
| 115 |
+
"name": "medical",
|
| 116 |
+
"description": "Medical consultation endpoints"
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"name": "audio",
|
| 120 |
+
"description": "Audio processing endpoints"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "system",
|
| 124 |
+
"description": "System monitoring and health"
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"name": "feedback",
|
| 128 |
+
"description": "User feedback and analytics"
|
| 129 |
+
}
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
app.openapi_schema = openapi_schema
|
| 133 |
+
return app.openapi_schema
|
| 134 |
+
|
| 135 |
+
# Initialize FastAPI app
|
| 136 |
app = FastAPI(
|
| 137 |
title="🩺 Medical AI Assistant",
|
| 138 |
+
description="Advanced multilingual medical consultation API",
|
| 139 |
+
version="2.0.0",
|
| 140 |
lifespan=lifespan,
|
| 141 |
docs_url="/docs",
|
| 142 |
+
redoc_url="/redoc",
|
| 143 |
+
openapi_url="/openapi.json"
|
| 144 |
)
|
| 145 |
|
| 146 |
+
# Set custom OpenAPI
|
| 147 |
+
app.openapi = custom_openapi
|
| 148 |
+
|
| 149 |
# CORS middleware
|
| 150 |
app.add_middleware(
|
| 151 |
CORSMiddleware,
|
| 152 |
allow_origins=["*"],
|
| 153 |
allow_credentials=True,
|
| 154 |
allow_methods=["*"],
|
| 155 |
+
allow_headers=["*"],
|
| 156 |
+
expose_headers=["*"]
|
| 157 |
)
|
| 158 |
|
| 159 |
# ============================================================================
|
| 160 |
+
# PYDANTIC MODELS FOR REQUEST/RESPONSE VALIDATION
|
| 161 |
# ============================================================================
|
| 162 |
|
| 163 |
class MedicalQuestion(BaseModel):
|
| 164 |
"""Medical question request model"""
|
| 165 |
+
question: str = Field(..., description="The medical question", min_length=3, max_length=1000)
|
| 166 |
+
language: str = Field("auto", description="Preferred language (auto, en, fr)", pattern="^(auto|en|fr)$")
|
| 167 |
+
conversation_id: Optional[str] = Field(None, description="Optional conversation ID for context")
|
| 168 |
|
| 169 |
class Config:
|
| 170 |
schema_extra = {
|
| 171 |
"example": {
|
| 172 |
+
"question": "What are the symptoms of malaria and how is it treated?",
|
| 173 |
+
"language": "en",
|
| 174 |
+
"conversation_id": "conv_123"
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
| 178 |
class MedicalResponse(BaseModel):
|
| 179 |
"""Medical response model"""
|
| 180 |
+
success: bool = Field(..., description="Whether the request was successful")
|
| 181 |
+
response: str = Field(..., description="The medical response")
|
| 182 |
+
detected_language: str = Field(..., description="Detected or used language")
|
| 183 |
+
conversation_id: str = Field(..., description="Conversation identifier")
|
| 184 |
+
context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
|
| 185 |
+
processing_time: float = Field(..., description="Response time in seconds")
|
| 186 |
+
confidence: str = Field(..., description="Response confidence level")
|
| 187 |
|
| 188 |
class Config:
|
| 189 |
schema_extra = {
|
| 190 |
"example": {
|
| 191 |
"success": True,
|
| 192 |
+
"response": "Malaria symptoms include high fever, chills, headache...",
|
| 193 |
"detected_language": "en",
|
| 194 |
+
"conversation_id": "conv_123",
|
| 195 |
+
"context_used": ["Malaria treatment protocols", "Symptom guidelines"],
|
| 196 |
+
"processing_time": 2.5,
|
| 197 |
+
"confidence": "high"
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
class AudioResponse(BaseModel):
|
| 202 |
+
"""Audio processing response model"""
|
| 203 |
+
success: bool = Field(..., description="Whether the request was successful")
|
| 204 |
+
transcription: str = Field(..., description="Transcribed text from audio")
|
| 205 |
+
response: str = Field(..., description="The medical response")
|
| 206 |
+
detected_language: str = Field(..., description="Detected audio language")
|
| 207 |
+
conversation_id: str = Field(..., description="Conversation identifier")
|
| 208 |
+
context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
|
| 209 |
+
processing_time: float = Field(..., description="Response time in seconds")
|
| 210 |
+
audio_duration: Optional[float] = Field(None, description="Audio duration in seconds")
|
| 211 |
+
|
| 212 |
+
class Config:
|
| 213 |
+
schema_extra = {
|
| 214 |
+
"example": {
|
| 215 |
+
"success": True,
|
| 216 |
+
"transcription": "What are the symptoms of malaria?",
|
| 217 |
+
"response": "Malaria symptoms include high fever, chills...",
|
| 218 |
+
"detected_language": "en",
|
| 219 |
+
"conversation_id": "conv_456",
|
| 220 |
+
"context_used": ["Malaria diagnosis"],
|
| 221 |
+
"processing_time": 3.2,
|
| 222 |
+
"audio_duration": 4.5
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
class FeedbackRequest(BaseModel):
|
| 227 |
+
"""Feedback request model"""
|
| 228 |
+
conversation_id: str = Field(..., description="Conversation ID")
|
| 229 |
+
rating: int = Field(..., description="Rating from 1-5", ge=1, le=5)
|
| 230 |
+
feedback: Optional[str] = Field(None, description="Optional text feedback", max_length=500)
|
| 231 |
+
|
| 232 |
+
class Config:
|
| 233 |
+
schema_extra = {
|
| 234 |
+
"example": {
|
| 235 |
+
"conversation_id": "conv_123",
|
| 236 |
+
"rating": 5,
|
| 237 |
+
"feedback": "Very helpful and accurate medical information"
|
| 238 |
}
|
| 239 |
}
|
| 240 |
|
| 241 |
class HealthStatus(BaseModel):
|
| 242 |
+
"""System health status model"""
|
| 243 |
+
status: str = Field(..., description="Overall system status")
|
| 244 |
+
models_loaded: bool = Field(..., description="Whether ML models are loaded")
|
| 245 |
+
audio_available: bool = Field(..., description="Whether audio processing is available")
|
| 246 |
+
uptime: float = Field(..., description="System uptime in seconds")
|
| 247 |
version: str = Field(..., description="API version")
|
| 248 |
+
|
| 249 |
+
class Config:
|
| 250 |
+
schema_extra = {
|
| 251 |
+
"example": {
|
| 252 |
+
"status": "healthy",
|
| 253 |
+
"models_loaded": True,
|
| 254 |
+
"audio_available": True,
|
| 255 |
+
"uptime": 3600.0,
|
| 256 |
+
"version": "2.0.0"
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
class ErrorResponse(BaseModel):
|
| 261 |
+
"""Error response model"""
|
| 262 |
+
success: bool = Field(False, description="Always false for errors")
|
| 263 |
+
error: str = Field(..., description="Error message")
|
| 264 |
+
error_code: str = Field(..., description="Error code")
|
| 265 |
+
conversation_id: Optional[str] = Field(None, description="Conversation ID if available")
|
| 266 |
+
|
| 267 |
+
# ============================================================================
|
| 268 |
+
# UTILITY FUNCTIONS
|
| 269 |
+
# ============================================================================
|
| 270 |
+
|
| 271 |
+
def generate_conversation_id() -> str:
|
| 272 |
+
"""Generate a unique conversation ID"""
|
| 273 |
+
return f"conv_{uuid.uuid4().hex[:8]}"
|
| 274 |
|
| 275 |
def validate_models():
|
| 276 |
"""Check if models are loaded"""
|
| 277 |
if pipeline is None:
|
| 278 |
raise HTTPException(
|
| 279 |
status_code=503,
|
| 280 |
+
detail="Medical AI models are not loaded yet. Please try again in a moment."
|
| 281 |
)
|
| 282 |
|
| 283 |
# ============================================================================
|
| 284 |
+
# API ENDPOINTS
|
| 285 |
# ============================================================================
|
| 286 |
|
| 287 |
@app.get("/", tags=["system"])
|
| 288 |
async def root():
|
| 289 |
+
"""Root endpoint with API information"""
|
| 290 |
return {
|
| 291 |
+
"message": "🩺 Medical AI Assistant API",
|
| 292 |
+
"version": "2.0.0",
|
| 293 |
"status": "running",
|
| 294 |
"docs": "/docs",
|
| 295 |
+
"redoc": "/redoc",
|
| 296 |
"endpoints": {
|
| 297 |
"medical_consultation": "/medical/ask",
|
| 298 |
+
"audio_consultation": "/medical/audio",
|
| 299 |
+
"health_check": "/health",
|
| 300 |
+
"feedback": "/feedback"
|
| 301 |
+
}
|
| 302 |
}
|
| 303 |
|
| 304 |
@app.get("/health", response_model=HealthStatus, tags=["system"])
|
| 305 |
async def health_check():
|
| 306 |
+
"""
|
| 307 |
+
## System Health Check
|
| 308 |
+
|
| 309 |
+
Returns the current status of the Medical AI system including:
|
| 310 |
+
- Overall system health
|
| 311 |
+
- Model loading status
|
| 312 |
+
- Audio processing availability
|
| 313 |
+
- System uptime
|
| 314 |
+
"""
|
| 315 |
+
global pipeline, whisper_model
|
| 316 |
+
|
| 317 |
+
# Calculate uptime (simplified)
|
| 318 |
+
uptime = time.time() - getattr(health_check, 'start_time', time.time())
|
| 319 |
+
if not hasattr(health_check, 'start_time'):
|
| 320 |
+
health_check.start_time = time.time()
|
| 321 |
|
| 322 |
return HealthStatus(
|
| 323 |
status="healthy" if pipeline is not None else "loading",
|
| 324 |
models_loaded=pipeline is not None,
|
| 325 |
+
audio_available=whisper_model is not None,
|
| 326 |
+
uptime=uptime,
|
| 327 |
+
version="2.0.0"
|
| 328 |
)
|
| 329 |
|
| 330 |
@app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
|
| 331 |
async def medical_consultation(request: MedicalQuestion):
|
| 332 |
"""
|
| 333 |
+
## Text-based Medical Consultation
|
| 334 |
|
| 335 |
+
Process a medical question and return expert medical guidance.
|
| 336 |
|
| 337 |
**Features:**
|
| 338 |
+
- 🌍 Multilingual support (auto-detect or specify language)
|
| 339 |
+
- 🧠 AI-powered medical knowledge retrieval
|
| 340 |
+
- ⚡ Fast response generation
|
| 341 |
+
- 🔒 Medical disclaimers included
|
| 342 |
+
|
| 343 |
+
**Supported Languages:** English (en), French (fr), Auto-detect (auto)
|
| 344 |
"""
|
| 345 |
start_time = time.time()
|
| 346 |
+
validate_models()
|
| 347 |
|
| 348 |
+
conversation_id = request.conversation_id or generate_conversation_id()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
try:
|
| 351 |
logger.info(f"🩺 Processing medical question: {request.question[:50]}...")
|
|
|
|
| 363 |
success=True,
|
| 364 |
response=result["response"],
|
| 365 |
detected_language=result["source_lang"],
|
| 366 |
+
conversation_id=conversation_id,
|
| 367 |
+
context_used=result.get("context_used", []),
|
| 368 |
+
processing_time=round(processing_time, 2),
|
| 369 |
+
confidence=result.get("confidence", "medium")
|
| 370 |
)
|
| 371 |
|
| 372 |
except Exception as e:
|
| 373 |
+
logger.error(f"❌ Error in medical consultation: {str(e)}", exc_info=True)
|
| 374 |
processing_time = time.time() - start_time
|
| 375 |
|
| 376 |
raise HTTPException(
|
| 377 |
status_code=500,
|
| 378 |
detail={
|
| 379 |
"success": False,
|
| 380 |
+
"error": "Internal processing error occurred",
|
| 381 |
+
"error_code": "MEDICAL_PROCESSING_ERROR",
|
| 382 |
+
"conversation_id": conversation_id,
|
| 383 |
"processing_time": round(processing_time, 2)
|
| 384 |
}
|
| 385 |
)
|
| 386 |
|
| 387 |
+
@app.post("/medical/audio", response_model=AudioResponse, tags=["audio", "medical"])
|
| 388 |
+
async def audio_medical_consultation(
|
| 389 |
+
file: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, etc.)")
|
| 390 |
+
):
|
| 391 |
"""
|
| 392 |
+
## Audio-based Medical Consultation
|
| 393 |
+
|
| 394 |
+
Process an audio medical question and return expert medical guidance.
|
| 395 |
+
|
| 396 |
+
**Features:**
|
| 397 |
+
- 🎤 Speech-to-text conversion
|
| 398 |
+
- 🌍 Language detection from audio
|
| 399 |
+
- 🧠 Medical AI processing of transcribed text
|
| 400 |
+
- 📝 Full transcription provided
|
| 401 |
|
| 402 |
+
**Supported Audio Formats:** WAV, MP3, M4A, FLAC, OGG
|
| 403 |
+
**Max File Size:** 25MB
|
| 404 |
+
**Max Duration:** 5 minutes
|
| 405 |
"""
|
| 406 |
+
start_time = time.time()
|
| 407 |
+
validate_models()
|
| 408 |
+
|
| 409 |
+
if whisper_model is None:
|
| 410 |
+
raise HTTPException(
|
| 411 |
+
status_code=503,
|
| 412 |
+
detail="Audio processing is currently unavailable"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
conversation_id = generate_conversation_id()
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
logger.info(f"🎤 Processing audio file: {file.filename}")
|
| 419 |
+
|
| 420 |
+
# Read audio file
|
| 421 |
+
file_bytes = await file.read()
|
| 422 |
+
|
| 423 |
+
# Process audio
|
| 424 |
+
from audio_utils import preprocess_audio
|
| 425 |
+
processed_audio = preprocess_audio(file_bytes)
|
| 426 |
+
|
| 427 |
+
if len(processed_audio) == 0:
|
| 428 |
+
raise HTTPException(
|
| 429 |
+
status_code=400,
|
| 430 |
+
detail="Could not process audio file. Please check the format and try again."
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Transcribe audio
|
| 434 |
+
segments, info = whisper_model.transcribe(
|
| 435 |
+
processed_audio,
|
| 436 |
+
beam_size=5,
|
| 437 |
+
language=None,
|
| 438 |
+
task='transcribe',
|
| 439 |
+
vad_filter=True
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
transcription = "".join([seg.text for seg in segments])
|
| 443 |
+
detected_language = info.language
|
| 444 |
+
|
| 445 |
+
if not transcription.strip():
|
| 446 |
+
raise HTTPException(
|
| 447 |
+
status_code=400,
|
| 448 |
+
detail="Could not transcribe audio. Please ensure clear speech and try again."
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
logger.info(f"🔤 Transcription: {transcription[:100]}...")
|
| 452 |
+
|
| 453 |
+
# Process transcribed text with medical AI
|
| 454 |
+
result = pipeline.process(
|
| 455 |
+
question=transcription,
|
| 456 |
+
user_lang=detected_language,
|
| 457 |
+
conversation_history=[]
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
processing_time = time.time() - start_time
|
| 461 |
+
|
| 462 |
+
return AudioResponse(
|
| 463 |
+
success=True,
|
| 464 |
+
transcription=transcription,
|
| 465 |
+
response=result["response"],
|
| 466 |
+
detected_language=detected_language,
|
| 467 |
+
conversation_id=conversation_id,
|
| 468 |
+
context_used=result.get("context_used", []),
|
| 469 |
+
processing_time=round(processing_time, 2),
|
| 470 |
+
audio_duration=len(processed_audio) / 16000 # Assuming 16kHz sample rate
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
except HTTPException:
|
| 474 |
+
raise
|
| 475 |
+
except Exception as e:
|
| 476 |
+
logger.error(f"❌ Error in audio processing: {str(e)}", exc_info=True)
|
| 477 |
+
processing_time = time.time() - start_time
|
| 478 |
+
|
| 479 |
+
raise HTTPException(
|
| 480 |
+
status_code=500,
|
| 481 |
+
detail={
|
| 482 |
+
"success": False,
|
| 483 |
+
"error": "Audio processing error occurred",
|
| 484 |
+
"error_code": "AUDIO_PROCESSING_ERROR",
|
| 485 |
+
"conversation_id": conversation_id,
|
| 486 |
+
"processing_time": round(processing_time, 2)
|
| 487 |
+
}
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
@app.post("/feedback", tags=["feedback"])
|
| 491 |
+
async def submit_feedback(request: FeedbackRequest):
|
| 492 |
+
"""
|
| 493 |
+
## Submit User Feedback
|
| 494 |
+
|
| 495 |
+
Submit feedback about a medical consultation to help improve the service.
|
| 496 |
+
|
| 497 |
+
**Rating Scale:**
|
| 498 |
+
- 1: Very Poor
|
| 499 |
+
- 2: Poor
|
| 500 |
+
- 3: Average
|
| 501 |
+
- 4: Good
|
| 502 |
+
- 5: Excellent
|
| 503 |
+
"""
|
| 504 |
+
try:
|
| 505 |
+
logger.info(f"📊 Feedback received - ID: {request.conversation_id}, Rating: {request.rating}")
|
| 506 |
+
|
| 507 |
+
# Here you could store feedback in a database
|
| 508 |
+
# For now, just log it
|
| 509 |
+
feedback_data = {
|
| 510 |
+
"conversation_id": request.conversation_id,
|
| 511 |
+
"rating": request.rating,
|
| 512 |
+
"feedback": request.feedback,
|
| 513 |
+
"timestamp": time.time()
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
return {
|
| 517 |
+
"success": True,
|
| 518 |
+
"message": "Thank you for your feedback! This helps us improve our medical AI service.",
|
| 519 |
+
"feedback_id": f"fb_{uuid.uuid4().hex[:8]}"
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
except Exception as e:
|
| 523 |
+
logger.error(f"❌ Error processing feedback: {str(e)}")
|
| 524 |
+
raise HTTPException(
|
| 525 |
+
status_code=500,
|
| 526 |
+
detail="Error processing feedback"
|
| 527 |
+
)
|
| 528 |
|
| 529 |
+
@app.get("/medical/specialties", tags=["medical"])
|
| 530 |
+
async def get_medical_specialties():
|
| 531 |
+
"""
|
| 532 |
+
## Get Supported Medical Specialties
|
| 533 |
+
|
| 534 |
+
Returns a list of medical specialties and conditions supported by the AI.
|
| 535 |
+
"""
|
| 536 |
return {
|
|
|
|
| 537 |
"specialties": [
|
| 538 |
+
{
|
| 539 |
+
"name": "Primary Care",
|
| 540 |
+
"description": "General medical consultations and health guidance",
|
| 541 |
+
"conditions": ["General symptoms", "Preventive care", "Health maintenance"]
|
| 542 |
+
},
|
| 543 |
+
{
|
| 544 |
+
"name": "Infectious Diseases",
|
| 545 |
+
"description": "Infectious disease diagnosis and treatment",
|
| 546 |
+
"conditions": ["Malaria", "Tuberculosis", "HIV/AIDS", "Respiratory infections"]
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"name": "Emergency Medicine",
|
| 550 |
+
"description": "Emergency protocols and urgent care guidance",
|
| 551 |
+
"conditions": ["Stroke recognition", "Cardiac emergencies", "Trauma assessment"]
|
| 552 |
+
},
|
| 553 |
+
{
|
| 554 |
+
"name": "Chronic Disease Management",
|
| 555 |
+
"description": "Management of chronic conditions",
|
| 556 |
+
"conditions": ["Diabetes", "Hypertension", "Gastritis"]
|
| 557 |
+
}
|
| 558 |
],
|
| 559 |
+
"languages_supported": ["English", "French", "Auto-detect"],
|
| 560 |
+
"disclaimer": "This AI provides educational information only. Always consult healthcare professionals for medical advice."
|
| 561 |
}
|
| 562 |
|
| 563 |
# ============================================================================
|
|
|
|
| 571 |
content={
|
| 572 |
"success": False,
|
| 573 |
"error": "Endpoint not found",
|
| 574 |
+
"error_code": "NOT_FOUND",
|
| 575 |
"available_endpoints": [
|
| 576 |
"/docs - API Documentation",
|
| 577 |
+
"/medical/ask - Text consultation",
|
| 578 |
+
"/medical/audio - Audio consultation",
|
| 579 |
"/health - System status",
|
| 580 |
+
"/feedback - Submit feedback"
|
| 581 |
]
|
| 582 |
}
|
| 583 |
)
|
|
|
|
| 588 |
status_code=422,
|
| 589 |
content={
|
| 590 |
"success": False,
|
| 591 |
+
"error": "Invalid request data",
|
| 592 |
+
"error_code": "VALIDATION_ERROR",
|
| 593 |
+
"details": exc.errors()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
}
|
| 595 |
)
|
| 596 |
|
|
|
|
| 598 |
# STARTUP MESSAGE
|
| 599 |
# ============================================================================
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
if __name__ == "__main__":
|
| 602 |
import uvicorn
|
| 603 |
|
| 604 |
+
print("🩺 Starting Medical AI Assistant API...")
|
| 605 |
+
print("📚 Documentation available at: http://localhost:8000/docs")
|
| 606 |
+
print("🔄 Alternative docs at: http://localhost:8000/redoc")
|
| 607 |
|
| 608 |
uvicorn.run(
|
| 609 |
app,
|
| 610 |
host="0.0.0.0",
|
| 611 |
+
port=8000,
|
| 612 |
log_level="info",
|
| 613 |
reload=False
|
| 614 |
)
|
medical_ai.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# medical_ai.py -
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
@@ -7,7 +7,8 @@ from typing import List, Dict, Any
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
import faiss
|
| 9 |
from functools import lru_cache
|
| 10 |
-
from transformers import
|
|
|
|
| 11 |
import torch
|
| 12 |
from typing import Optional
|
| 13 |
import logging
|
|
@@ -17,52 +18,64 @@ import re
|
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
-
# ===
|
| 21 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
PATIENT_RECORDS_PATH = "patient_records.json"
|
| 25 |
|
| 26 |
-
#
|
| 27 |
DEVICE = "cpu"
|
| 28 |
-
MAX_LENGTH =
|
| 29 |
-
TEMPERATURE = 0.7
|
| 30 |
TOP_P = 0.9
|
| 31 |
-
TOP_K =
|
| 32 |
|
| 33 |
-
# === 1.
|
| 34 |
-
class
|
| 35 |
def __init__(self):
|
| 36 |
-
# Use lightweight langdetect instead of heavy ML model
|
| 37 |
try:
|
| 38 |
-
|
| 39 |
-
self.
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
@lru_cache(maxsize=256)
|
| 46 |
def detect_language(self, text: str) -> str:
|
| 47 |
if not text.strip():
|
| 48 |
return 'en'
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
if self.
|
| 52 |
try:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
except:
|
| 57 |
-
|
| 58 |
|
| 59 |
-
# Fallback to keyword detection
|
| 60 |
return self._keyword_detection(text)
|
| 61 |
|
| 62 |
def _keyword_detection(self, text: str) -> str:
|
| 63 |
-
"""
|
| 64 |
-
french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement']
|
| 65 |
-
english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment']
|
| 66 |
|
| 67 |
text_lower = text.lower()
|
| 68 |
fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
|
|
@@ -70,32 +83,60 @@ class SimpleLanguageDetector:
|
|
| 70 |
|
| 71 |
return 'fr' if fr_score > en_score else 'en'
|
| 72 |
|
| 73 |
-
# === 2.
|
| 74 |
-
class
|
| 75 |
-
def __init__(self):
|
| 76 |
-
# Use a lightweight translation approach
|
| 77 |
try:
|
| 78 |
-
|
| 79 |
-
self.
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
except Exception as e:
|
| 82 |
logger.error(f"Error initializing translator: {str(e)}")
|
| 83 |
-
self.
|
|
|
|
| 84 |
|
| 85 |
@lru_cache(maxsize=256)
|
| 86 |
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 87 |
if not text.strip() or source_lang == target_lang:
|
| 88 |
return text
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
# === 3.
|
| 95 |
-
class
|
| 96 |
def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
|
| 97 |
try:
|
| 98 |
-
logger.info("Loading lightweight embedder...")
|
| 99 |
self.embedder = SentenceTransformer(embedding_model_name)
|
| 100 |
|
| 101 |
if not os.path.exists(records_path):
|
|
@@ -105,33 +146,43 @@ class LightweightMedicalRAG:
|
|
| 105 |
with open(records_path, 'r', encoding='utf-8') as f:
|
| 106 |
self.records = json.load(f)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
self.medical_chunks = []
|
| 110 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
#
|
| 113 |
-
self.medical_index = self._build_faiss_index(self.medical_chunks)
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
logger.info(f"
|
|
|
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
-
logger.error(f"Error initializing RAG: {str(e)}")
|
| 119 |
self._initialize_fallback()
|
| 120 |
|
| 121 |
def _create_sample_records(self, path: str):
|
| 122 |
-
"""
|
| 123 |
sample_records = [
|
| 124 |
{
|
| 125 |
"id": "malaria_001",
|
| 126 |
-
"diagnosis": {"en": "Malaria", "fr": "Paludisme"},
|
| 127 |
-
"symptoms": {"en": "
|
| 128 |
-
"
|
|
|
|
| 129 |
},
|
| 130 |
{
|
| 131 |
-
"id": "
|
| 132 |
-
"
|
| 133 |
-
"
|
| 134 |
-
"
|
|
|
|
| 135 |
}
|
| 136 |
]
|
| 137 |
|
|
@@ -139,25 +190,65 @@ class LightweightMedicalRAG:
|
|
| 139 |
json.dump(sample_records, f, ensure_ascii=False, indent=2)
|
| 140 |
|
| 141 |
def _initialize_fallback(self):
|
| 142 |
-
"""
|
| 143 |
-
self.medical_chunks = [
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
self.medical_index = None
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
"""
|
| 152 |
for rec in self.records:
|
| 153 |
try:
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
chunk += f"Treatment: {rec['treatment'].get('en', '')}"
|
| 159 |
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
except Exception as e:
|
| 163 |
logger.error(f"Error processing record: {str(e)}")
|
|
@@ -165,152 +256,281 @@ class LightweightMedicalRAG:
|
|
| 165 |
|
| 166 |
def _build_faiss_index(self, chunks):
|
| 167 |
if not chunks:
|
| 168 |
-
return None
|
| 169 |
try:
|
| 170 |
embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
|
| 171 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 172 |
index.add(embeddings)
|
| 173 |
-
return index
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f"Error building FAISS index: {str(e)}")
|
| 176 |
-
return None
|
| 177 |
|
| 178 |
-
def
|
| 179 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
try:
|
| 181 |
-
if self.medical_index is None:
|
| 182 |
-
return self.medical_chunks[:2]
|
| 183 |
-
|
| 184 |
q_emb = self.embedder.encode([question], convert_to_numpy=True)
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
except Exception as e:
|
| 189 |
-
logger.error(f"Error getting contexts: {str(e)}")
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
-
# === 4.
|
| 193 |
-
class
|
| 194 |
def __init__(self, model_name: str = MODEL_NAME):
|
| 195 |
self.device = DEVICE
|
| 196 |
-
logger.info(f"Loading
|
| 197 |
|
| 198 |
try:
|
| 199 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
self.generator = pipeline(
|
| 201 |
"text-generation",
|
| 202 |
-
model=
|
|
|
|
| 203 |
device=-1, # CPU
|
| 204 |
-
|
| 205 |
-
model_kwargs={"low_cpu_mem_usage": True}
|
| 206 |
)
|
| 207 |
|
| 208 |
-
logger.info(f"
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"Error loading model: {str(e)}")
|
| 211 |
self.generator = None
|
| 212 |
|
| 213 |
-
def
|
| 214 |
-
"""
|
| 215 |
|
| 216 |
if self.generator is None:
|
| 217 |
-
return self.
|
| 218 |
|
| 219 |
try:
|
| 220 |
-
#
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
if lang == "fr":
|
| 224 |
-
prompt = f"Contexte médical: {context_str}\n\nQuestion: {question}\n\nRéponse médicale:"
|
| 225 |
-
else:
|
| 226 |
-
prompt = f"Medical context: {context_str}\n\nQuestion: {question}\n\nMedical response:"
|
| 227 |
|
| 228 |
-
#
|
| 229 |
response = self.generator(
|
| 230 |
prompt,
|
| 231 |
-
max_length=len(prompt) +
|
|
|
|
| 232 |
temperature=TEMPERATURE,
|
| 233 |
top_p=TOP_P,
|
| 234 |
top_k=TOP_K,
|
| 235 |
do_sample=True,
|
| 236 |
-
pad_token_id=self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
-
#
|
| 240 |
full_text = response[0]['generated_text']
|
| 241 |
response_text = full_text[len(prompt):].strip()
|
| 242 |
|
| 243 |
-
#
|
| 244 |
-
|
| 245 |
-
"en": "\n\n⚕️ Medical Disclaimer: Consult a healthcare professional for proper diagnosis.",
|
| 246 |
-
"fr": "\n\n⚕️ Avertissement médical: Consultez un professionnel de santé pour un diagnostic approprié."
|
| 247 |
-
}
|
| 248 |
-
|
| 249 |
-
if "disclaimer" not in response_text.lower():
|
| 250 |
-
response_text += disclaimer.get(lang, disclaimer["en"])
|
| 251 |
|
| 252 |
-
return response_text
|
| 253 |
|
| 254 |
except Exception as e:
|
| 255 |
-
logger.error(f"Error in generation: {str(e)}")
|
| 256 |
-
return self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
-
def
|
| 259 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
templates = {
|
| 262 |
-
"en":
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
}
|
| 265 |
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
# ===
|
| 269 |
-
class
|
| 270 |
def __init__(self):
|
| 271 |
-
logger.info("
|
| 272 |
try:
|
| 273 |
-
self.lang_detector =
|
| 274 |
-
self.translator =
|
| 275 |
-
self.rag =
|
| 276 |
-
self.llm =
|
| 277 |
-
logger.info("
|
| 278 |
except Exception as e:
|
| 279 |
-
logger.error(f"Error initializing pipeline: {str(e)}")
|
| 280 |
raise
|
| 281 |
|
| 282 |
def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
|
| 283 |
-
"""
|
| 284 |
try:
|
| 285 |
if not question or not question.strip():
|
| 286 |
return self._empty_question_response(user_lang)
|
| 287 |
|
| 288 |
-
#
|
| 289 |
detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
|
| 290 |
-
logger.info(f"Processing question in {detected_lang}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
| 295 |
-
#
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
|
| 298 |
return {
|
| 299 |
-
"response":
|
| 300 |
"source_lang": detected_lang,
|
| 301 |
-
"context_used":
|
| 302 |
-
"confidence": "
|
| 303 |
}
|
| 304 |
|
| 305 |
except Exception as e:
|
| 306 |
-
logger.error(f"
|
| 307 |
return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
|
| 308 |
|
| 309 |
def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
|
| 310 |
-
"""
|
| 311 |
responses = {
|
| 312 |
-
"en": "Please provide a medical question for
|
| 313 |
-
"fr": "Veuillez poser une question médicale pour
|
| 314 |
}
|
| 315 |
lang = user_lang if user_lang != "auto" else "en"
|
| 316 |
return {
|
|
@@ -321,18 +541,17 @@ class SpacesMedicalAIPipeline:
|
|
| 321 |
}
|
| 322 |
|
| 323 |
def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
|
| 324 |
-
"""
|
| 325 |
responses = {
|
| 326 |
-
"en": "I'm experiencing technical difficulties. Please try rephrasing your medical question.",
|
| 327 |
-
"fr": "Je rencontre des difficultés techniques. Veuillez reformuler votre question médicale."
|
| 328 |
}
|
| 329 |
return {
|
| 330 |
"response": responses.get(lang, responses["en"]),
|
| 331 |
"source_lang": lang,
|
| 332 |
"context_used": [],
|
| 333 |
-
"confidence": "
|
| 334 |
}
|
| 335 |
|
| 336 |
-
#
|
| 337 |
-
|
| 338 |
-
MedicalAIPipeline = SpacesMedicalAIPipeline
|
|
|
|
| 1 |
+
# medical_ai.py - VERSION COMPETITION OPTIMISÉE
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
import faiss
|
| 9 |
from functools import lru_cache
|
| 10 |
+
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
import torch
|
| 13 |
from typing import Optional
|
| 14 |
import logging
|
|
|
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
+
# === CONFIGURATION COMPÉTITION ===
|
| 22 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 23 |
+
NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
|
| 24 |
+
# MODÈLE PRINCIPAL - MEDIUM pour la compétition
|
| 25 |
+
MODEL_NAME = "microsoft/DialoGPT-medium"
|
| 26 |
PATIENT_RECORDS_PATH = "patient_records.json"
|
| 27 |
|
| 28 |
+
# Configuration optimisée pour CPU avec performance maximale
|
| 29 |
DEVICE = "cpu"
|
| 30 |
+
MAX_LENGTH = 512 # Augmenté pour des réponses plus complètes
|
| 31 |
+
TEMPERATURE = 0.7 # Équilibre créativité/cohérence
|
| 32 |
TOP_P = 0.9
|
| 33 |
+
TOP_K = 50
|
| 34 |
|
| 35 |
+
# === 1. DÉTECTION DE LANGUE AVANCÉE ===
|
| 36 |
+
class AdvancedLanguageDetector:
|
| 37 |
def __init__(self):
|
|
|
|
| 38 |
try:
|
| 39 |
+
# Utilise un modèle plus précis pour la détection
|
| 40 |
+
self.lang_id = pipeline("text-classification",
|
| 41 |
+
model="papluca/xlm-roberta-base-language-detection",
|
| 42 |
+
device=-1) # Force CPU
|
| 43 |
+
self.lang_map = {
|
| 44 |
+
'fr': 'fr', 'en': 'en', 'bss': 'bss', 'dua': 'dua', 'ewo': 'ewo',
|
| 45 |
+
'fr-FR': 'fr', 'en-EN': 'en', 'fr_XX': 'fr', 'en_XX': 'en',
|
| 46 |
+
'LABEL_0': 'en', 'LABEL_1': 'fr' # Fallbacks
|
| 47 |
+
}
|
| 48 |
+
logger.info("Advanced language detector initialized")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"Error initializing language detector: {str(e)}")
|
| 51 |
+
self.lang_id = None
|
| 52 |
|
| 53 |
@lru_cache(maxsize=256)
|
| 54 |
def detect_language(self, text: str) -> str:
|
| 55 |
if not text.strip():
|
| 56 |
return 'en'
|
| 57 |
|
| 58 |
+
# Méthode hybride : ML + règles
|
| 59 |
+
if self.lang_id:
|
| 60 |
try:
|
| 61 |
+
pred = self.lang_id(text)[0]
|
| 62 |
+
detected = pred['label'] if isinstance(pred, dict) else str(pred)
|
| 63 |
+
confidence = pred.get('score', 0.5) if isinstance(pred, dict) else 0.5
|
| 64 |
+
|
| 65 |
+
# Si confiance faible, utiliser détection par mots-clés
|
| 66 |
+
if confidence < 0.8:
|
| 67 |
+
return self._keyword_detection(text)
|
| 68 |
+
|
| 69 |
+
return self.lang_map.get(detected, 'en')
|
| 70 |
except:
|
| 71 |
+
return self._keyword_detection(text)
|
| 72 |
|
|
|
|
| 73 |
return self._keyword_detection(text)
|
| 74 |
|
| 75 |
def _keyword_detection(self, text: str) -> str:
|
| 76 |
+
"""Détection par mots-clés comme fallback"""
|
| 77 |
+
french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement', 'médecin', 'santé']
|
| 78 |
+
english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment', 'doctor', 'health']
|
| 79 |
|
| 80 |
text_lower = text.lower()
|
| 81 |
fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
|
|
|
|
| 83 |
|
| 84 |
return 'fr' if fr_score > en_score else 'en'
|
| 85 |
|
| 86 |
+
# === 2. TRADUCTION OPTIMISÉE ===
|
| 87 |
+
class OptimizedTranslator:
|
| 88 |
+
def __init__(self, model_name=NLLB_MODEL_NAME):
|
|
|
|
| 89 |
try:
|
| 90 |
+
self.tokenizer = NllbTokenizer.from_pretrained(model_name)
|
| 91 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 92 |
+
model_name,
|
| 93 |
+
torch_dtype=torch.float32, # CPU optimized
|
| 94 |
+
low_cpu_mem_usage=True
|
| 95 |
+
)
|
| 96 |
+
self.lang_code_map = {
|
| 97 |
+
'fr': 'fra_Latn', 'en': 'eng_Latn', 'bss': 'bss_Latn',
|
| 98 |
+
'dua': 'dua_Latn', 'ewo': 'ewo_Latn',
|
| 99 |
+
}
|
| 100 |
+
logger.info("Optimized translator initialized")
|
| 101 |
except Exception as e:
|
| 102 |
logger.error(f"Error initializing translator: {str(e)}")
|
| 103 |
+
self.tokenizer = None
|
| 104 |
+
self.model = None
|
| 105 |
|
| 106 |
@lru_cache(maxsize=256)
|
| 107 |
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 108 |
if not text.strip() or source_lang == target_lang:
|
| 109 |
return text
|
| 110 |
+
|
| 111 |
+
if self.tokenizer is None or self.model is None:
|
| 112 |
+
return text
|
| 113 |
|
| 114 |
+
try:
|
| 115 |
+
src = self.lang_code_map.get(source_lang, 'eng_Latn')
|
| 116 |
+
tgt = self.lang_code_map.get(target_lang, 'eng_Latn')
|
| 117 |
+
|
| 118 |
+
self.tokenizer.src_lang = src
|
| 119 |
+
inputs = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
generated_tokens = self.model.generate(
|
| 123 |
+
**inputs,
|
| 124 |
+
forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(tgt),
|
| 125 |
+
max_length=512,
|
| 126 |
+
num_beams=4, # Améliore la qualité
|
| 127 |
+
early_stopping=True
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
| 131 |
+
return result
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"Translation error: {str(e)}")
|
| 134 |
+
return text
|
| 135 |
|
| 136 |
+
# === 3. RAG MÉDICAL AVANCÉ ===
|
| 137 |
+
class AdvancedMedicalRAG:
|
| 138 |
def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
|
| 139 |
try:
|
|
|
|
| 140 |
self.embedder = SentenceTransformer(embedding_model_name)
|
| 141 |
|
| 142 |
if not os.path.exists(records_path):
|
|
|
|
| 146 |
with open(records_path, 'r', encoding='utf-8') as f:
|
| 147 |
self.records = json.load(f)
|
| 148 |
|
| 149 |
+
# Construction d'indices spécialisés
|
| 150 |
self.medical_chunks = []
|
| 151 |
+
self.educational_chunks = []
|
| 152 |
+
self.emergency_chunks = []
|
| 153 |
+
self.prevention_chunks = []
|
| 154 |
+
|
| 155 |
+
self._build_specialized_chunks()
|
| 156 |
|
| 157 |
+
# Indices FAISS multiples pour différents types de requêtes
|
| 158 |
+
self.medical_index, _ = self._build_faiss_index(self.medical_chunks)
|
| 159 |
+
self.edu_index, _ = self._build_faiss_index(self.educational_chunks)
|
| 160 |
+
self.emergency_index, _ = self._build_faiss_index(self.emergency_chunks)
|
| 161 |
+
self.prevention_index, _ = self._build_faiss_index(self.prevention_chunks)
|
| 162 |
|
| 163 |
+
logger.info(f"Advanced RAG initialized: {len(self.medical_chunks)} medical, "
|
| 164 |
+
f"{len(self.educational_chunks)} educational, {len(self.emergency_chunks)} emergency chunks")
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
+
logger.error(f"Error initializing Advanced RAG: {str(e)}")
|
| 168 |
self._initialize_fallback()
|
| 169 |
|
| 170 |
def _create_sample_records(self, path: str):
|
| 171 |
+
"""Crée des enregistrements médicaux de base pour la compétition"""
|
| 172 |
sample_records = [
|
| 173 |
{
|
| 174 |
"id": "malaria_001",
|
| 175 |
+
"diagnosis": {"en": "Malaria (Plasmodium falciparum)", "fr": "Paludisme (Plasmodium falciparum)"},
|
| 176 |
+
"symptoms": {"en": "High fever, chills, headache, nausea, vomiting, fatigue", "fr": "Fièvre élevée, frissons, maux de tête, nausées, vomissements, fatigue"},
|
| 177 |
+
"medications": [{"name": {"en": "Artemether-Lumefantrine", "fr": "Artéméther-Luméfantrine"}, "dosage": "20mg/120mg twice daily for 3 days"}],
|
| 178 |
+
"care_instructions": {"en": "Complete bed rest, increase fluid intake, complete full medication course, return if symptoms worsen or fever persists after 48 hours", "fr": "Repos complet au lit, augmenter l'apport hydrique, terminer le traitement complet, revenir si les symptômes s'aggravent ou si la fièvre persiste après 48 heures"}
|
| 179 |
},
|
| 180 |
{
|
| 181 |
+
"id": "diabetes_prevention",
|
| 182 |
+
"context_type": "prevention",
|
| 183 |
+
"topic": {"en": "Type 2 Diabetes Prevention", "fr": "Prévention du Diabète de Type 2"},
|
| 184 |
+
"educational_content": {"en": "Maintain healthy BMI (18.5-24.9), engage in 150 minutes moderate exercise weekly, consume balanced diet rich in fiber and low in processed sugars, regular blood glucose monitoring for high-risk individuals", "fr": "Maintenir un IMC sain (18,5-24,9), pratiquer 150 minutes d'exercice modéré par semaine, consommer une alimentation équilibrée riche en fibres et pauvre en sucres transformés, surveillance régulière de la glycémie pour les personnes à risque"},
|
| 185 |
+
"target_group": "Adults over 30, family history of diabetes, sedentary lifestyle"
|
| 186 |
}
|
| 187 |
]
|
| 188 |
|
|
|
|
| 190 |
json.dump(sample_records, f, ensure_ascii=False, indent=2)
|
| 191 |
|
| 192 |
def _initialize_fallback(self):
|
| 193 |
+
"""Initialise un système de fallback basique"""
|
| 194 |
+
self.medical_chunks = ["General medical consultation and symptom assessment"]
|
| 195 |
+
self.educational_chunks = ["Health education and prevention guidelines"]
|
| 196 |
+
self.emergency_chunks = ["Emergency medical procedures and protocols"]
|
| 197 |
+
self.prevention_chunks = ["Disease prevention and health maintenance"]
|
| 198 |
+
|
| 199 |
self.medical_index = None
|
| 200 |
+
self.edu_index = None
|
| 201 |
+
self.emergency_index = None
|
| 202 |
+
self.prevention_index = None
|
| 203 |
|
| 204 |
+
def _build_specialized_chunks(self):
|
| 205 |
+
"""Construit des chunks spécialisés pour différents types de requêtes médicales"""
|
| 206 |
for rec in self.records:
|
| 207 |
try:
|
| 208 |
+
# Chunks médicaux (diagnostics, traitements)
|
| 209 |
+
if 'diagnosis' in rec:
|
| 210 |
+
medical_parts = []
|
| 211 |
+
medical_parts.append(f"Condition: {rec['diagnosis'].get('en', '')}")
|
|
|
|
| 212 |
|
| 213 |
+
if 'symptoms' in rec:
|
| 214 |
+
medical_parts.append(f"Symptoms: {rec['symptoms'].get('en', '')}")
|
| 215 |
+
|
| 216 |
+
if 'medications' in rec:
|
| 217 |
+
meds = [f"{m['name'].get('en', '')} ({m.get('dosage', '')})" for m in rec['medications']]
|
| 218 |
+
medical_parts.append(f"Treatment: {', '.join(meds)}")
|
| 219 |
+
|
| 220 |
+
if 'care_instructions' in rec:
|
| 221 |
+
medical_parts.append(f"Care instructions: {rec['care_instructions'].get('en', '')}")
|
| 222 |
+
|
| 223 |
+
if medical_parts:
|
| 224 |
+
self.medical_chunks.append(". ".join(medical_parts))
|
| 225 |
+
|
| 226 |
+
# Chunks éducatifs
|
| 227 |
+
if rec.get('context_type') == 'prevention' or 'educational_content' in rec:
|
| 228 |
+
edu_parts = []
|
| 229 |
+
if 'topic' in rec:
|
| 230 |
+
edu_parts.append(f"Topic: {rec['topic'].get('en', '')}")
|
| 231 |
+
if 'educational_content' in rec:
|
| 232 |
+
edu_parts.append(f"Information: {rec['educational_content'].get('en', '')}")
|
| 233 |
+
if 'target_group' in rec:
|
| 234 |
+
edu_parts.append(f"Target: {rec['target_group']}")
|
| 235 |
+
|
| 236 |
+
if edu_parts:
|
| 237 |
+
chunk = ". ".join(edu_parts)
|
| 238 |
+
self.educational_chunks.append(chunk)
|
| 239 |
+
if 'prevention' in chunk.lower():
|
| 240 |
+
self.prevention_chunks.append(chunk)
|
| 241 |
+
|
| 242 |
+
# Chunks d'urgence
|
| 243 |
+
if rec.get('context_type') == 'emergency_education' or 'emergency' in str(rec).lower():
|
| 244 |
+
emergency_parts = []
|
| 245 |
+
if 'scenario' in rec:
|
| 246 |
+
emergency_parts.append(f"Emergency: {rec['scenario'].get('en', '')}")
|
| 247 |
+
if 'action_steps' in rec:
|
| 248 |
+
emergency_parts.append(f"Actions: {rec['action_steps'].get('en', '')}")
|
| 249 |
+
|
| 250 |
+
if emergency_parts:
|
| 251 |
+
self.emergency_chunks.append(". ".join(emergency_parts))
|
| 252 |
|
| 253 |
except Exception as e:
|
| 254 |
logger.error(f"Error processing record: {str(e)}")
|
|
|
|
| 256 |
|
| 257 |
def _build_faiss_index(self, chunks):
|
| 258 |
if not chunks:
|
| 259 |
+
return None, None
|
| 260 |
try:
|
| 261 |
embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
|
| 262 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 263 |
index.add(embeddings)
|
| 264 |
+
return index, embeddings
|
| 265 |
except Exception as e:
|
| 266 |
logger.error(f"Error building FAISS index: {str(e)}")
|
| 267 |
+
return None, None
|
| 268 |
|
| 269 |
+
def get_smart_contexts(self, question: str, lang: str = "en") -> Dict[str, List[str]]:
|
| 270 |
+
"""Récupère des contextes intelligents basés sur le type de question"""
|
| 271 |
+
question_lower = question.lower()
|
| 272 |
+
contexts = {
|
| 273 |
+
"medical": [],
|
| 274 |
+
"educational": [],
|
| 275 |
+
"emergency": [],
|
| 276 |
+
"prevention": []
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
try:
|
|
|
|
|
|
|
|
|
|
| 280 |
q_emb = self.embedder.encode([question], convert_to_numpy=True)
|
| 281 |
+
|
| 282 |
+
# Détection du type de question
|
| 283 |
+
is_emergency = any(word in question_lower for word in ['emergency', 'urgent', 'severe', 'critical', 'urgence', 'grave'])
|
| 284 |
+
is_prevention = any(word in question_lower for word in ['prevent', 'prevention', 'avoid', 'prévenir', 'éviter'])
|
| 285 |
+
is_educational = any(word in question_lower for word in ['what is', 'explain', 'how', 'why', "qu'est-ce que", 'expliquer', 'comment', 'pourquoi'])
|
| 286 |
+
|
| 287 |
+
# Récupération contextuelle intelligente
|
| 288 |
+
if is_emergency and self.emergency_index:
|
| 289 |
+
_, I = self.emergency_index.search(q_emb, min(3, len(self.emergency_chunks)))
|
| 290 |
+
contexts["emergency"] = [self.emergency_chunks[i] for i in I[0] if i < len(self.emergency_chunks)]
|
| 291 |
+
|
| 292 |
+
if is_prevention and self.prevention_index:
|
| 293 |
+
_, I = self.prevention_index.search(q_emb, min(2, len(self.prevention_chunks)))
|
| 294 |
+
contexts["prevention"] = [self.prevention_chunks[i] for i in I[0] if i < len(self.prevention_chunks)]
|
| 295 |
+
|
| 296 |
+
if is_educational and self.edu_index:
|
| 297 |
+
_, I = self.edu_index.search(q_emb, min(3, len(self.educational_chunks)))
|
| 298 |
+
contexts["educational"] = [self.educational_chunks[i] for i in I[0] if i < len(self.educational_chunks)]
|
| 299 |
+
|
| 300 |
+
# Toujours inclure du contexte médical général
|
| 301 |
+
if self.medical_index:
|
| 302 |
+
n_med = 4 if not any(contexts.values()) else 2
|
| 303 |
+
_, I = self.medical_index.search(q_emb, min(n_med, len(self.medical_chunks)))
|
| 304 |
+
contexts["medical"] = [self.medical_chunks[i] for i in I[0] if i < len(self.medical_chunks)]
|
| 305 |
|
| 306 |
except Exception as e:
|
| 307 |
+
logger.error(f"Error getting smart contexts: {str(e)}")
|
| 308 |
+
|
| 309 |
+
return contexts
|
| 310 |
|
| 311 |
+
# === 4. GÉNÉRATEUR LLM OPTIMISÉ ===
|
| 312 |
+
class CompetitionMedicalLLM:
|
| 313 |
def __init__(self, model_name: str = MODEL_NAME):
|
| 314 |
self.device = DEVICE
|
| 315 |
+
logger.info(f"Loading competition model {model_name} on {self.device}...")
|
| 316 |
|
| 317 |
try:
|
| 318 |
+
# Configuration optimale pour DialoGPT-medium sur CPU
|
| 319 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
| 320 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 321 |
+
model_name,
|
| 322 |
+
torch_dtype=torch.float32,
|
| 323 |
+
low_cpu_mem_usage=True,
|
| 324 |
+
device_map="auto" if DEVICE == "cuda" else None
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Configuration du tokenizer
|
| 328 |
+
if self.tokenizer.pad_token is None:
|
| 329 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 330 |
+
|
| 331 |
self.generator = pipeline(
|
| 332 |
"text-generation",
|
| 333 |
+
model=self.model,
|
| 334 |
+
tokenizer=self.tokenizer,
|
| 335 |
device=-1, # CPU
|
| 336 |
+
framework="pt"
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
+
logger.info(f"Competition model {model_name} loaded successfully")
|
| 340 |
except Exception as e:
|
| 341 |
logger.error(f"Error loading model: {str(e)}")
|
| 342 |
self.generator = None
|
| 343 |
|
| 344 |
+
def generate_expert_response(self, question: str, contexts: Dict[str, List[str]], lang: str = "en") -> str:
|
| 345 |
+
"""Génère une réponse d'expert médical de niveau compétition"""
|
| 346 |
|
| 347 |
if self.generator is None:
|
| 348 |
+
return self._expert_fallback_response(question, contexts, lang)
|
| 349 |
|
| 350 |
try:
|
| 351 |
+
# Construction du prompt expert
|
| 352 |
+
prompt = self._build_expert_prompt(question, contexts, lang)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
+
# Génération avec paramètres optimisés
|
| 355 |
response = self.generator(
|
| 356 |
prompt,
|
| 357 |
+
max_length=len(prompt) + 300, # Plus long pour des réponses complètes
|
| 358 |
+
num_return_sequences=1,
|
| 359 |
temperature=TEMPERATURE,
|
| 360 |
top_p=TOP_P,
|
| 361 |
top_k=TOP_K,
|
| 362 |
do_sample=True,
|
| 363 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 364 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 365 |
+
repetition_penalty=1.1, # Évite les répétitions
|
| 366 |
+
length_penalty=1.0,
|
| 367 |
+
early_stopping=True
|
| 368 |
)
|
| 369 |
|
| 370 |
+
# Extraction et nettoyage expert
|
| 371 |
full_text = response[0]['generated_text']
|
| 372 |
response_text = full_text[len(prompt):].strip()
|
| 373 |
|
| 374 |
+
# Post-processing expert
|
| 375 |
+
response_text = self._expert_post_process(response_text, lang)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
+
return response_text
|
| 378 |
|
| 379 |
except Exception as e:
|
| 380 |
+
logger.error(f"Error in expert generation: {str(e)}")
|
| 381 |
+
return self._expert_fallback_response(question, contexts, lang)
|
| 382 |
+
|
| 383 |
+
def _build_expert_prompt(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
|
| 384 |
+
"""Construit un prompt de niveau expert pour la compétition"""
|
| 385 |
+
|
| 386 |
+
# Agrégation intelligente des contextes
|
| 387 |
+
context_parts = []
|
| 388 |
+
if contexts.get("emergency"):
|
| 389 |
+
context_parts.append(f"🚨 Emergency Protocol: {' | '.join(contexts['emergency'][:2])}")
|
| 390 |
+
if contexts.get("medical"):
|
| 391 |
+
context_parts.append(f"📋 Clinical Information: {' | '.join(contexts['medical'][:2])}")
|
| 392 |
+
if contexts.get("prevention"):
|
| 393 |
+
context_parts.append(f"🛡️ Prevention Guidelines: {' | '.join(contexts['prevention'][:1])}")
|
| 394 |
+
if contexts.get("educational"):
|
| 395 |
+
context_parts.append(f"📚 Educational Content: {' | '.join(contexts['educational'][:1])}")
|
| 396 |
+
|
| 397 |
+
context_str = "\n".join(context_parts) if context_parts else "General medical consultation context."
|
| 398 |
+
|
| 399 |
+
# Prompt structuré pour excellence
|
| 400 |
+
if lang == "fr":
|
| 401 |
+
prompt = f"""Contexte médical expert:
|
| 402 |
+
{context_str}
|
| 403 |
+
|
| 404 |
+
Question du patient: {question}
|
| 405 |
+
|
| 406 |
+
Réponse médicale experte (structurée et complète):"""
|
| 407 |
+
else:
|
| 408 |
+
prompt = f"""Expert medical context:
|
| 409 |
+
{context_str}
|
| 410 |
+
|
| 411 |
+
Patient question: {question}
|
| 412 |
+
|
| 413 |
+
Expert medical response (structured and comprehensive):"""
|
| 414 |
+
|
| 415 |
+
return prompt
|
| 416 |
|
| 417 |
+
def _expert_post_process(self, response: str, lang: str) -> str:
|
| 418 |
+
"""Post-traitement expert de la réponse"""
|
| 419 |
+
|
| 420 |
+
# Nettoyage des artifacts
|
| 421 |
+
for stop_seq in ["</s>", "\nPatient:", "\nDoctor:", "\nExpert:", "\n\nContext:", "Question:"]:
|
| 422 |
+
if stop_seq in response:
|
| 423 |
+
response = response.split(stop_seq)[0].strip()
|
| 424 |
+
|
| 425 |
+
# Structuration expert
|
| 426 |
+
if len(response.split('.')) > 2: # Si réponse assez longue
|
| 427 |
+
sentences = [s.strip() for s in response.split('.') if s.strip()]
|
| 428 |
+
if len(sentences) >= 3:
|
| 429 |
+
response = '. '.join(sentences[:4]) + '.' # Limiter à 4 phrases max
|
| 430 |
+
|
| 431 |
+
# Ajout disclaimer expert
|
| 432 |
+
disclaimer = {
|
| 433 |
+
"en": "\n\n⚕️ Medical Disclaimer: This information is for educational purposes. Always consult a qualified healthcare professional for proper diagnosis and treatment.",
|
| 434 |
+
"fr": "\n\n⚕️ Avertissement médical: Cette information est à des fins éducatives. Consultez toujours un professionnel de santé qualifié pour un diagnostic et traitement appropriés."
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
if "consult" not in response.lower() and "disclaimer" not in response.lower():
|
| 438 |
+
response += disclaimer.get(lang, disclaimer["en"])
|
| 439 |
+
|
| 440 |
+
return response.strip()
|
| 441 |
+
|
| 442 |
+
def _expert_fallback_response(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
|
| 443 |
+
"""Réponse de fallback de niveau expert"""
|
| 444 |
|
| 445 |
templates = {
|
| 446 |
+
"en": {
|
| 447 |
+
"intro": "Based on medical expertise and available clinical information:",
|
| 448 |
+
"structure": "\n\n🔍 Assessment: This appears to be a medical inquiry requiring professional evaluation.\n\n💡 General Guidance: Monitor symptoms, maintain proper hygiene, stay hydrated, and seek appropriate medical care.\n\n⚠️ Important: For accurate diagnosis and treatment, please consult with a healthcare professional.",
|
| 449 |
+
"context_available": "According to medical literature and clinical guidelines: "
|
| 450 |
+
},
|
| 451 |
+
"fr": {
|
| 452 |
+
"intro": "Sur la base de l'expertise médicale et des informations cliniques disponibles:",
|
| 453 |
+
"structure": "\n\n🔍 Évaluation: Il s'agit d'une demande médicale nécessitant une évaluation professionnelle.\n\n💡 Guidance générale: Surveillez les symptômes, maintenez une hygiène appropriée, restez hydraté et consultez un professionnel de santé.\n\n⚠️ Important: Pour un diagnostic et traitement précis, veuillez consulter un professionnel de santé.",
|
| 454 |
+
"context_available": "Selon la littérature médicale et les directives cliniques: "
|
| 455 |
+
}
|
| 456 |
}
|
| 457 |
|
| 458 |
+
template = templates.get(lang, templates["en"])
|
| 459 |
+
response = template["intro"]
|
| 460 |
+
|
| 461 |
+
# Intégrer contextes si disponibles
|
| 462 |
+
all_contexts = []
|
| 463 |
+
for context_list in contexts.values():
|
| 464 |
+
all_contexts.extend(context_list)
|
| 465 |
+
|
| 466 |
+
if all_contexts:
|
| 467 |
+
response += f" {template['context_available']}{' | '.join(all_contexts[:2])}"
|
| 468 |
+
|
| 469 |
+
response += template["structure"]
|
| 470 |
+
|
| 471 |
+
return response
|
| 472 |
|
| 473 |
+
# === PIPELINE PRINCIPAL COMPÉTITION ===
|
| 474 |
+
class CompetitionMedicalAIPipeline:
|
| 475 |
def __init__(self):
|
| 476 |
+
logger.info("🏆 Initializing COMPETITION Medical AI Pipeline...")
|
| 477 |
try:
|
| 478 |
+
self.lang_detector = AdvancedLanguageDetector()
|
| 479 |
+
self.translator = OptimizedTranslator()
|
| 480 |
+
self.rag = AdvancedMedicalRAG()
|
| 481 |
+
self.llm = CompetitionMedicalLLM()
|
| 482 |
+
logger.info("🎯 Competition Medical AI Pipeline ready for excellence!")
|
| 483 |
except Exception as e:
|
| 484 |
+
logger.error(f"Error initializing competition pipeline: {str(e)}")
|
| 485 |
raise
|
| 486 |
|
| 487 |
def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
|
| 488 |
+
"""Traitement de niveau compétition"""
|
| 489 |
try:
|
| 490 |
if not question or not question.strip():
|
| 491 |
return self._empty_question_response(user_lang)
|
| 492 |
|
| 493 |
+
# Détection langue avancée
|
| 494 |
detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
|
| 495 |
+
logger.info(f"🎯 Processing competition-level question in {detected_lang}")
|
| 496 |
+
|
| 497 |
+
# Traduction si nécessaire avec qualité optimale
|
| 498 |
+
question_en = question
|
| 499 |
+
if detected_lang != "en":
|
| 500 |
+
question_en = self.translator.translate(question, detected_lang, "en")
|
| 501 |
+
|
| 502 |
+
# RAG intelligent multi-contexte
|
| 503 |
+
smart_contexts = self.rag.get_smart_contexts(question_en, "en")
|
| 504 |
+
|
| 505 |
+
# Génération experte
|
| 506 |
+
response_en = self.llm.generate_expert_response(question_en, smart_contexts, "en")
|
| 507 |
|
| 508 |
+
# Traduction retour avec qualité optimale
|
| 509 |
+
final_response = response_en
|
| 510 |
+
if detected_lang != "en":
|
| 511 |
+
final_response = self.translator.translate(response_en, "en", detected_lang)
|
| 512 |
|
| 513 |
+
# Contextes utilisés pour transparence
|
| 514 |
+
all_contexts = []
|
| 515 |
+
for context_list in smart_contexts.values():
|
| 516 |
+
all_contexts.extend(context_list)
|
| 517 |
|
| 518 |
return {
|
| 519 |
+
"response": final_response,
|
| 520 |
"source_lang": detected_lang,
|
| 521 |
+
"context_used": all_contexts[:5], # Top 5 contextes
|
| 522 |
+
"confidence": "high" # Indicateur de qualité
|
| 523 |
}
|
| 524 |
|
| 525 |
except Exception as e:
|
| 526 |
+
logger.error(f"Competition processing error: {str(e)}")
|
| 527 |
return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
|
| 528 |
|
| 529 |
def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
|
| 530 |
+
"""Réponse pour question vide"""
|
| 531 |
responses = {
|
| 532 |
+
"en": "Please provide a medical question for me to assist you with professional healthcare guidance.",
|
| 533 |
+
"fr": "Veuillez poser une question médicale pour que je puisse vous fournir des conseils de santé professionnels."
|
| 534 |
}
|
| 535 |
lang = user_lang if user_lang != "auto" else "en"
|
| 536 |
return {
|
|
|
|
| 541 |
}
|
| 542 |
|
| 543 |
def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
|
| 544 |
+
"""Réponse d'erreur professionnelle"""
|
| 545 |
responses = {
|
| 546 |
+
"en": "I apologize, but I'm experiencing technical difficulties. Please try rephrasing your medical question, and I'll provide you with professional healthcare guidance.",
|
| 547 |
+
"fr": "Je m'excuse, mais je rencontre des difficultés techniques. Veuillez reformuler votre question médicale, et je vous fournirai des conseils de santé professionnels."
|
| 548 |
}
|
| 549 |
return {
|
| 550 |
"response": responses.get(lang, responses["en"]),
|
| 551 |
"source_lang": lang,
|
| 552 |
"context_used": [],
|
| 553 |
+
"confidence": "medium"
|
| 554 |
}
|
| 555 |
|
| 556 |
+
# Alias pour compatibilité
|
| 557 |
+
MedicalAIPipeline = CompetitionMedicalAIPipeline
|
|
|
requirements.txt
CHANGED
|
@@ -1,34 +1,39 @@
|
|
| 1 |
-
# FastAPI Medical AI -
|
| 2 |
# Core FastAPI dependencies
|
| 3 |
fastapi==0.104.1
|
| 4 |
uvicorn[standard]==0.24.0
|
| 5 |
python-multipart==0.0.6
|
| 6 |
pydantic==2.4.2
|
| 7 |
|
| 8 |
-
#
|
| 9 |
transformers==4.35.2
|
| 10 |
-
torch==2.1.0
|
| 11 |
sentence-transformers==2.2.2
|
| 12 |
faiss-cpu==1.7.4
|
|
|
|
| 13 |
|
| 14 |
-
# Audio processing
|
| 15 |
librosa==0.10.1
|
| 16 |
soundfile==0.12.1
|
| 17 |
numpy==1.24.3
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
# accelerate==0.24.1 # Not needed for CPU
|
| 22 |
-
# optimum==1.13.2 # Not needed for basic setup
|
| 23 |
|
| 24 |
# Language processing
|
| 25 |
sentencepiece==0.1.99
|
| 26 |
langdetect==1.0.9
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
# System monitoring
|
| 32 |
psutil==5.9.6
|
| 33 |
|
| 34 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI Medical AI - Requirements
|
| 2 |
# Core FastAPI dependencies
|
| 3 |
fastapi==0.104.1
|
| 4 |
uvicorn[standard]==0.24.0
|
| 5 |
python-multipart==0.0.6
|
| 6 |
pydantic==2.4.2
|
| 7 |
|
| 8 |
+
# ML and AI models
|
| 9 |
transformers==4.35.2
|
| 10 |
+
torch==2.1.0
|
| 11 |
sentence-transformers==2.2.2
|
| 12 |
faiss-cpu==1.7.4
|
| 13 |
+
faster-whisper==0.9.0
|
| 14 |
|
| 15 |
+
# Audio processing
|
| 16 |
librosa==0.10.1
|
| 17 |
soundfile==0.12.1
|
| 18 |
numpy==1.24.3
|
| 19 |
|
| 20 |
+
# HTTP requests for testing
|
| 21 |
+
requests==2.31.0
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Language processing
|
| 24 |
sentencepiece==0.1.99
|
| 25 |
langdetect==1.0.9
|
| 26 |
|
| 27 |
+
# Performance optimizations
|
| 28 |
+
accelerate==0.24.1
|
| 29 |
+
optimum==1.13.2
|
| 30 |
|
| 31 |
# System monitoring
|
| 32 |
psutil==5.9.6
|
| 33 |
|
| 34 |
+
# Development and testing
|
| 35 |
+
pytest==7.4.3
|
| 36 |
+
pytest-asyncio==0.21.1
|
| 37 |
+
|
| 38 |
+
# Optional: For production deployment
|
| 39 |
+
gunicorn==21.2.0
|