jjj
Browse files- README.md +67 -5
- app.py +25 -18
- audio_utils.py +0 -162
- fastapi_app.py +126 -413
- medical_ai.py +150 -369
- requirements.txt +11 -16
README.md
CHANGED
|
@@ -1,12 +1,74 @@
|
|
| 1 |
---
|
| 2 |
-
title: Medical
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: green
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
-
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Medical AI Assistant
|
| 3 |
+
emoji: 🩺
|
| 4 |
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
+
short_description: Multilingual medical consultation AI assistant
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 🩺 Medical AI Assistant
|
| 14 |
+
|
| 15 |
+
A multilingual medical consultation AI assistant optimized for Hugging Face Spaces.
|
| 16 |
+
|
| 17 |
+
## ✨ Features
|
| 18 |
+
|
| 19 |
+
- 🌍 **Multilingual Support**: English and French medical consultations
|
| 20 |
+
- 🧠 **Medical Knowledge**: AI-powered responses based on medical literature
|
| 21 |
+
- ⚡ **Fast Processing**: Optimized for quick responses
|
| 22 |
+
- 📚 **Educational Focus**: Provides educational medical information with proper disclaimers
|
| 23 |
+
|
| 24 |
+
## 🚀 Quick Start
|
| 25 |
+
|
| 26 |
+
1. Visit the **API Documentation** at `/docs` to explore all endpoints
|
| 27 |
+
2. Try the main endpoint: `POST /medical/ask`
|
| 28 |
+
3. Use demo questions from `/medical/demo`
|
| 29 |
+
|
| 30 |
+
## 📋 Example Usage
|
| 31 |
+
|
| 32 |
+
### English Consultation
|
| 33 |
+
```json
|
| 34 |
+
{
|
| 35 |
+
"question": "What are the symptoms of malaria?",
|
| 36 |
+
"language": "en"
|
| 37 |
+
}
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### French Consultation
|
| 41 |
+
```json
|
| 42 |
+
{
|
| 43 |
+
"question": "Quels sont les symptômes du paludisme?",
|
| 44 |
+
"language": "fr"
|
| 45 |
+
}
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## 🔧 API Endpoints
|
| 49 |
+
|
| 50 |
+
- `GET /` - Welcome message and API info
|
| 51 |
+
- `POST /medical/ask` - Medical consultation endpoint
|
| 52 |
+
- `GET /health` - System health check
|
| 53 |
+
- `GET /medical/demo` - Sample demo questions
|
| 54 |
+
- `GET /docs` - Interactive API documentation
|
| 55 |
+
- `GET /redoc` - Alternative API documentation
|
| 56 |
+
|
| 57 |
+
## ⚕️ Medical Disclaimer
|
| 58 |
+
|
| 59 |
+
This AI assistant provides educational medical information only. Always consult qualified healthcare professionals for proper medical diagnosis and treatment.
|
| 60 |
+
|
| 61 |
+
## 🏗️ Technical Details
|
| 62 |
+
|
| 63 |
+
- **Framework**: FastAPI
|
| 64 |
+
- **AI Models**: Lightweight transformers optimized for Spaces
|
| 65 |
+
- **Languages**: Python 3.9+
|
| 66 |
+
- **Deployment**: Docker on Hugging Face Spaces
|
| 67 |
+
|
| 68 |
+
## 📞 Support
|
| 69 |
+
|
| 70 |
+
For technical issues or questions, please check the API documentation at `/docs`.
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
*Optimized for Hugging Face Spaces deployment*
|
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Medical AI Assistant -
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
@@ -20,22 +20,22 @@ logging.basicConfig(
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def setup_environment():
|
| 23 |
-
"""Setup environment variables for
|
| 24 |
# Set environment variables for optimal performance
|
| 25 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 26 |
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
|
| 27 |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
|
| 28 |
|
| 29 |
-
#
|
| 30 |
os.environ.setdefault("HOST", "0.0.0.0")
|
| 31 |
-
os.environ.setdefault("PORT", "
|
| 32 |
|
| 33 |
-
logger.info("✅ Environment configured for
|
| 34 |
|
| 35 |
def main():
|
| 36 |
"""Main application entry point"""
|
| 37 |
try:
|
| 38 |
-
logger.info("🩺 Starting Medical AI Assistant -
|
| 39 |
|
| 40 |
# Setup environment
|
| 41 |
setup_environment()
|
|
@@ -44,22 +44,22 @@ def main():
|
|
| 44 |
from fastapi_app import app
|
| 45 |
import uvicorn
|
| 46 |
|
| 47 |
-
# Get port from environment
|
| 48 |
-
port = int(os.getenv("PORT",
|
| 49 |
host = os.getenv("HOST", "0.0.0.0")
|
| 50 |
|
| 51 |
logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
|
| 52 |
-
logger.info(f"📚 API Documentation
|
| 53 |
-
logger.info(f"🔄 Alternative docs at: http://{host}:{port}/redoc")
|
| 54 |
|
| 55 |
-
# Launch the FastAPI application
|
| 56 |
uvicorn.run(
|
| 57 |
app,
|
| 58 |
host=host,
|
| 59 |
port=port,
|
| 60 |
log_level="info",
|
| 61 |
-
reload=False, #
|
| 62 |
-
access_log=True
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
except KeyboardInterrupt:
|
|
@@ -71,16 +71,23 @@ def main():
|
|
| 71 |
# For direct import (Hugging Face Spaces compatibility)
|
| 72 |
try:
|
| 73 |
from fastapi_app import app
|
| 74 |
-
logger.info("✅ FastAPI app imported successfully")
|
| 75 |
except ImportError as e:
|
| 76 |
logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
|
| 77 |
# Create minimal fallback app
|
| 78 |
from fastapi import FastAPI
|
| 79 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
@app.get("/")
|
| 82 |
-
async def
|
| 83 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
main()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Medical AI Assistant - Hugging Face Spaces Version
|
| 4 |
+
Optimized for Spaces deployment
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def setup_environment():
|
| 23 |
+
"""Setup environment variables for Spaces deployment"""
|
| 24 |
# Set environment variables for optimal performance
|
| 25 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 26 |
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
|
| 27 |
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
|
| 28 |
|
| 29 |
+
# Spaces specific - use port 7860
|
| 30 |
os.environ.setdefault("HOST", "0.0.0.0")
|
| 31 |
+
os.environ.setdefault("PORT", "7860") # Changed to Spaces default
|
| 32 |
|
| 33 |
+
logger.info("✅ Environment configured for Hugging Face Spaces")
|
| 34 |
|
| 35 |
def main():
|
| 36 |
"""Main application entry point"""
|
| 37 |
try:
|
| 38 |
+
logger.info("🩺 Starting Medical AI Assistant - Spaces Edition")
|
| 39 |
|
| 40 |
# Setup environment
|
| 41 |
setup_environment()
|
|
|
|
| 44 |
from fastapi_app import app
|
| 45 |
import uvicorn
|
| 46 |
|
| 47 |
+
# Get port from environment - Spaces uses 7860
|
| 48 |
+
port = int(os.getenv("PORT", 7860))
|
| 49 |
host = os.getenv("HOST", "0.0.0.0")
|
| 50 |
|
| 51 |
logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
|
| 52 |
+
logger.info(f"📚 API Documentation available at: http://{host}:{port}/docs")
|
|
|
|
| 53 |
|
| 54 |
+
# Launch the FastAPI application with Spaces-optimized settings
|
| 55 |
uvicorn.run(
|
| 56 |
app,
|
| 57 |
host=host,
|
| 58 |
port=port,
|
| 59 |
log_level="info",
|
| 60 |
+
reload=False, # Never reload in production
|
| 61 |
+
access_log=True,
|
| 62 |
+
workers=1 # Single worker for Spaces
|
| 63 |
)
|
| 64 |
|
| 65 |
except KeyboardInterrupt:
|
|
|
|
| 71 |
# For direct import (Hugging Face Spaces compatibility)
|
| 72 |
try:
|
| 73 |
from fastapi_app import app
|
| 74 |
+
logger.info("✅ FastAPI app imported successfully for Spaces")
|
| 75 |
except ImportError as e:
|
| 76 |
logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
|
| 77 |
# Create minimal fallback app
|
| 78 |
from fastapi import FastAPI
|
| 79 |
+
app = FastAPI(
|
| 80 |
+
title="Medical AI - Loading",
|
| 81 |
+
description="Medical AI Assistant is starting up..."
|
| 82 |
+
)
|
| 83 |
|
| 84 |
@app.get("/")
|
| 85 |
+
async def loading_root():
|
| 86 |
+
return {
|
| 87 |
+
"message": "🩺 Medical AI Assistant is starting up...",
|
| 88 |
+
"status": "loading",
|
| 89 |
+
"docs": "/docs"
|
| 90 |
+
}
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
main()
|
audio_utils.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import os
|
| 3 |
-
import numpy as np
|
| 4 |
-
import librosa
|
| 5 |
-
import soundfile as sf
|
| 6 |
-
from typing import Union
|
| 7 |
-
import logging
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
def preprocess_audio(audio_data: Union[bytes, str], target_sr: int = 16000) -> np.ndarray:
|
| 12 |
-
"""
|
| 13 |
-
Preprocess audio data for Whisper transcription.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
audio_data: Audio data as bytes or file path
|
| 17 |
-
target_sr: Target sample rate for Whisper (16kHz)
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
numpy array of preprocessed audio
|
| 21 |
-
"""
|
| 22 |
-
try:
|
| 23 |
-
if isinstance(audio_data, bytes):
|
| 24 |
-
# Load audio from bytes
|
| 25 |
-
audio_buffer = io.BytesIO(audio_data)
|
| 26 |
-
audio, sr = sf.read(audio_buffer)
|
| 27 |
-
logger.info(f"Loaded audio from bytes: {len(audio)} samples at {sr}Hz")
|
| 28 |
-
elif isinstance(audio_data, str):
|
| 29 |
-
# Load audio from file path
|
| 30 |
-
if not os.path.exists(audio_data):
|
| 31 |
-
logger.error(f"Audio file not found: {audio_data}")
|
| 32 |
-
return np.array([], dtype=np.float32)
|
| 33 |
-
|
| 34 |
-
audio, sr = librosa.load(audio_data, sr=None)
|
| 35 |
-
logger.info(f"Loaded audio from file {audio_data}: {len(audio)} samples at {sr}Hz")
|
| 36 |
-
else:
|
| 37 |
-
logger.error(f"Unsupported audio data type: {type(audio_data)}")
|
| 38 |
-
return np.array([], dtype=np.float32)
|
| 39 |
-
|
| 40 |
-
# Handle empty audio
|
| 41 |
-
if len(audio) == 0:
|
| 42 |
-
logger.warning("Empty audio data received")
|
| 43 |
-
return np.array([], dtype=np.float32)
|
| 44 |
-
|
| 45 |
-
# Convert to mono if stereo
|
| 46 |
-
if len(audio.shape) > 1:
|
| 47 |
-
audio = librosa.to_mono(audio)
|
| 48 |
-
logger.info("Converted stereo to mono")
|
| 49 |
-
|
| 50 |
-
# Resample to target sample rate if needed
|
| 51 |
-
if sr != target_sr:
|
| 52 |
-
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
| 53 |
-
logger.info(f"Resampled from {sr}Hz to {target_sr}Hz")
|
| 54 |
-
|
| 55 |
-
# Normalize audio to [-1, 1] range
|
| 56 |
-
if np.max(np.abs(audio)) > 0:
|
| 57 |
-
audio = audio / np.max(np.abs(audio))
|
| 58 |
-
|
| 59 |
-
# Convert to float32
|
| 60 |
-
audio = audio.astype(np.float32)
|
| 61 |
-
|
| 62 |
-
# Remove silence from beginning and end (optional but helpful)
|
| 63 |
-
try:
|
| 64 |
-
audio, _ = librosa.effects.trim(audio, top_db=20)
|
| 65 |
-
logger.info(f"Trimmed silence, final length: {len(audio)} samples")
|
| 66 |
-
except Exception as e:
|
| 67 |
-
logger.warning(f"Could not trim silence: {str(e)}")
|
| 68 |
-
|
| 69 |
-
# Ensure minimum length (avoid very short clips)
|
| 70 |
-
min_samples = int(0.1 * target_sr) # 0.1 second minimum
|
| 71 |
-
if len(audio) < min_samples:
|
| 72 |
-
logger.warning(f"Audio too short ({len(audio)} samples), padding to {min_samples}")
|
| 73 |
-
audio = np.pad(audio, (0, min_samples - len(audio)), mode='constant')
|
| 74 |
-
|
| 75 |
-
logger.info(f"Audio preprocessing completed: {len(audio)} samples at {target_sr}Hz")
|
| 76 |
-
return audio
|
| 77 |
-
|
| 78 |
-
except Exception as e:
|
| 79 |
-
logger.error(f"Error preprocessing audio: {str(e)}", exc_info=True)
|
| 80 |
-
# Return empty audio array as fallback
|
| 81 |
-
return np.array([], dtype=np.float32)
|
| 82 |
-
|
| 83 |
-
def validate_audio_format(audio_path: str) -> bool:
|
| 84 |
-
"""
|
| 85 |
-
Validate if audio file format is supported
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
audio_path: Path to audio file
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
True if format is supported, False otherwise
|
| 92 |
-
"""
|
| 93 |
-
try:
|
| 94 |
-
if not os.path.exists(audio_path):
|
| 95 |
-
return False
|
| 96 |
-
|
| 97 |
-
# Get file extension
|
| 98 |
-
_, ext = os.path.splitext(audio_path.lower())
|
| 99 |
-
|
| 100 |
-
# Supported formats
|
| 101 |
-
supported_formats = ['.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac']
|
| 102 |
-
|
| 103 |
-
return ext in supported_formats
|
| 104 |
-
|
| 105 |
-
except Exception as e:
|
| 106 |
-
logger.error(f"Error validating audio format: {str(e)}")
|
| 107 |
-
return False
|
| 108 |
-
|
| 109 |
-
def get_audio_info(audio_path: str) -> dict:
|
| 110 |
-
"""
|
| 111 |
-
Get information about audio file
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
audio_path: Path to audio file
|
| 115 |
-
|
| 116 |
-
Returns:
|
| 117 |
-
Dictionary with audio information
|
| 118 |
-
"""
|
| 119 |
-
try:
|
| 120 |
-
if not os.path.exists(audio_path):
|
| 121 |
-
return {"error": "File not found"}
|
| 122 |
-
|
| 123 |
-
# Load audio metadata
|
| 124 |
-
info = sf.info(audio_path)
|
| 125 |
-
|
| 126 |
-
return {
|
| 127 |
-
"duration": info.duration,
|
| 128 |
-
"sample_rate": info.samplerate,
|
| 129 |
-
"channels": info.channels,
|
| 130 |
-
"format": info.format,
|
| 131 |
-
"subtype": info.subtype,
|
| 132 |
-
"frames": info.frames
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
except Exception as e:
|
| 136 |
-
logger.error(f"Error getting audio info: {str(e)}")
|
| 137 |
-
return {"error": str(e)}
|
| 138 |
-
|
| 139 |
-
def enhance_audio_quality(audio: np.ndarray, sr: int = 16000) -> np.ndarray:
|
| 140 |
-
"""
|
| 141 |
-
Apply basic audio enhancement for better transcription
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
audio: Audio signal
|
| 145 |
-
sr: Sample rate
|
| 146 |
-
|
| 147 |
-
Returns:
|
| 148 |
-
Enhanced audio signal
|
| 149 |
-
"""
|
| 150 |
-
try:
|
| 151 |
-
# Apply high-pass filter to remove low-frequency noise
|
| 152 |
-
audio_filtered = librosa.effects.preemphasis(audio)
|
| 153 |
-
|
| 154 |
-
# Normalize again after filtering
|
| 155 |
-
if np.max(np.abs(audio_filtered)) > 0:
|
| 156 |
-
audio_filtered = audio_filtered / np.max(np.abs(audio_filtered))
|
| 157 |
-
|
| 158 |
-
return audio_filtered.astype(np.float32)
|
| 159 |
-
|
| 160 |
-
except Exception as e:
|
| 161 |
-
logger.warning(f"Audio enhancement failed: {str(e)}")
|
| 162 |
-
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastapi_app.py
CHANGED
|
@@ -1,20 +1,17 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Medical AI Assistant - FastAPI
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from fastapi.responses import JSONResponse
|
| 10 |
-
from fastapi.openapi.docs import get_swagger_ui_html
|
| 11 |
-
from fastapi.openapi.utils import get_openapi
|
| 12 |
from pydantic import BaseModel, Field
|
| 13 |
-
from typing import List, Optional, Dict, Any
|
| 14 |
import logging
|
| 15 |
import uuid
|
| 16 |
import os
|
| 17 |
-
import json
|
| 18 |
import asyncio
|
| 19 |
from contextlib import asynccontextmanager
|
| 20 |
import time
|
|
@@ -28,37 +25,23 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
|
| 29 |
# Initialize models globally
|
| 30 |
pipeline = None
|
| 31 |
-
whisper_model = None
|
| 32 |
|
| 33 |
async def load_models():
|
| 34 |
"""Load ML models asynchronously"""
|
| 35 |
-
global pipeline
|
| 36 |
try:
|
| 37 |
-
logger.info("Loading Medical AI models...")
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
logger.info("✅ Medical pipeline loaded successfully")
|
| 42 |
|
| 43 |
-
|
| 44 |
-
from faster_whisper import WhisperModel
|
| 45 |
-
model_cache = os.getenv('HF_HOME', '/tmp/models')
|
| 46 |
-
whisper_model = WhisperModel(
|
| 47 |
-
"medium",
|
| 48 |
-
device="cpu",
|
| 49 |
-
compute_type="int8",
|
| 50 |
-
download_root=model_cache
|
| 51 |
-
)
|
| 52 |
-
logger.info("✅ Whisper model loaded successfully")
|
| 53 |
-
except Exception as e:
|
| 54 |
-
logger.warning(f"⚠️ Could not load Whisper model: {str(e)}")
|
| 55 |
-
whisper_model = None
|
| 56 |
-
|
| 57 |
-
logger.info("🚀 All models loaded successfully")
|
| 58 |
|
| 59 |
except Exception as e:
|
| 60 |
logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
|
| 61 |
-
raise
|
| 62 |
|
| 63 |
@asynccontextmanager
|
| 64 |
async def lifespan(app: FastAPI):
|
|
@@ -68,496 +51,211 @@ async def lifespan(app: FastAPI):
|
|
| 68 |
yield
|
| 69 |
except Exception as e:
|
| 70 |
logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
|
| 71 |
-
|
|
|
|
| 72 |
finally:
|
| 73 |
logger.info("🔄 Shutting down...")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
def custom_openapi():
|
| 77 |
-
if app.openapi_schema:
|
| 78 |
-
return app.openapi_schema
|
| 79 |
-
|
| 80 |
-
openapi_schema = get_openapi(
|
| 81 |
-
title="🩺 Medical AI Assistant API",
|
| 82 |
-
version="2.0.0",
|
| 83 |
-
description="""
|
| 84 |
-
## 🎯 Advanced Medical AI Assistant
|
| 85 |
-
|
| 86 |
-
**Multilingual medical consultation API** supporting:
|
| 87 |
-
- 🌍 French, English, and local African languages
|
| 88 |
-
- 🎤 Audio processing with speech-to-text
|
| 89 |
-
- 🧠 Advanced medical knowledge retrieval
|
| 90 |
-
- ⚡ Real-time medical consultations
|
| 91 |
-
|
| 92 |
-
### 🔧 Main Endpoints:
|
| 93 |
-
- **POST /medical/ask** - Text-based medical consultation
|
| 94 |
-
- **POST /medical/audio** - Audio-based medical consultation
|
| 95 |
-
- **GET /health** - System health check
|
| 96 |
-
- **POST /feedback** - Submit user feedback
|
| 97 |
-
|
| 98 |
-
### 🔒 Important Medical Disclaimer:
|
| 99 |
-
This API provides educational medical information only. Always consult qualified healthcare professionals for medical advice.
|
| 100 |
-
""",
|
| 101 |
-
routes=app.routes,
|
| 102 |
-
contact={
|
| 103 |
-
"name": "Medical AI Support",
|
| 104 |
-
"email": "support@medicalai.com"
|
| 105 |
-
},
|
| 106 |
-
license_info={
|
| 107 |
-
"name": "MIT License",
|
| 108 |
-
"url": "https://opensource.org/licenses/MIT"
|
| 109 |
-
}
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
# Add custom tags
|
| 113 |
-
openapi_schema["tags"] = [
|
| 114 |
-
{
|
| 115 |
-
"name": "medical",
|
| 116 |
-
"description": "Medical consultation endpoints"
|
| 117 |
-
},
|
| 118 |
-
{
|
| 119 |
-
"name": "audio",
|
| 120 |
-
"description": "Audio processing endpoints"
|
| 121 |
-
},
|
| 122 |
-
{
|
| 123 |
-
"name": "system",
|
| 124 |
-
"description": "System monitoring and health"
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"name": "feedback",
|
| 128 |
-
"description": "User feedback and analytics"
|
| 129 |
-
}
|
| 130 |
-
]
|
| 131 |
-
|
| 132 |
-
app.openapi_schema = openapi_schema
|
| 133 |
-
return app.openapi_schema
|
| 134 |
-
|
| 135 |
-
# Initialize FastAPI app
|
| 136 |
app = FastAPI(
|
| 137 |
title="🩺 Medical AI Assistant",
|
| 138 |
-
description="
|
| 139 |
-
version="2.0.0",
|
| 140 |
lifespan=lifespan,
|
| 141 |
docs_url="/docs",
|
| 142 |
-
redoc_url="/redoc"
|
| 143 |
-
openapi_url="/openapi.json"
|
| 144 |
)
|
| 145 |
|
| 146 |
-
# Set custom OpenAPI
|
| 147 |
-
app.openapi = custom_openapi
|
| 148 |
-
|
| 149 |
# CORS middleware
|
| 150 |
app.add_middleware(
|
| 151 |
CORSMiddleware,
|
| 152 |
allow_origins=["*"],
|
| 153 |
allow_credentials=True,
|
| 154 |
allow_methods=["*"],
|
| 155 |
-
allow_headers=["*"]
|
| 156 |
-
expose_headers=["*"]
|
| 157 |
)
|
| 158 |
|
| 159 |
# ============================================================================
|
| 160 |
-
# PYDANTIC MODELS
|
| 161 |
# ============================================================================
|
| 162 |
|
| 163 |
class MedicalQuestion(BaseModel):
|
| 164 |
"""Medical question request model"""
|
| 165 |
-
question: str = Field(..., description="The medical question", min_length=3, max_length=
|
| 166 |
-
language: str = Field("auto", description="
|
| 167 |
-
conversation_id: Optional[str] = Field(None, description="Optional conversation ID for context")
|
| 168 |
|
| 169 |
class Config:
|
| 170 |
schema_extra = {
|
| 171 |
"example": {
|
| 172 |
-
"question": "What are the symptoms of malaria
|
| 173 |
-
"language": "en"
|
| 174 |
-
"conversation_id": "conv_123"
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
| 178 |
class MedicalResponse(BaseModel):
|
| 179 |
"""Medical response model"""
|
| 180 |
-
success: bool = Field(..., description="
|
| 181 |
-
response: str = Field(..., description="
|
| 182 |
-
detected_language: str = Field(..., description="Detected
|
| 183 |
-
|
| 184 |
-
context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
|
| 185 |
-
processing_time: float = Field(..., description="Response time in seconds")
|
| 186 |
-
confidence: str = Field(..., description="Response confidence level")
|
| 187 |
|
| 188 |
class Config:
|
| 189 |
schema_extra = {
|
| 190 |
"example": {
|
| 191 |
"success": True,
|
| 192 |
-
"response": "Malaria symptoms include
|
| 193 |
"detected_language": "en",
|
| 194 |
-
"
|
| 195 |
-
"context_used": ["Malaria treatment protocols", "Symptom guidelines"],
|
| 196 |
-
"processing_time": 2.5,
|
| 197 |
-
"confidence": "high"
|
| 198 |
-
}
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
class AudioResponse(BaseModel):
|
| 202 |
-
"""Audio processing response model"""
|
| 203 |
-
success: bool = Field(..., description="Whether the request was successful")
|
| 204 |
-
transcription: str = Field(..., description="Transcribed text from audio")
|
| 205 |
-
response: str = Field(..., description="The medical response")
|
| 206 |
-
detected_language: str = Field(..., description="Detected audio language")
|
| 207 |
-
conversation_id: str = Field(..., description="Conversation identifier")
|
| 208 |
-
context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
|
| 209 |
-
processing_time: float = Field(..., description="Response time in seconds")
|
| 210 |
-
audio_duration: Optional[float] = Field(None, description="Audio duration in seconds")
|
| 211 |
-
|
| 212 |
-
class Config:
|
| 213 |
-
schema_extra = {
|
| 214 |
-
"example": {
|
| 215 |
-
"success": True,
|
| 216 |
-
"transcription": "What are the symptoms of malaria?",
|
| 217 |
-
"response": "Malaria symptoms include high fever, chills...",
|
| 218 |
-
"detected_language": "en",
|
| 219 |
-
"conversation_id": "conv_456",
|
| 220 |
-
"context_used": ["Malaria diagnosis"],
|
| 221 |
-
"processing_time": 3.2,
|
| 222 |
-
"audio_duration": 4.5
|
| 223 |
-
}
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
class FeedbackRequest(BaseModel):
|
| 227 |
-
"""Feedback request model"""
|
| 228 |
-
conversation_id: str = Field(..., description="Conversation ID")
|
| 229 |
-
rating: int = Field(..., description="Rating from 1-5", ge=1, le=5)
|
| 230 |
-
feedback: Optional[str] = Field(None, description="Optional text feedback", max_length=500)
|
| 231 |
-
|
| 232 |
-
class Config:
|
| 233 |
-
schema_extra = {
|
| 234 |
-
"example": {
|
| 235 |
-
"conversation_id": "conv_123",
|
| 236 |
-
"rating": 5,
|
| 237 |
-
"feedback": "Very helpful and accurate medical information"
|
| 238 |
}
|
| 239 |
}
|
| 240 |
|
| 241 |
class HealthStatus(BaseModel):
|
| 242 |
-
"""System health status
|
| 243 |
-
status: str = Field(..., description="
|
| 244 |
-
models_loaded: bool = Field(..., description="
|
| 245 |
-
audio_available: bool = Field(..., description="Whether audio processing is available")
|
| 246 |
-
uptime: float = Field(..., description="System uptime in seconds")
|
| 247 |
version: str = Field(..., description="API version")
|
| 248 |
-
|
| 249 |
-
class Config:
|
| 250 |
-
schema_extra = {
|
| 251 |
-
"example": {
|
| 252 |
-
"status": "healthy",
|
| 253 |
-
"models_loaded": True,
|
| 254 |
-
"audio_available": True,
|
| 255 |
-
"uptime": 3600.0,
|
| 256 |
-
"version": "2.0.0"
|
| 257 |
-
}
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
class ErrorResponse(BaseModel):
|
| 261 |
-
"""Error response model"""
|
| 262 |
-
success: bool = Field(False, description="Always false for errors")
|
| 263 |
-
error: str = Field(..., description="Error message")
|
| 264 |
-
error_code: str = Field(..., description="Error code")
|
| 265 |
-
conversation_id: Optional[str] = Field(None, description="Conversation ID if available")
|
| 266 |
-
|
| 267 |
-
# ============================================================================
|
| 268 |
-
# UTILITY FUNCTIONS
|
| 269 |
-
# ============================================================================
|
| 270 |
-
|
| 271 |
-
def generate_conversation_id() -> str:
|
| 272 |
-
"""Generate a unique conversation ID"""
|
| 273 |
-
return f"conv_{uuid.uuid4().hex[:8]}"
|
| 274 |
|
| 275 |
def validate_models():
|
| 276 |
"""Check if models are loaded"""
|
| 277 |
if pipeline is None:
|
| 278 |
raise HTTPException(
|
| 279 |
status_code=503,
|
| 280 |
-
detail="Medical AI models are
|
| 281 |
)
|
| 282 |
|
| 283 |
# ============================================================================
|
| 284 |
-
# API ENDPOINTS
|
| 285 |
# ============================================================================
|
| 286 |
|
| 287 |
@app.get("/", tags=["system"])
|
| 288 |
async def root():
|
| 289 |
-
"""Root endpoint
|
| 290 |
return {
|
| 291 |
-
"message": "🩺 Medical AI Assistant
|
| 292 |
-
"version": "2.0.0",
|
| 293 |
"status": "running",
|
| 294 |
"docs": "/docs",
|
| 295 |
-
"redoc": "/redoc",
|
| 296 |
"endpoints": {
|
| 297 |
"medical_consultation": "/medical/ask",
|
| 298 |
-
"
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
}
|
| 302 |
}
|
| 303 |
|
| 304 |
@app.get("/health", response_model=HealthStatus, tags=["system"])
|
| 305 |
async def health_check():
|
| 306 |
-
"""
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
Returns the current status of the Medical AI system including:
|
| 310 |
-
- Overall system health
|
| 311 |
-
- Model loading status
|
| 312 |
-
- Audio processing availability
|
| 313 |
-
- System uptime
|
| 314 |
-
"""
|
| 315 |
-
global pipeline, whisper_model
|
| 316 |
-
|
| 317 |
-
# Calculate uptime (simplified)
|
| 318 |
-
uptime = time.time() - getattr(health_check, 'start_time', time.time())
|
| 319 |
-
if not hasattr(health_check, 'start_time'):
|
| 320 |
-
health_check.start_time = time.time()
|
| 321 |
|
| 322 |
return HealthStatus(
|
| 323 |
status="healthy" if pipeline is not None else "loading",
|
| 324 |
models_loaded=pipeline is not None,
|
| 325 |
-
|
| 326 |
-
uptime=uptime,
|
| 327 |
-
version="2.0.0"
|
| 328 |
)
|
| 329 |
|
| 330 |
@app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
|
| 331 |
async def medical_consultation(request: MedicalQuestion):
|
| 332 |
"""
|
| 333 |
-
##
|
| 334 |
|
| 335 |
-
|
| 336 |
|
| 337 |
**Features:**
|
| 338 |
-
- 🌍 Multilingual support (
|
| 339 |
-
- 🧠
|
| 340 |
-
- ⚡
|
| 341 |
-
- 🔒 Medical disclaimers included
|
| 342 |
-
|
| 343 |
-
**Supported Languages:** English (en), French (fr), Auto-detect (auto)
|
| 344 |
"""
|
| 345 |
start_time = time.time()
|
| 346 |
-
validate_models()
|
| 347 |
-
|
| 348 |
-
conversation_id = request.conversation_id or generate_conversation_id()
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
conversation_history=[]
|
| 358 |
-
)
|
| 359 |
|
| 360 |
-
|
| 361 |
|
| 362 |
return MedicalResponse(
|
| 363 |
success=True,
|
| 364 |
-
response=
|
| 365 |
-
detected_language=
|
| 366 |
-
|
| 367 |
-
context_used=result.get("context_used", []),
|
| 368 |
-
processing_time=round(processing_time, 2),
|
| 369 |
-
confidence=result.get("confidence", "medium")
|
| 370 |
)
|
| 371 |
-
|
| 372 |
-
except Exception as e:
|
| 373 |
-
logger.error(f"❌ Error in medical consultation: {str(e)}", exc_info=True)
|
| 374 |
-
processing_time = time.time() - start_time
|
| 375 |
-
|
| 376 |
-
raise HTTPException(
|
| 377 |
-
status_code=500,
|
| 378 |
-
detail={
|
| 379 |
-
"success": False,
|
| 380 |
-
"error": "Internal processing error occurred",
|
| 381 |
-
"error_code": "MEDICAL_PROCESSING_ERROR",
|
| 382 |
-
"conversation_id": conversation_id,
|
| 383 |
-
"processing_time": round(processing_time, 2)
|
| 384 |
-
}
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
@app.post("/medical/audio", response_model=AudioResponse, tags=["audio", "medical"])
|
| 388 |
-
async def audio_medical_consultation(
|
| 389 |
-
file: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, etc.)")
|
| 390 |
-
):
|
| 391 |
-
"""
|
| 392 |
-
## Audio-based Medical Consultation
|
| 393 |
-
|
| 394 |
-
Process an audio medical question and return expert medical guidance.
|
| 395 |
-
|
| 396 |
-
**Features:**
|
| 397 |
-
- 🎤 Speech-to-text conversion
|
| 398 |
-
- 🌍 Language detection from audio
|
| 399 |
-
- 🧠 Medical AI processing of transcribed text
|
| 400 |
-
- 📝 Full transcription provided
|
| 401 |
-
|
| 402 |
-
**Supported Audio Formats:** WAV, MP3, M4A, FLAC, OGG
|
| 403 |
-
**Max File Size:** 25MB
|
| 404 |
-
**Max Duration:** 5 minutes
|
| 405 |
-
"""
|
| 406 |
-
start_time = time.time()
|
| 407 |
-
validate_models()
|
| 408 |
-
|
| 409 |
-
if whisper_model is None:
|
| 410 |
-
raise HTTPException(
|
| 411 |
-
status_code=503,
|
| 412 |
-
detail="Audio processing is currently unavailable"
|
| 413 |
-
)
|
| 414 |
-
|
| 415 |
-
conversation_id = generate_conversation_id()
|
| 416 |
|
| 417 |
try:
|
| 418 |
-
logger.info(f"
|
| 419 |
-
|
| 420 |
-
# Read audio file
|
| 421 |
-
file_bytes = await file.read()
|
| 422 |
-
|
| 423 |
-
# Process audio
|
| 424 |
-
from audio_utils import preprocess_audio
|
| 425 |
-
processed_audio = preprocess_audio(file_bytes)
|
| 426 |
-
|
| 427 |
-
if len(processed_audio) == 0:
|
| 428 |
-
raise HTTPException(
|
| 429 |
-
status_code=400,
|
| 430 |
-
detail="Could not process audio file. Please check the format and try again."
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
-
# Transcribe audio
|
| 434 |
-
segments, info = whisper_model.transcribe(
|
| 435 |
-
processed_audio,
|
| 436 |
-
beam_size=5,
|
| 437 |
-
language=None,
|
| 438 |
-
task='transcribe',
|
| 439 |
-
vad_filter=True
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
transcription = "".join([seg.text for seg in segments])
|
| 443 |
-
detected_language = info.language
|
| 444 |
-
|
| 445 |
-
if not transcription.strip():
|
| 446 |
-
raise HTTPException(
|
| 447 |
-
status_code=400,
|
| 448 |
-
detail="Could not transcribe audio. Please ensure clear speech and try again."
|
| 449 |
-
)
|
| 450 |
-
|
| 451 |
-
logger.info(f"🔤 Transcription: {transcription[:100]}...")
|
| 452 |
|
| 453 |
-
# Process
|
| 454 |
result = pipeline.process(
|
| 455 |
-
question=
|
| 456 |
-
user_lang=
|
| 457 |
conversation_history=[]
|
| 458 |
)
|
| 459 |
|
| 460 |
processing_time = time.time() - start_time
|
| 461 |
|
| 462 |
-
return
|
| 463 |
success=True,
|
| 464 |
-
transcription=transcription,
|
| 465 |
response=result["response"],
|
| 466 |
-
detected_language=
|
| 467 |
-
|
| 468 |
-
context_used=result.get("context_used", []),
|
| 469 |
-
processing_time=round(processing_time, 2),
|
| 470 |
-
audio_duration=len(processed_audio) / 16000 # Assuming 16kHz sample rate
|
| 471 |
)
|
| 472 |
|
| 473 |
-
except HTTPException:
|
| 474 |
-
raise
|
| 475 |
except Exception as e:
|
| 476 |
-
logger.error(f"❌ Error in
|
| 477 |
processing_time = time.time() - start_time
|
| 478 |
|
| 479 |
raise HTTPException(
|
| 480 |
status_code=500,
|
| 481 |
detail={
|
| 482 |
"success": False,
|
| 483 |
-
"error": "
|
| 484 |
-
"error_code": "AUDIO_PROCESSING_ERROR",
|
| 485 |
-
"conversation_id": conversation_id,
|
| 486 |
"processing_time": round(processing_time, 2)
|
| 487 |
}
|
| 488 |
)
|
| 489 |
|
| 490 |
-
@app.
|
| 491 |
-
async def
|
| 492 |
"""
|
| 493 |
-
##
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
**Rating Scale:**
|
| 498 |
-
- 1: Very Poor
|
| 499 |
-
- 2: Poor
|
| 500 |
-
- 3: Average
|
| 501 |
-
- 4: Good
|
| 502 |
-
- 5: Excellent
|
| 503 |
"""
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
"
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
except Exception as e:
|
| 523 |
-
logger.error(f"❌ Error processing feedback: {str(e)}")
|
| 524 |
-
raise HTTPException(
|
| 525 |
-
status_code=500,
|
| 526 |
-
detail="Error processing feedback"
|
| 527 |
-
)
|
| 528 |
|
| 529 |
-
@app.get("/medical/
|
| 530 |
-
async def
|
| 531 |
-
"""
|
| 532 |
-
## Get Supported Medical Specialties
|
| 533 |
-
|
| 534 |
-
Returns a list of medical specialties and conditions supported by the AI.
|
| 535 |
-
"""
|
| 536 |
return {
|
|
|
|
| 537 |
"specialties": [
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
},
|
| 543 |
-
{
|
| 544 |
-
"name": "Infectious Diseases",
|
| 545 |
-
"description": "Infectious disease diagnosis and treatment",
|
| 546 |
-
"conditions": ["Malaria", "Tuberculosis", "HIV/AIDS", "Respiratory infections"]
|
| 547 |
-
},
|
| 548 |
-
{
|
| 549 |
-
"name": "Emergency Medicine",
|
| 550 |
-
"description": "Emergency protocols and urgent care guidance",
|
| 551 |
-
"conditions": ["Stroke recognition", "Cardiac emergencies", "Trauma assessment"]
|
| 552 |
-
},
|
| 553 |
-
{
|
| 554 |
-
"name": "Chronic Disease Management",
|
| 555 |
-
"description": "Management of chronic conditions",
|
| 556 |
-
"conditions": ["Diabetes", "Hypertension", "Gastritis"]
|
| 557 |
-
}
|
| 558 |
],
|
| 559 |
-
"
|
| 560 |
-
"
|
| 561 |
}
|
| 562 |
|
| 563 |
# ============================================================================
|
|
@@ -571,13 +269,11 @@ async def not_found_handler(request, exc):
|
|
| 571 |
content={
|
| 572 |
"success": False,
|
| 573 |
"error": "Endpoint not found",
|
| 574 |
-
"error_code": "NOT_FOUND",
|
| 575 |
"available_endpoints": [
|
| 576 |
"/docs - API Documentation",
|
| 577 |
-
"/medical/ask -
|
| 578 |
-
"/medical/audio - Audio consultation",
|
| 579 |
"/health - System status",
|
| 580 |
-
"/
|
| 581 |
]
|
| 582 |
}
|
| 583 |
)
|
|
@@ -588,9 +284,19 @@ async def validation_exception_handler(request, exc):
|
|
| 588 |
status_code=422,
|
| 589 |
content={
|
| 590 |
"success": False,
|
| 591 |
-
"error": "Invalid request data",
|
| 592 |
-
"
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
}
|
| 595 |
)
|
| 596 |
|
|
@@ -598,17 +304,24 @@ async def validation_exception_handler(request, exc):
|
|
| 598 |
# STARTUP MESSAGE
|
| 599 |
# ============================================================================
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
if __name__ == "__main__":
|
| 602 |
import uvicorn
|
| 603 |
|
| 604 |
-
print("🩺 Starting Medical AI Assistant
|
| 605 |
-
print("📚 Documentation available at: http://localhost:
|
| 606 |
-
print("🔄 Alternative docs at: http://localhost:8000/redoc")
|
| 607 |
|
| 608 |
uvicorn.run(
|
| 609 |
app,
|
| 610 |
host="0.0.0.0",
|
| 611 |
-
port=
|
| 612 |
log_level="info",
|
| 613 |
reload=False
|
| 614 |
)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Medical AI Assistant - Lightweight FastAPI for Spaces
|
| 4 |
+
Optimized for Hugging Face Spaces deployment
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from fastapi.responses import JSONResponse
|
|
|
|
|
|
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
+
from typing import List, Optional, Dict, Any
|
| 12 |
import logging
|
| 13 |
import uuid
|
| 14 |
import os
|
|
|
|
| 15 |
import asyncio
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
import time
|
|
|
|
| 25 |
|
| 26 |
# Initialize models globally
|
| 27 |
pipeline = None
|
|
|
|
| 28 |
|
| 29 |
async def load_models():
|
| 30 |
"""Load ML models asynchronously"""
|
| 31 |
+
global pipeline
|
| 32 |
try:
|
| 33 |
+
logger.info("🔄 Loading Medical AI models...")
|
| 34 |
|
| 35 |
+
# Use the lightweight version
|
| 36 |
+
from medical_ai import SpacesMedicalAIPipeline
|
| 37 |
+
pipeline = SpacesMedicalAIPipeline()
|
| 38 |
logger.info("✅ Medical pipeline loaded successfully")
|
| 39 |
|
| 40 |
+
logger.info("🚀 All models ready for Spaces!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
except Exception as e:
|
| 43 |
logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
|
| 44 |
+
# Don't raise - let the app start with fallback responses
|
| 45 |
|
| 46 |
@asynccontextmanager
|
| 47 |
async def lifespan(app: FastAPI):
|
|
|
|
| 51 |
yield
|
| 52 |
except Exception as e:
|
| 53 |
logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
|
| 54 |
+
# Continue anyway for demo purposes
|
| 55 |
+
yield
|
| 56 |
finally:
|
| 57 |
logger.info("🔄 Shutting down...")
|
| 58 |
|
| 59 |
+
# Initialize FastAPI app - Simplified for Spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
app = FastAPI(
|
| 61 |
title="🩺 Medical AI Assistant",
|
| 62 |
+
description="Multilingual medical consultation API optimized for Hugging Face Spaces",
|
| 63 |
+
version="2.0.0-spaces",
|
| 64 |
lifespan=lifespan,
|
| 65 |
docs_url="/docs",
|
| 66 |
+
redoc_url="/redoc"
|
|
|
|
| 67 |
)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
# CORS middleware
|
| 70 |
app.add_middleware(
|
| 71 |
CORSMiddleware,
|
| 72 |
allow_origins=["*"],
|
| 73 |
allow_credentials=True,
|
| 74 |
allow_methods=["*"],
|
| 75 |
+
allow_headers=["*"]
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
# ============================================================================
|
| 79 |
+
# PYDANTIC MODELS - Simplified
|
| 80 |
# ============================================================================
|
| 81 |
|
| 82 |
class MedicalQuestion(BaseModel):
|
| 83 |
"""Medical question request model"""
|
| 84 |
+
question: str = Field(..., description="The medical question", min_length=3, max_length=500)
|
| 85 |
+
language: str = Field("auto", description="Language (auto, en, fr)")
|
|
|
|
| 86 |
|
| 87 |
class Config:
|
| 88 |
schema_extra = {
|
| 89 |
"example": {
|
| 90 |
+
"question": "What are the symptoms of malaria?",
|
| 91 |
+
"language": "en"
|
|
|
|
| 92 |
}
|
| 93 |
}
|
| 94 |
|
| 95 |
class MedicalResponse(BaseModel):
|
| 96 |
"""Medical response model"""
|
| 97 |
+
success: bool = Field(..., description="Success status")
|
| 98 |
+
response: str = Field(..., description="Medical response")
|
| 99 |
+
detected_language: str = Field(..., description="Detected language")
|
| 100 |
+
processing_time: float = Field(..., description="Processing time")
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
class Config:
|
| 103 |
schema_extra = {
|
| 104 |
"example": {
|
| 105 |
"success": True,
|
| 106 |
+
"response": "Malaria symptoms include fever, chills, headache...",
|
| 107 |
"detected_language": "en",
|
| 108 |
+
"processing_time": 2.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
}
|
| 111 |
|
| 112 |
class HealthStatus(BaseModel):
|
| 113 |
+
"""System health status"""
|
| 114 |
+
status: str = Field(..., description="System status")
|
| 115 |
+
models_loaded: bool = Field(..., description="Models loaded status")
|
|
|
|
|
|
|
| 116 |
version: str = Field(..., description="API version")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def validate_models():
|
| 119 |
"""Check if models are loaded"""
|
| 120 |
if pipeline is None:
|
| 121 |
raise HTTPException(
|
| 122 |
status_code=503,
|
| 123 |
+
detail="Medical AI models are still loading. Please try again in a moment."
|
| 124 |
)
|
| 125 |
|
| 126 |
# ============================================================================
|
| 127 |
+
# API ENDPOINTS - Simplified for Spaces
|
| 128 |
# ============================================================================
|
| 129 |
|
| 130 |
@app.get("/", tags=["system"])
|
| 131 |
async def root():
|
| 132 |
+
"""Root endpoint"""
|
| 133 |
return {
|
| 134 |
+
"message": "🩺 Medical AI Assistant - Spaces Edition",
|
| 135 |
+
"version": "2.0.0-spaces",
|
| 136 |
"status": "running",
|
| 137 |
"docs": "/docs",
|
|
|
|
| 138 |
"endpoints": {
|
| 139 |
"medical_consultation": "/medical/ask",
|
| 140 |
+
"health_check": "/health"
|
| 141 |
+
},
|
| 142 |
+
"demo_note": "Optimized for Hugging Face Spaces deployment"
|
|
|
|
| 143 |
}
|
| 144 |
|
| 145 |
@app.get("/health", response_model=HealthStatus, tags=["system"])
|
| 146 |
async def health_check():
|
| 147 |
+
"""System health check"""
|
| 148 |
+
global pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
return HealthStatus(
|
| 151 |
status="healthy" if pipeline is not None else "loading",
|
| 152 |
models_loaded=pipeline is not None,
|
| 153 |
+
version="2.0.0-spaces"
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
@app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
|
| 157 |
async def medical_consultation(request: MedicalQuestion):
|
| 158 |
"""
|
| 159 |
+
## Medical Consultation
|
| 160 |
|
| 161 |
+
Ask medical questions and get AI-powered responses.
|
| 162 |
|
| 163 |
**Features:**
|
| 164 |
+
- 🌍 Multilingual support (English, French)
|
| 165 |
+
- 🧠 Medical knowledge retrieval
|
| 166 |
+
- ⚡ Optimized for Spaces
|
|
|
|
|
|
|
|
|
|
| 167 |
"""
|
| 168 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
# Check if models are loaded - with graceful fallback
|
| 171 |
+
if pipeline is None:
|
| 172 |
+
logger.warning("Models not loaded, using fallback response")
|
| 173 |
+
processing_time = time.time() - start_time
|
| 174 |
|
| 175 |
+
fallback_responses = {
|
| 176 |
+
"en": "Medical AI is still initializing. For immediate medical concerns, please consult a healthcare professional. Common symptoms like fever, headache, or persistent pain should be evaluated by a doctor.",
|
| 177 |
+
"fr": "L'IA médicale s'initialise encore. Pour des préoccupations médicales immédiates, veuillez consulter un professionnel de santé. Les symptômes courants comme la fièvre, les maux de tête ou la douleur persistante doivent être évalués par un médecin."
|
| 178 |
+
}
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
detected_lang = "fr" if any(word in request.question.lower() for word in ["quoi", "comment", "pourquoi"]) else "en"
|
| 181 |
|
| 182 |
return MedicalResponse(
|
| 183 |
success=True,
|
| 184 |
+
response=fallback_responses.get(detected_lang, fallback_responses["en"]) + "\n\n⚕️ Medical Disclaimer: Always consult healthcare professionals for proper medical advice.",
|
| 185 |
+
detected_language=detected_lang,
|
| 186 |
+
processing_time=round(processing_time, 2)
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
try:
|
| 190 |
+
logger.info(f"🩺 Processing medical question: {request.question[:50]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
# Process with medical AI pipeline
|
| 193 |
result = pipeline.process(
|
| 194 |
+
question=request.question,
|
| 195 |
+
user_lang=request.language,
|
| 196 |
conversation_history=[]
|
| 197 |
)
|
| 198 |
|
| 199 |
processing_time = time.time() - start_time
|
| 200 |
|
| 201 |
+
return MedicalResponse(
|
| 202 |
success=True,
|
|
|
|
| 203 |
response=result["response"],
|
| 204 |
+
detected_language=result["source_lang"],
|
| 205 |
+
processing_time=round(processing_time, 2)
|
|
|
|
|
|
|
|
|
|
| 206 |
)
|
| 207 |
|
|
|
|
|
|
|
| 208 |
except Exception as e:
|
| 209 |
+
logger.error(f"❌ Error in medical consultation: {str(e)}")
|
| 210 |
processing_time = time.time() - start_time
|
| 211 |
|
| 212 |
raise HTTPException(
|
| 213 |
status_code=500,
|
| 214 |
detail={
|
| 215 |
"success": False,
|
| 216 |
+
"error": "Processing error occurred. Please try again.",
|
|
|
|
|
|
|
| 217 |
"processing_time": round(processing_time, 2)
|
| 218 |
}
|
| 219 |
)
|
| 220 |
|
| 221 |
+
@app.get("/medical/demo", tags=["medical"])
|
| 222 |
+
async def demo_questions():
|
| 223 |
"""
|
| 224 |
+
## Demo Questions
|
| 225 |
|
| 226 |
+
Get sample questions to try with the Medical AI.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
+
return {
|
| 229 |
+
"demo_questions": {
|
| 230 |
+
"english": [
|
| 231 |
+
"What are the symptoms of malaria?",
|
| 232 |
+
"How can I prevent diabetes?",
|
| 233 |
+
"What should I do for a fever?",
|
| 234 |
+
"How to maintain good hygiene?"
|
| 235 |
+
],
|
| 236 |
+
"french": [
|
| 237 |
+
"Quels sont les symptômes du paludisme?",
|
| 238 |
+
"Comment puis-je prévenir le diabète?",
|
| 239 |
+
"Que dois-je faire pour une fièvre?",
|
| 240 |
+
"Comment maintenir une bonne hygiène?"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
"note": "Try these questions to test the Medical AI Assistant"
|
| 244 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
@app.get("/medical/info", tags=["medical"])
|
| 247 |
+
async def medical_info():
|
| 248 |
+
"""Medical AI information"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
return {
|
| 250 |
+
"supported_languages": ["English", "French"],
|
| 251 |
"specialties": [
|
| 252 |
+
"General Medicine",
|
| 253 |
+
"Tropical Diseases",
|
| 254 |
+
"Preventive Care",
|
| 255 |
+
"Emergency Guidelines"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
],
|
| 257 |
+
"disclaimer": "⚕️ This AI provides educational information only. Always consult qualified healthcare professionals for medical advice.",
|
| 258 |
+
"optimization": "Lightweight version optimized for Hugging Face Spaces"
|
| 259 |
}
|
| 260 |
|
| 261 |
# ============================================================================
|
|
|
|
| 269 |
content={
|
| 270 |
"success": False,
|
| 271 |
"error": "Endpoint not found",
|
|
|
|
| 272 |
"available_endpoints": [
|
| 273 |
"/docs - API Documentation",
|
| 274 |
+
"/medical/ask - Medical consultation",
|
|
|
|
| 275 |
"/health - System status",
|
| 276 |
+
"/medical/demo - Demo questions"
|
| 277 |
]
|
| 278 |
}
|
| 279 |
)
|
|
|
|
| 284 |
status_code=422,
|
| 285 |
content={
|
| 286 |
"success": False,
|
| 287 |
+
"error": "Invalid request data",
|
| 288 |
+
"details": "Please check your request format"
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@app.exception_handler(500)
|
| 293 |
+
async def internal_error_handler(request, exc):
|
| 294 |
+
return JSONResponse(
|
| 295 |
+
status_code=500,
|
| 296 |
+
content={
|
| 297 |
+
"success": False,
|
| 298 |
+
"error": "Internal server error",
|
| 299 |
+
"message": "The Medical AI is experiencing technical difficulties"
|
| 300 |
}
|
| 301 |
)
|
| 302 |
|
|
|
|
| 304 |
# STARTUP MESSAGE
|
| 305 |
# ============================================================================
|
| 306 |
|
| 307 |
+
@app.on_event("startup")
|
| 308 |
+
async def startup_event():
|
| 309 |
+
"""Startup event handler"""
|
| 310 |
+
logger.info("🩺 Medical AI Assistant starting up on Hugging Face Spaces")
|
| 311 |
+
logger.info("📚 API Documentation: /docs")
|
| 312 |
+
logger.info("🔄 Alternative docs: /redoc")
|
| 313 |
+
logger.info("⚡ Optimized for Spaces deployment")
|
| 314 |
+
|
| 315 |
if __name__ == "__main__":
|
| 316 |
import uvicorn
|
| 317 |
|
| 318 |
+
print("🩺 Starting Medical AI Assistant for Spaces...")
|
| 319 |
+
print("📚 Documentation available at: http://localhost:7860/docs")
|
|
|
|
| 320 |
|
| 321 |
uvicorn.run(
|
| 322 |
app,
|
| 323 |
host="0.0.0.0",
|
| 324 |
+
port=7860, # Spaces port
|
| 325 |
log_level="info",
|
| 326 |
reload=False
|
| 327 |
)
|
medical_ai.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# medical_ai.py - VERSION
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
@@ -7,8 +7,7 @@ from typing import List, Dict, Any
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
import faiss
|
| 9 |
from functools import lru_cache
|
| 10 |
-
from transformers import
|
| 11 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
import torch
|
| 13 |
from typing import Optional
|
| 14 |
import logging
|
|
@@ -18,64 +17,52 @@ import re
|
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
# === CONFIGURATION
|
| 22 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
MODEL_NAME = "microsoft/DialoGPT-medium"
|
| 26 |
PATIENT_RECORDS_PATH = "patient_records.json"
|
| 27 |
|
| 28 |
-
#
|
| 29 |
DEVICE = "cpu"
|
| 30 |
-
MAX_LENGTH =
|
| 31 |
-
TEMPERATURE = 0.7
|
| 32 |
TOP_P = 0.9
|
| 33 |
-
TOP_K =
|
| 34 |
|
| 35 |
-
# === 1.
|
| 36 |
-
class
|
| 37 |
def __init__(self):
|
|
|
|
| 38 |
try:
|
| 39 |
-
|
| 40 |
-
self.
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
'fr-FR': 'fr', 'en-EN': 'en', 'fr_XX': 'fr', 'en_XX': 'en',
|
| 46 |
-
'LABEL_0': 'en', 'LABEL_1': 'fr' # Fallbacks
|
| 47 |
-
}
|
| 48 |
-
logger.info("Advanced language detector initialized")
|
| 49 |
-
except Exception as e:
|
| 50 |
-
logger.error(f"Error initializing language detector: {str(e)}")
|
| 51 |
-
self.lang_id = None
|
| 52 |
|
| 53 |
@lru_cache(maxsize=256)
|
| 54 |
def detect_language(self, text: str) -> str:
|
| 55 |
if not text.strip():
|
| 56 |
return 'en'
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
if self.
|
| 60 |
try:
|
| 61 |
-
|
| 62 |
-
detected
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Si confiance faible, utiliser détection par mots-clés
|
| 66 |
-
if confidence < 0.8:
|
| 67 |
-
return self._keyword_detection(text)
|
| 68 |
-
|
| 69 |
-
return self.lang_map.get(detected, 'en')
|
| 70 |
except:
|
| 71 |
-
|
| 72 |
|
|
|
|
| 73 |
return self._keyword_detection(text)
|
| 74 |
|
| 75 |
def _keyword_detection(self, text: str) -> str:
|
| 76 |
-
"""
|
| 77 |
-
french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement'
|
| 78 |
-
english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment'
|
| 79 |
|
| 80 |
text_lower = text.lower()
|
| 81 |
fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
|
|
@@ -83,60 +70,32 @@ class AdvancedLanguageDetector:
|
|
| 83 |
|
| 84 |
return 'fr' if fr_score > en_score else 'en'
|
| 85 |
|
| 86 |
-
# === 2.
|
| 87 |
-
class
|
| 88 |
-
def __init__(self
|
|
|
|
| 89 |
try:
|
| 90 |
-
|
| 91 |
-
self.
|
| 92 |
-
|
| 93 |
-
torch_dtype=torch.float32, # CPU optimized
|
| 94 |
-
low_cpu_mem_usage=True
|
| 95 |
-
)
|
| 96 |
-
self.lang_code_map = {
|
| 97 |
-
'fr': 'fra_Latn', 'en': 'eng_Latn', 'bss': 'bss_Latn',
|
| 98 |
-
'dua': 'dua_Latn', 'ewo': 'ewo_Latn',
|
| 99 |
-
}
|
| 100 |
-
logger.info("Optimized translator initialized")
|
| 101 |
except Exception as e:
|
| 102 |
logger.error(f"Error initializing translator: {str(e)}")
|
| 103 |
-
self.
|
| 104 |
-
self.model = None
|
| 105 |
|
| 106 |
@lru_cache(maxsize=256)
|
| 107 |
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 108 |
if not text.strip() or source_lang == target_lang:
|
| 109 |
return text
|
| 110 |
-
|
| 111 |
-
if self.tokenizer is None or self.model is None:
|
| 112 |
-
return text
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
self.tokenizer.src_lang = src
|
| 119 |
-
inputs = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
|
| 120 |
-
|
| 121 |
-
with torch.no_grad():
|
| 122 |
-
generated_tokens = self.model.generate(
|
| 123 |
-
**inputs,
|
| 124 |
-
forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(tgt),
|
| 125 |
-
max_length=512,
|
| 126 |
-
num_beams=4, # Améliore la qualité
|
| 127 |
-
early_stopping=True
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
| 131 |
-
return result
|
| 132 |
-
except Exception as e:
|
| 133 |
-
logger.error(f"Translation error: {str(e)}")
|
| 134 |
-
return text
|
| 135 |
|
| 136 |
-
# === 3.
|
| 137 |
-
class
|
| 138 |
def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
|
| 139 |
try:
|
|
|
|
| 140 |
self.embedder = SentenceTransformer(embedding_model_name)
|
| 141 |
|
| 142 |
if not os.path.exists(records_path):
|
|
@@ -146,43 +105,33 @@ class AdvancedMedicalRAG:
|
|
| 146 |
with open(records_path, 'r', encoding='utf-8') as f:
|
| 147 |
self.records = json.load(f)
|
| 148 |
|
| 149 |
-
#
|
| 150 |
self.medical_chunks = []
|
| 151 |
-
self.
|
| 152 |
-
self.emergency_chunks = []
|
| 153 |
-
self.prevention_chunks = []
|
| 154 |
-
|
| 155 |
-
self._build_specialized_chunks()
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
self.medical_index
|
| 159 |
-
self.edu_index, _ = self._build_faiss_index(self.educational_chunks)
|
| 160 |
-
self.emergency_index, _ = self._build_faiss_index(self.emergency_chunks)
|
| 161 |
-
self.prevention_index, _ = self._build_faiss_index(self.prevention_chunks)
|
| 162 |
|
| 163 |
-
logger.info(f"
|
| 164 |
-
f"{len(self.educational_chunks)} educational, {len(self.emergency_chunks)} emergency chunks")
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
-
logger.error(f"Error initializing
|
| 168 |
self._initialize_fallback()
|
| 169 |
|
| 170 |
def _create_sample_records(self, path: str):
|
| 171 |
-
"""
|
| 172 |
sample_records = [
|
| 173 |
{
|
| 174 |
"id": "malaria_001",
|
| 175 |
-
"diagnosis": {"en": "Malaria
|
| 176 |
-
"symptoms": {"en": "
|
| 177 |
-
"
|
| 178 |
-
"care_instructions": {"en": "Complete bed rest, increase fluid intake, complete full medication course, return if symptoms worsen or fever persists after 48 hours", "fr": "Repos complet au lit, augmenter l'apport hydrique, terminer le traitement complet, revenir si les symptômes s'aggravent ou si la fièvre persiste après 48 heures"}
|
| 179 |
},
|
| 180 |
{
|
| 181 |
-
"id": "
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
"target_group": "Adults over 30, family history of diabetes, sedentary lifestyle"
|
| 186 |
}
|
| 187 |
]
|
| 188 |
|
|
@@ -190,65 +139,25 @@ class AdvancedMedicalRAG:
|
|
| 190 |
json.dump(sample_records, f, ensure_ascii=False, indent=2)
|
| 191 |
|
| 192 |
def _initialize_fallback(self):
|
| 193 |
-
"""
|
| 194 |
-
self.medical_chunks = [
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
self.medical_index = None
|
| 200 |
-
self.edu_index = None
|
| 201 |
-
self.emergency_index = None
|
| 202 |
-
self.prevention_index = None
|
| 203 |
|
| 204 |
-
def
|
| 205 |
-
"""
|
| 206 |
for rec in self.records:
|
| 207 |
try:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
medical_parts.append(f"Symptoms: {rec['symptoms'].get('en', '')}")
|
| 215 |
-
|
| 216 |
-
if 'medications' in rec:
|
| 217 |
-
meds = [f"{m['name'].get('en', '')} ({m.get('dosage', '')})" for m in rec['medications']]
|
| 218 |
-
medical_parts.append(f"Treatment: {', '.join(meds)}")
|
| 219 |
-
|
| 220 |
-
if 'care_instructions' in rec:
|
| 221 |
-
medical_parts.append(f"Care instructions: {rec['care_instructions'].get('en', '')}")
|
| 222 |
-
|
| 223 |
-
if medical_parts:
|
| 224 |
-
self.medical_chunks.append(". ".join(medical_parts))
|
| 225 |
-
|
| 226 |
-
# Chunks éducatifs
|
| 227 |
-
if rec.get('context_type') == 'prevention' or 'educational_content' in rec:
|
| 228 |
-
edu_parts = []
|
| 229 |
-
if 'topic' in rec:
|
| 230 |
-
edu_parts.append(f"Topic: {rec['topic'].get('en', '')}")
|
| 231 |
-
if 'educational_content' in rec:
|
| 232 |
-
edu_parts.append(f"Information: {rec['educational_content'].get('en', '')}")
|
| 233 |
-
if 'target_group' in rec:
|
| 234 |
-
edu_parts.append(f"Target: {rec['target_group']}")
|
| 235 |
-
|
| 236 |
-
if edu_parts:
|
| 237 |
-
chunk = ". ".join(edu_parts)
|
| 238 |
-
self.educational_chunks.append(chunk)
|
| 239 |
-
if 'prevention' in chunk.lower():
|
| 240 |
-
self.prevention_chunks.append(chunk)
|
| 241 |
-
|
| 242 |
-
# Chunks d'urgence
|
| 243 |
-
if rec.get('context_type') == 'emergency_education' or 'emergency' in str(rec).lower():
|
| 244 |
-
emergency_parts = []
|
| 245 |
-
if 'scenario' in rec:
|
| 246 |
-
emergency_parts.append(f"Emergency: {rec['scenario'].get('en', '')}")
|
| 247 |
-
if 'action_steps' in rec:
|
| 248 |
-
emergency_parts.append(f"Actions: {rec['action_steps'].get('en', '')}")
|
| 249 |
-
|
| 250 |
-
if emergency_parts:
|
| 251 |
-
self.emergency_chunks.append(". ".join(emergency_parts))
|
| 252 |
|
| 253 |
except Exception as e:
|
| 254 |
logger.error(f"Error processing record: {str(e)}")
|
|
@@ -256,281 +165,152 @@ class AdvancedMedicalRAG:
|
|
| 256 |
|
| 257 |
def _build_faiss_index(self, chunks):
|
| 258 |
if not chunks:
|
| 259 |
-
return None
|
| 260 |
try:
|
| 261 |
embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
|
| 262 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 263 |
index.add(embeddings)
|
| 264 |
-
return index
|
| 265 |
except Exception as e:
|
| 266 |
logger.error(f"Error building FAISS index: {str(e)}")
|
| 267 |
-
return None
|
| 268 |
|
| 269 |
-
def
|
| 270 |
-
"""
|
| 271 |
-
question_lower = question.lower()
|
| 272 |
-
contexts = {
|
| 273 |
-
"medical": [],
|
| 274 |
-
"educational": [],
|
| 275 |
-
"emergency": [],
|
| 276 |
-
"prevention": []
|
| 277 |
-
}
|
| 278 |
-
|
| 279 |
try:
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
is_educational = any(word in question_lower for word in ['what is', 'explain', 'how', 'why', "qu'est-ce que", 'expliquer', 'comment', 'pourquoi'])
|
| 286 |
-
|
| 287 |
-
# Récupération contextuelle intelligente
|
| 288 |
-
if is_emergency and self.emergency_index:
|
| 289 |
-
_, I = self.emergency_index.search(q_emb, min(3, len(self.emergency_chunks)))
|
| 290 |
-
contexts["emergency"] = [self.emergency_chunks[i] for i in I[0] if i < len(self.emergency_chunks)]
|
| 291 |
-
|
| 292 |
-
if is_prevention and self.prevention_index:
|
| 293 |
-
_, I = self.prevention_index.search(q_emb, min(2, len(self.prevention_chunks)))
|
| 294 |
-
contexts["prevention"] = [self.prevention_chunks[i] for i in I[0] if i < len(self.prevention_chunks)]
|
| 295 |
-
|
| 296 |
-
if is_educational and self.edu_index:
|
| 297 |
-
_, I = self.edu_index.search(q_emb, min(3, len(self.educational_chunks)))
|
| 298 |
-
contexts["educational"] = [self.educational_chunks[i] for i in I[0] if i < len(self.educational_chunks)]
|
| 299 |
-
|
| 300 |
-
# Toujours inclure du contexte médical général
|
| 301 |
-
if self.medical_index:
|
| 302 |
-
n_med = 4 if not any(contexts.values()) else 2
|
| 303 |
-
_, I = self.medical_index.search(q_emb, min(n_med, len(self.medical_chunks)))
|
| 304 |
-
contexts["medical"] = [self.medical_chunks[i] for i in I[0] if i < len(self.medical_chunks)]
|
| 305 |
|
| 306 |
except Exception as e:
|
| 307 |
-
logger.error(f"Error getting
|
| 308 |
-
|
| 309 |
-
return contexts
|
| 310 |
|
| 311 |
-
# === 4.
|
| 312 |
-
class
|
| 313 |
def __init__(self, model_name: str = MODEL_NAME):
|
| 314 |
self.device = DEVICE
|
| 315 |
-
logger.info(f"Loading
|
| 316 |
|
| 317 |
try:
|
| 318 |
-
#
|
| 319 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
| 320 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
| 321 |
-
model_name,
|
| 322 |
-
torch_dtype=torch.float32,
|
| 323 |
-
low_cpu_mem_usage=True,
|
| 324 |
-
device_map="auto" if DEVICE == "cuda" else None
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
# Configuration du tokenizer
|
| 328 |
-
if self.tokenizer.pad_token is None:
|
| 329 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 330 |
-
|
| 331 |
self.generator = pipeline(
|
| 332 |
"text-generation",
|
| 333 |
-
model=
|
| 334 |
-
tokenizer=self.tokenizer,
|
| 335 |
device=-1, # CPU
|
| 336 |
-
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
-
logger.info(f"
|
| 340 |
except Exception as e:
|
| 341 |
logger.error(f"Error loading model: {str(e)}")
|
| 342 |
self.generator = None
|
| 343 |
|
| 344 |
-
def
|
| 345 |
-
"""
|
| 346 |
|
| 347 |
if self.generator is None:
|
| 348 |
-
return self.
|
| 349 |
|
| 350 |
try:
|
| 351 |
-
#
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
#
|
| 355 |
response = self.generator(
|
| 356 |
prompt,
|
| 357 |
-
max_length=len(prompt) +
|
| 358 |
-
num_return_sequences=1,
|
| 359 |
temperature=TEMPERATURE,
|
| 360 |
top_p=TOP_P,
|
| 361 |
top_k=TOP_K,
|
| 362 |
do_sample=True,
|
| 363 |
-
pad_token_id=self.tokenizer.eos_token_id
|
| 364 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
| 365 |
-
repetition_penalty=1.1, # Évite les répétitions
|
| 366 |
-
length_penalty=1.0,
|
| 367 |
-
early_stopping=True
|
| 368 |
)
|
| 369 |
|
| 370 |
-
#
|
| 371 |
full_text = response[0]['generated_text']
|
| 372 |
response_text = full_text[len(prompt):].strip()
|
| 373 |
|
| 374 |
-
#
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
-
return response_text
|
| 378 |
|
| 379 |
except Exception as e:
|
| 380 |
-
logger.error(f"Error in
|
| 381 |
-
return self.
|
| 382 |
-
|
| 383 |
-
def _build_expert_prompt(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
|
| 384 |
-
"""Construit un prompt de niveau expert pour la compétition"""
|
| 385 |
-
|
| 386 |
-
# Agrégation intelligente des contextes
|
| 387 |
-
context_parts = []
|
| 388 |
-
if contexts.get("emergency"):
|
| 389 |
-
context_parts.append(f"🚨 Emergency Protocol: {' | '.join(contexts['emergency'][:2])}")
|
| 390 |
-
if contexts.get("medical"):
|
| 391 |
-
context_parts.append(f"📋 Clinical Information: {' | '.join(contexts['medical'][:2])}")
|
| 392 |
-
if contexts.get("prevention"):
|
| 393 |
-
context_parts.append(f"🛡️ Prevention Guidelines: {' | '.join(contexts['prevention'][:1])}")
|
| 394 |
-
if contexts.get("educational"):
|
| 395 |
-
context_parts.append(f"📚 Educational Content: {' | '.join(contexts['educational'][:1])}")
|
| 396 |
-
|
| 397 |
-
context_str = "\n".join(context_parts) if context_parts else "General medical consultation context."
|
| 398 |
-
|
| 399 |
-
# Prompt structuré pour excellence
|
| 400 |
-
if lang == "fr":
|
| 401 |
-
prompt = f"""Contexte médical expert:
|
| 402 |
-
{context_str}
|
| 403 |
-
|
| 404 |
-
Question du patient: {question}
|
| 405 |
-
|
| 406 |
-
Réponse médicale experte (structurée et complète):"""
|
| 407 |
-
else:
|
| 408 |
-
prompt = f"""Expert medical context:
|
| 409 |
-
{context_str}
|
| 410 |
-
|
| 411 |
-
Patient question: {question}
|
| 412 |
-
|
| 413 |
-
Expert medical response (structured and comprehensive):"""
|
| 414 |
-
|
| 415 |
-
return prompt
|
| 416 |
|
| 417 |
-
def
|
| 418 |
-
"""
|
| 419 |
-
|
| 420 |
-
# Nettoyage des artifacts
|
| 421 |
-
for stop_seq in ["</s>", "\nPatient:", "\nDoctor:", "\nExpert:", "\n\nContext:", "Question:"]:
|
| 422 |
-
if stop_seq in response:
|
| 423 |
-
response = response.split(stop_seq)[0].strip()
|
| 424 |
-
|
| 425 |
-
# Structuration expert
|
| 426 |
-
if len(response.split('.')) > 2: # Si réponse assez longue
|
| 427 |
-
sentences = [s.strip() for s in response.split('.') if s.strip()]
|
| 428 |
-
if len(sentences) >= 3:
|
| 429 |
-
response = '. '.join(sentences[:4]) + '.' # Limiter à 4 phrases max
|
| 430 |
-
|
| 431 |
-
# Ajout disclaimer expert
|
| 432 |
-
disclaimer = {
|
| 433 |
-
"en": "\n\n⚕️ Medical Disclaimer: This information is for educational purposes. Always consult a qualified healthcare professional for proper diagnosis and treatment.",
|
| 434 |
-
"fr": "\n\n⚕️ Avertissement médical: Cette information est à des fins éducatives. Consultez toujours un professionnel de santé qualifié pour un diagnostic et traitement appropriés."
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
if "consult" not in response.lower() and "disclaimer" not in response.lower():
|
| 438 |
-
response += disclaimer.get(lang, disclaimer["en"])
|
| 439 |
-
|
| 440 |
-
return response.strip()
|
| 441 |
-
|
| 442 |
-
def _expert_fallback_response(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
|
| 443 |
-
"""Réponse de fallback de niveau expert"""
|
| 444 |
|
| 445 |
templates = {
|
| 446 |
-
"en":
|
| 447 |
-
|
| 448 |
-
"structure": "\n\n🔍 Assessment: This appears to be a medical inquiry requiring professional evaluation.\n\n💡 General Guidance: Monitor symptoms, maintain proper hygiene, stay hydrated, and seek appropriate medical care.\n\n⚠️ Important: For accurate diagnosis and treatment, please consult with a healthcare professional.",
|
| 449 |
-
"context_available": "According to medical literature and clinical guidelines: "
|
| 450 |
-
},
|
| 451 |
-
"fr": {
|
| 452 |
-
"intro": "Sur la base de l'expertise médicale et des informations cliniques disponibles:",
|
| 453 |
-
"structure": "\n\n🔍 Évaluation: Il s'agit d'une demande médicale nécessitant une évaluation professionnelle.\n\n💡 Guidance générale: Surveillez les symptômes, maintenez une hygiène appropriée, restez hydraté et consultez un professionnel de santé.\n\n⚠️ Important: Pour un diagnostic et traitement précis, veuillez consulter un professionnel de santé.",
|
| 454 |
-
"context_available": "Selon la littérature médicale et les directives cliniques: "
|
| 455 |
-
}
|
| 456 |
}
|
| 457 |
|
| 458 |
-
|
| 459 |
-
response = template["intro"]
|
| 460 |
-
|
| 461 |
-
# Intégrer contextes si disponibles
|
| 462 |
-
all_contexts = []
|
| 463 |
-
for context_list in contexts.values():
|
| 464 |
-
all_contexts.extend(context_list)
|
| 465 |
-
|
| 466 |
-
if all_contexts:
|
| 467 |
-
response += f" {template['context_available']}{' | '.join(all_contexts[:2])}"
|
| 468 |
-
|
| 469 |
-
response += template["structure"]
|
| 470 |
-
|
| 471 |
-
return response
|
| 472 |
|
| 473 |
-
# === PIPELINE
|
| 474 |
-
class
|
| 475 |
def __init__(self):
|
| 476 |
-
logger.info("
|
| 477 |
try:
|
| 478 |
-
self.lang_detector =
|
| 479 |
-
self.translator =
|
| 480 |
-
self.rag =
|
| 481 |
-
self.llm =
|
| 482 |
-
logger.info("
|
| 483 |
except Exception as e:
|
| 484 |
-
logger.error(f"Error initializing
|
| 485 |
raise
|
| 486 |
|
| 487 |
def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
|
| 488 |
-
"""
|
| 489 |
try:
|
| 490 |
if not question or not question.strip():
|
| 491 |
return self._empty_question_response(user_lang)
|
| 492 |
|
| 493 |
-
#
|
| 494 |
detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
|
| 495 |
-
logger.info(f"
|
| 496 |
-
|
| 497 |
-
# Traduction si nécessaire avec qualité optimale
|
| 498 |
-
question_en = question
|
| 499 |
-
if detected_lang != "en":
|
| 500 |
-
question_en = self.translator.translate(question, detected_lang, "en")
|
| 501 |
-
|
| 502 |
-
# RAG intelligent multi-contexte
|
| 503 |
-
smart_contexts = self.rag.get_smart_contexts(question_en, "en")
|
| 504 |
-
|
| 505 |
-
# Génération experte
|
| 506 |
-
response_en = self.llm.generate_expert_response(question_en, smart_contexts, "en")
|
| 507 |
|
| 508 |
-
#
|
| 509 |
-
|
| 510 |
-
if detected_lang != "en":
|
| 511 |
-
final_response = self.translator.translate(response_en, "en", detected_lang)
|
| 512 |
|
| 513 |
-
#
|
| 514 |
-
|
| 515 |
-
for context_list in smart_contexts.values():
|
| 516 |
-
all_contexts.extend(context_list)
|
| 517 |
|
| 518 |
return {
|
| 519 |
-
"response":
|
| 520 |
"source_lang": detected_lang,
|
| 521 |
-
"context_used":
|
| 522 |
-
"confidence": "
|
| 523 |
}
|
| 524 |
|
| 525 |
except Exception as e:
|
| 526 |
-
logger.error(f"
|
| 527 |
return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
|
| 528 |
|
| 529 |
def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
|
| 530 |
-
"""
|
| 531 |
responses = {
|
| 532 |
-
"en": "Please provide a medical question for
|
| 533 |
-
"fr": "Veuillez poser une question médicale pour
|
| 534 |
}
|
| 535 |
lang = user_lang if user_lang != "auto" else "en"
|
| 536 |
return {
|
|
@@ -541,17 +321,18 @@ class CompetitionMedicalAIPipeline:
|
|
| 541 |
}
|
| 542 |
|
| 543 |
def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
|
| 544 |
-
"""
|
| 545 |
responses = {
|
| 546 |
-
"en": "I
|
| 547 |
-
"fr": "Je
|
| 548 |
}
|
| 549 |
return {
|
| 550 |
"response": responses.get(lang, responses["en"]),
|
| 551 |
"source_lang": lang,
|
| 552 |
"context_used": [],
|
| 553 |
-
"confidence": "
|
| 554 |
}
|
| 555 |
|
| 556 |
-
#
|
| 557 |
-
|
|
|
|
|
|
| 1 |
+
# medical_ai.py - LIGHTWEIGHT VERSION FOR HUGGING FACE SPACES
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
import faiss
|
| 9 |
from functools import lru_cache
|
| 10 |
+
from transformers import pipeline, AutoTokenizer
|
|
|
|
| 11 |
import torch
|
| 12 |
from typing import Optional
|
| 13 |
import logging
|
|
|
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
+
# === LIGHTWEIGHT CONFIGURATION FOR SPACES ===
|
| 21 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 22 |
+
# Use a much lighter model that works on Spaces
|
| 23 |
+
MODEL_NAME = "microsoft/DialoGPT-small" # Changed from medium to small
|
|
|
|
| 24 |
PATIENT_RECORDS_PATH = "patient_records.json"
|
| 25 |
|
| 26 |
+
# Optimized for Spaces CPU limits
|
| 27 |
DEVICE = "cpu"
|
| 28 |
+
MAX_LENGTH = 256 # Reduced for faster processing
|
| 29 |
+
TEMPERATURE = 0.7
|
| 30 |
TOP_P = 0.9
|
| 31 |
+
TOP_K = 40
|
| 32 |
|
| 33 |
+
# === 1. SIMPLE LANGUAGE DETECTOR ===
|
| 34 |
+
class SimpleLanguageDetector:
|
| 35 |
def __init__(self):
|
| 36 |
+
# Use lightweight langdetect instead of heavy ML model
|
| 37 |
try:
|
| 38 |
+
from langdetect import detect
|
| 39 |
+
self.detect_func = detect
|
| 40 |
+
logger.info("Simple language detector initialized")
|
| 41 |
+
except ImportError:
|
| 42 |
+
logger.warning("langdetect not available, using keyword detection")
|
| 43 |
+
self.detect_func = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
@lru_cache(maxsize=256)
|
| 46 |
def detect_language(self, text: str) -> str:
|
| 47 |
if not text.strip():
|
| 48 |
return 'en'
|
| 49 |
|
| 50 |
+
# Try langdetect first
|
| 51 |
+
if self.detect_func:
|
| 52 |
try:
|
| 53 |
+
detected = self.detect_func(text)
|
| 54 |
+
if detected in ['fr', 'en']:
|
| 55 |
+
return detected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
except:
|
| 57 |
+
pass
|
| 58 |
|
| 59 |
+
# Fallback to keyword detection
|
| 60 |
return self._keyword_detection(text)
|
| 61 |
|
| 62 |
def _keyword_detection(self, text: str) -> str:
|
| 63 |
+
"""Keyword-based detection as fallback"""
|
| 64 |
+
french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement']
|
| 65 |
+
english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment']
|
| 66 |
|
| 67 |
text_lower = text.lower()
|
| 68 |
fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
|
|
|
|
| 70 |
|
| 71 |
return 'fr' if fr_score > en_score else 'en'
|
| 72 |
|
| 73 |
+
# === 2. SIMPLE TRANSLATOR ===
|
| 74 |
+
class SimpleTranslator:
|
| 75 |
+
def __init__(self):
|
| 76 |
+
# Use a lightweight translation approach
|
| 77 |
try:
|
| 78 |
+
# Only load if really needed
|
| 79 |
+
self.translator = None
|
| 80 |
+
logger.info("Simple translator initialized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
except Exception as e:
|
| 82 |
logger.error(f"Error initializing translator: {str(e)}")
|
| 83 |
+
self.translator = None
|
|
|
|
| 84 |
|
| 85 |
@lru_cache(maxsize=256)
|
| 86 |
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 87 |
if not text.strip() or source_lang == target_lang:
|
| 88 |
return text
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# For Spaces demo, we'll use simple template responses
|
| 91 |
+
# In production, you'd want proper translation
|
| 92 |
+
return text # Simplified for demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
# === 3. LIGHTWEIGHT MEDICAL RAG ===
|
| 95 |
+
class LightweightMedicalRAG:
|
| 96 |
def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
|
| 97 |
try:
|
| 98 |
+
logger.info("Loading lightweight embedder...")
|
| 99 |
self.embedder = SentenceTransformer(embedding_model_name)
|
| 100 |
|
| 101 |
if not os.path.exists(records_path):
|
|
|
|
| 105 |
with open(records_path, 'r', encoding='utf-8') as f:
|
| 106 |
self.records = json.load(f)
|
| 107 |
|
| 108 |
+
# Build simple medical chunks
|
| 109 |
self.medical_chunks = []
|
| 110 |
+
self._build_medical_chunks()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# Single FAISS index
|
| 113 |
+
self.medical_index = self._build_faiss_index(self.medical_chunks)
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
logger.info(f"Lightweight RAG initialized: {len(self.medical_chunks)} chunks")
|
|
|
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
+
logger.error(f"Error initializing RAG: {str(e)}")
|
| 119 |
self._initialize_fallback()
|
| 120 |
|
| 121 |
def _create_sample_records(self, path: str):
|
| 122 |
+
"""Create basic medical records"""
|
| 123 |
sample_records = [
|
| 124 |
{
|
| 125 |
"id": "malaria_001",
|
| 126 |
+
"diagnosis": {"en": "Malaria", "fr": "Paludisme"},
|
| 127 |
+
"symptoms": {"en": "Fever, chills, headache", "fr": "Fièvre, frissons, maux de tête"},
|
| 128 |
+
"treatment": {"en": "Antimalarial medication and rest", "fr": "Médicaments antipaludiques et repos"}
|
|
|
|
| 129 |
},
|
| 130 |
{
|
| 131 |
+
"id": "diabetes_001",
|
| 132 |
+
"diagnosis": {"en": "Diabetes", "fr": "Diabète"},
|
| 133 |
+
"symptoms": {"en": "Increased thirst, frequent urination", "fr": "Soif excessive, mictions fréquentes"},
|
| 134 |
+
"treatment": {"en": "Diet control and medication", "fr": "Contrôle alimentaire et médicaments"}
|
|
|
|
| 135 |
}
|
| 136 |
]
|
| 137 |
|
|
|
|
| 139 |
json.dump(sample_records, f, ensure_ascii=False, indent=2)
|
| 140 |
|
| 141 |
def _initialize_fallback(self):
|
| 142 |
+
"""Initialize fallback system"""
|
| 143 |
+
self.medical_chunks = [
|
| 144 |
+
"General medical consultation and symptom assessment",
|
| 145 |
+
"Common tropical diseases like malaria require immediate medical attention",
|
| 146 |
+
"Diabetes management involves diet control and regular monitoring"
|
| 147 |
+
]
|
| 148 |
self.medical_index = None
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
def _build_medical_chunks(self):
|
| 151 |
+
"""Build simple medical chunks"""
|
| 152 |
for rec in self.records:
|
| 153 |
try:
|
| 154 |
+
if 'diagnosis' in rec and 'symptoms' in rec:
|
| 155 |
+
chunk = f"Condition: {rec['diagnosis'].get('en', '')}. "
|
| 156 |
+
chunk += f"Symptoms: {rec['symptoms'].get('en', '')}. "
|
| 157 |
+
if 'treatment' in rec:
|
| 158 |
+
chunk += f"Treatment: {rec['treatment'].get('en', '')}"
|
| 159 |
|
| 160 |
+
self.medical_chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
except Exception as e:
|
| 163 |
logger.error(f"Error processing record: {str(e)}")
|
|
|
|
| 165 |
|
| 166 |
def _build_faiss_index(self, chunks):
|
| 167 |
if not chunks:
|
| 168 |
+
return None
|
| 169 |
try:
|
| 170 |
embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
|
| 171 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 172 |
index.add(embeddings)
|
| 173 |
+
return index
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f"Error building FAISS index: {str(e)}")
|
| 176 |
+
return None
|
| 177 |
|
| 178 |
+
def get_contexts(self, question: str, lang: str = "en") -> List[str]:
|
| 179 |
+
"""Get relevant medical contexts"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
try:
|
| 181 |
+
if self.medical_index is None:
|
| 182 |
+
return self.medical_chunks[:2]
|
| 183 |
|
| 184 |
+
q_emb = self.embedder.encode([question], convert_to_numpy=True)
|
| 185 |
+
_, I = self.medical_index.search(q_emb, min(3, len(self.medical_chunks)))
|
| 186 |
+
return [self.medical_chunks[i] for i in I[0] if i < len(self.medical_chunks)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
except Exception as e:
|
| 189 |
+
logger.error(f"Error getting contexts: {str(e)}")
|
| 190 |
+
return self.medical_chunks[:2]
|
|
|
|
| 191 |
|
| 192 |
+
# === 4. LIGHTWEIGHT LLM ===
|
| 193 |
+
class LightweightMedicalLLM:
|
| 194 |
def __init__(self, model_name: str = MODEL_NAME):
|
| 195 |
self.device = DEVICE
|
| 196 |
+
logger.info(f"Loading lightweight model {model_name}...")
|
| 197 |
|
| 198 |
try:
|
| 199 |
+
# Use pipeline for simplicity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
self.generator = pipeline(
|
| 201 |
"text-generation",
|
| 202 |
+
model=model_name,
|
|
|
|
| 203 |
device=-1, # CPU
|
| 204 |
+
torch_dtype=torch.float32,
|
| 205 |
+
model_kwargs={"low_cpu_mem_usage": True}
|
| 206 |
)
|
| 207 |
|
| 208 |
+
logger.info(f"Lightweight model {model_name} loaded successfully")
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"Error loading model: {str(e)}")
|
| 211 |
self.generator = None
|
| 212 |
|
| 213 |
+
def generate_response(self, question: str, contexts: List[str], lang: str = "en") -> str:
|
| 214 |
+
"""Generate medical response"""
|
| 215 |
|
| 216 |
if self.generator is None:
|
| 217 |
+
return self._fallback_response(question, contexts, lang)
|
| 218 |
|
| 219 |
try:
|
| 220 |
+
# Build simple prompt
|
| 221 |
+
context_str = " | ".join(contexts[:2]) if contexts else "General medical consultation"
|
| 222 |
+
|
| 223 |
+
if lang == "fr":
|
| 224 |
+
prompt = f"Contexte médical: {context_str}\n\nQuestion: {question}\n\nRéponse médicale:"
|
| 225 |
+
else:
|
| 226 |
+
prompt = f"Medical context: {context_str}\n\nQuestion: {question}\n\nMedical response:"
|
| 227 |
|
| 228 |
+
# Generate with conservative settings for Spaces
|
| 229 |
response = self.generator(
|
| 230 |
prompt,
|
| 231 |
+
max_length=len(prompt) + 150, # Shorter for faster processing
|
|
|
|
| 232 |
temperature=TEMPERATURE,
|
| 233 |
top_p=TOP_P,
|
| 234 |
top_k=TOP_K,
|
| 235 |
do_sample=True,
|
| 236 |
+
pad_token_id=self.generator.tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
+
# Extract response
|
| 240 |
full_text = response[0]['generated_text']
|
| 241 |
response_text = full_text[len(prompt):].strip()
|
| 242 |
|
| 243 |
+
# Add medical disclaimer
|
| 244 |
+
disclaimer = {
|
| 245 |
+
"en": "\n\n⚕️ Medical Disclaimer: Consult a healthcare professional for proper diagnosis.",
|
| 246 |
+
"fr": "\n\n⚕️ Avertissement médical: Consultez un professionnel de santé pour un diagnostic approprié."
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if "disclaimer" not in response_text.lower():
|
| 250 |
+
response_text += disclaimer.get(lang, disclaimer["en"])
|
| 251 |
|
| 252 |
+
return response_text.strip()
|
| 253 |
|
| 254 |
except Exception as e:
|
| 255 |
+
logger.error(f"Error in generation: {str(e)}")
|
| 256 |
+
return self._fallback_response(question, contexts, lang)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
def _fallback_response(self, question: str, contexts: List[str], lang: str) -> str:
|
| 259 |
+
"""Fallback response for errors"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
templates = {
|
| 262 |
+
"en": "Based on medical knowledge: This requires professional medical evaluation. Please consult with a healthcare provider for proper diagnosis and treatment. Stay hydrated and monitor symptoms.",
|
| 263 |
+
"fr": "Selon les connaissances médicales: Ceci nécessite une évaluation médicale professionnelle. Veuillez consulter un professionnel de santé pour un diagnostic et traitement appropriés."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
}
|
| 265 |
|
| 266 |
+
return templates.get(lang, templates["en"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
+
# === MAIN PIPELINE ===
|
| 269 |
+
class SpacesMedicalAIPipeline:
|
| 270 |
def __init__(self):
|
| 271 |
+
logger.info("🚀 Initializing Spaces Medical AI Pipeline...")
|
| 272 |
try:
|
| 273 |
+
self.lang_detector = SimpleLanguageDetector()
|
| 274 |
+
self.translator = SimpleTranslator()
|
| 275 |
+
self.rag = LightweightMedicalRAG()
|
| 276 |
+
self.llm = LightweightMedicalLLM()
|
| 277 |
+
logger.info("✅ Spaces Medical AI Pipeline ready!")
|
| 278 |
except Exception as e:
|
| 279 |
+
logger.error(f"Error initializing pipeline: {str(e)}")
|
| 280 |
raise
|
| 281 |
|
| 282 |
def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
|
| 283 |
+
"""Process medical question for Spaces"""
|
| 284 |
try:
|
| 285 |
if not question or not question.strip():
|
| 286 |
return self._empty_question_response(user_lang)
|
| 287 |
|
| 288 |
+
# Detect language
|
| 289 |
detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
|
| 290 |
+
logger.info(f"Processing question in {detected_lang}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
+
# Get medical contexts
|
| 293 |
+
contexts = self.rag.get_contexts(question, detected_lang)
|
|
|
|
|
|
|
| 294 |
|
| 295 |
+
# Generate response
|
| 296 |
+
response = self.llm.generate_response(question, contexts, detected_lang)
|
|
|
|
|
|
|
| 297 |
|
| 298 |
return {
|
| 299 |
+
"response": response,
|
| 300 |
"source_lang": detected_lang,
|
| 301 |
+
"context_used": contexts[:3],
|
| 302 |
+
"confidence": "medium"
|
| 303 |
}
|
| 304 |
|
| 305 |
except Exception as e:
|
| 306 |
+
logger.error(f"Processing error: {str(e)}")
|
| 307 |
return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
|
| 308 |
|
| 309 |
def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
|
| 310 |
+
"""Response for empty question"""
|
| 311 |
responses = {
|
| 312 |
+
"en": "Please provide a medical question for consultation.",
|
| 313 |
+
"fr": "Veuillez poser une question médicale pour consultation."
|
| 314 |
}
|
| 315 |
lang = user_lang if user_lang != "auto" else "en"
|
| 316 |
return {
|
|
|
|
| 321 |
}
|
| 322 |
|
| 323 |
def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
|
| 324 |
+
"""Error response"""
|
| 325 |
responses = {
|
| 326 |
+
"en": "I'm experiencing technical difficulties. Please try rephrasing your medical question.",
|
| 327 |
+
"fr": "Je rencontre des difficultés techniques. Veuillez reformuler votre question médicale."
|
| 328 |
}
|
| 329 |
return {
|
| 330 |
"response": responses.get(lang, responses["en"]),
|
| 331 |
"source_lang": lang,
|
| 332 |
"context_used": [],
|
| 333 |
+
"confidence": "low"
|
| 334 |
}
|
| 335 |
|
| 336 |
+
# Compatibility aliases
|
| 337 |
+
CompetitionMedicalAIPipeline = SpacesMedicalAIPipeline
|
| 338 |
+
MedicalAIPipeline = SpacesMedicalAIPipeline
|
requirements.txt
CHANGED
|
@@ -1,39 +1,34 @@
|
|
| 1 |
-
# FastAPI Medical AI - Requirements
|
| 2 |
# Core FastAPI dependencies
|
| 3 |
fastapi==0.104.1
|
| 4 |
uvicorn[standard]==0.24.0
|
| 5 |
python-multipart==0.0.6
|
| 6 |
pydantic==2.4.2
|
| 7 |
|
| 8 |
-
# ML
|
| 9 |
transformers==4.35.2
|
| 10 |
-
torch==2.1.0
|
| 11 |
sentence-transformers==2.2.2
|
| 12 |
faiss-cpu==1.7.4
|
| 13 |
-
faster-whisper==0.9.0
|
| 14 |
|
| 15 |
-
# Audio processing
|
| 16 |
librosa==0.10.1
|
| 17 |
soundfile==0.12.1
|
| 18 |
numpy==1.24.3
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Language processing
|
| 24 |
sentencepiece==0.1.99
|
| 25 |
langdetect==1.0.9
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
optimum==1.13.2
|
| 30 |
|
| 31 |
# System monitoring
|
| 32 |
psutil==5.9.6
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
pytest==7.4.3
|
| 36 |
-
pytest-asyncio==0.21.1
|
| 37 |
-
|
| 38 |
-
# Optional: For production deployment
|
| 39 |
-
gunicorn==21.2.0
|
|
|
|
| 1 |
+
# FastAPI Medical AI - Fixed Requirements for Hugging Face Spaces
|
| 2 |
# Core FastAPI dependencies
|
| 3 |
fastapi==0.104.1
|
| 4 |
uvicorn[standard]==0.24.0
|
| 5 |
python-multipart==0.0.6
|
| 6 |
pydantic==2.4.2
|
| 7 |
|
| 8 |
+
# Lightweight ML models for Spaces
|
| 9 |
transformers==4.35.2
|
| 10 |
+
torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu
|
| 11 |
sentence-transformers==2.2.2
|
| 12 |
faiss-cpu==1.7.4
|
|
|
|
| 13 |
|
| 14 |
+
# Audio processing (lightweight versions)
|
| 15 |
librosa==0.10.1
|
| 16 |
soundfile==0.12.1
|
| 17 |
numpy==1.24.3
|
| 18 |
|
| 19 |
+
# Remove heavy dependencies that cause timeouts
|
| 20 |
+
# faster-whisper==0.9.0 # Too heavy for Spaces
|
| 21 |
+
# accelerate==0.24.1 # Not needed for CPU
|
| 22 |
+
# optimum==1.13.2 # Not needed for basic setup
|
| 23 |
|
| 24 |
# Language processing
|
| 25 |
sentencepiece==0.1.99
|
| 26 |
langdetect==1.0.9
|
| 27 |
|
| 28 |
+
# HTTP requests
|
| 29 |
+
requests==2.31.0
|
|
|
|
| 30 |
|
| 31 |
# System monitoring
|
| 32 |
psutil==5.9.6
|
| 33 |
|
| 34 |
+
# Keep lightweight for Spaces deployment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|