Dama03 commited on
Commit
3581d64
·
1 Parent(s): e1a8fc4
Files changed (6) hide show
  1. README.md +67 -5
  2. app.py +25 -18
  3. audio_utils.py +0 -162
  4. fastapi_app.py +126 -413
  5. medical_ai.py +150 -369
  6. requirements.txt +11 -16
README.md CHANGED
@@ -1,12 +1,74 @@
1
  ---
2
- title: Medical
3
- emoji: 📈
4
  colorFrom: green
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
- short_description: medical assistant AI
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Medical AI Assistant
3
+ emoji: 🩺
4
  colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ app_port: 7860
10
+ short_description: Multilingual medical consultation AI assistant
11
  ---
12
 
13
+ # 🩺 Medical AI Assistant
14
+
15
+ A multilingual medical consultation AI assistant optimized for Hugging Face Spaces.
16
+
17
+ ## ✨ Features
18
+
19
+ - 🌍 **Multilingual Support**: English and French medical consultations
20
+ - 🧠 **Medical Knowledge**: AI-powered responses based on medical literature
21
+ - ⚡ **Fast Processing**: Optimized for quick responses
22
+ - 📚 **Educational Focus**: Provides educational medical information with proper disclaimers
23
+
24
+ ## 🚀 Quick Start
25
+
26
+ 1. Visit the **API Documentation** at `/docs` to explore all endpoints
27
+ 2. Try the main endpoint: `POST /medical/ask`
28
+ 3. Use demo questions from `/medical/demo`
29
+
30
+ ## 📋 Example Usage
31
+
32
+ ### English Consultation
33
+ ```json
34
+ {
35
+ "question": "What are the symptoms of malaria?",
36
+ "language": "en"
37
+ }
38
+ ```
39
+
40
+ ### French Consultation
41
+ ```json
42
+ {
43
+ "question": "Quels sont les symptômes du paludisme?",
44
+ "language": "fr"
45
+ }
46
+ ```
47
+
48
+ ## 🔧 API Endpoints
49
+
50
+ - `GET /` - Welcome message and API info
51
+ - `POST /medical/ask` - Medical consultation endpoint
52
+ - `GET /health` - System health check
53
+ - `GET /medical/demo` - Sample demo questions
54
+ - `GET /docs` - Interactive API documentation
55
+ - `GET /redoc` - Alternative API documentation
56
+
57
+ ## ⚕️ Medical Disclaimer
58
+
59
+ This AI assistant provides educational medical information only. Always consult qualified healthcare professionals for proper medical diagnosis and treatment.
60
+
61
+ ## 🏗️ Technical Details
62
+
63
+ - **Framework**: FastAPI
64
+ - **AI Models**: Lightweight transformers optimized for Spaces
65
+ - **Languages**: Python 3.9+
66
+ - **Deployment**: Docker on Hugging Face Spaces
67
+
68
+ ## 📞 Support
69
+
70
+ For technical issues or questions, please check the API documentation at `/docs`.
71
+
72
+ ---
73
+
74
+ *Optimized for Hugging Face Spaces deployment*
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Medical AI Assistant - FastAPI Only Entry Point
4
- Simplified for backend integration
5
  """
6
 
7
  import os
@@ -20,22 +20,22 @@ logging.basicConfig(
20
  logger = logging.getLogger(__name__)
21
 
22
  def setup_environment():
23
- """Setup environment variables for FastAPI deployment"""
24
  # Set environment variables for optimal performance
25
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
26
  os.environ.setdefault("HF_HOME", "/tmp/huggingface")
27
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
28
 
29
- # FastAPI specific
30
  os.environ.setdefault("HOST", "0.0.0.0")
31
- os.environ.setdefault("PORT", "8000")
32
 
33
- logger.info("✅ Environment configured for FastAPI Medical AI")
34
 
35
  def main():
36
  """Main application entry point"""
37
  try:
38
- logger.info("🩺 Starting Medical AI Assistant - FastAPI Edition")
39
 
40
  # Setup environment
41
  setup_environment()
@@ -44,22 +44,22 @@ def main():
44
  from fastapi_app import app
45
  import uvicorn
46
 
47
- # Get port from environment or use default
48
- port = int(os.getenv("PORT", 8000))
49
  host = os.getenv("HOST", "0.0.0.0")
50
 
51
  logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
52
- logger.info(f"📚 API Documentation will be available at: http://{host}:{port}/docs")
53
- logger.info(f"🔄 Alternative docs at: http://{host}:{port}/redoc")
54
 
55
- # Launch the FastAPI application
56
  uvicorn.run(
57
  app,
58
  host=host,
59
  port=port,
60
  log_level="info",
61
- reload=False, # Set to True for development
62
- access_log=True
 
63
  )
64
 
65
  except KeyboardInterrupt:
@@ -71,16 +71,23 @@ def main():
71
  # For direct import (Hugging Face Spaces compatibility)
72
  try:
73
  from fastapi_app import app
74
- logger.info("✅ FastAPI app imported successfully")
75
  except ImportError as e:
76
  logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
77
  # Create minimal fallback app
78
  from fastapi import FastAPI
79
- app = FastAPI(title="Medical AI - Error", description="Failed to load main application")
 
 
 
80
 
81
  @app.get("/")
82
- async def error_root():
83
- return {"error": "Medical AI Assistant failed to load properly"}
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Medical AI Assistant - Hugging Face Spaces Version
4
+ Optimized for Spaces deployment
5
  """
6
 
7
  import os
 
20
  logger = logging.getLogger(__name__)
21
 
22
  def setup_environment():
23
+ """Setup environment variables for Spaces deployment"""
24
  # Set environment variables for optimal performance
25
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
26
  os.environ.setdefault("HF_HOME", "/tmp/huggingface")
27
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/transformers")
28
 
29
+ # Spaces specific - use port 7860
30
  os.environ.setdefault("HOST", "0.0.0.0")
31
+ os.environ.setdefault("PORT", "7860") # Changed to Spaces default
32
 
33
+ logger.info("✅ Environment configured for Hugging Face Spaces")
34
 
35
  def main():
36
  """Main application entry point"""
37
  try:
38
+ logger.info("🩺 Starting Medical AI Assistant - Spaces Edition")
39
 
40
  # Setup environment
41
  setup_environment()
 
44
  from fastapi_app import app
45
  import uvicorn
46
 
47
+ # Get port from environment - Spaces uses 7860
48
+ port = int(os.getenv("PORT", 7860))
49
  host = os.getenv("HOST", "0.0.0.0")
50
 
51
  logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
52
+ logger.info(f"📚 API Documentation available at: http://{host}:{port}/docs")
 
53
 
54
+ # Launch the FastAPI application with Spaces-optimized settings
55
  uvicorn.run(
56
  app,
57
  host=host,
58
  port=port,
59
  log_level="info",
60
+ reload=False, # Never reload in production
61
+ access_log=True,
62
+ workers=1 # Single worker for Spaces
63
  )
64
 
65
  except KeyboardInterrupt:
 
71
  # For direct import (Hugging Face Spaces compatibility)
72
  try:
73
  from fastapi_app import app
74
+ logger.info("✅ FastAPI app imported successfully for Spaces")
75
  except ImportError as e:
76
  logger.error(f"❌ Failed to import FastAPI app: {str(e)}")
77
  # Create minimal fallback app
78
  from fastapi import FastAPI
79
+ app = FastAPI(
80
+ title="Medical AI - Loading",
81
+ description="Medical AI Assistant is starting up..."
82
+ )
83
 
84
  @app.get("/")
85
+ async def loading_root():
86
+ return {
87
+ "message": "🩺 Medical AI Assistant is starting up...",
88
+ "status": "loading",
89
+ "docs": "/docs"
90
+ }
91
 
92
  if __name__ == "__main__":
93
  main()
audio_utils.py DELETED
@@ -1,162 +0,0 @@
1
- import io
2
- import os
3
- import numpy as np
4
- import librosa
5
- import soundfile as sf
6
- from typing import Union
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- def preprocess_audio(audio_data: Union[bytes, str], target_sr: int = 16000) -> np.ndarray:
12
- """
13
- Preprocess audio data for Whisper transcription.
14
-
15
- Args:
16
- audio_data: Audio data as bytes or file path
17
- target_sr: Target sample rate for Whisper (16kHz)
18
-
19
- Returns:
20
- numpy array of preprocessed audio
21
- """
22
- try:
23
- if isinstance(audio_data, bytes):
24
- # Load audio from bytes
25
- audio_buffer = io.BytesIO(audio_data)
26
- audio, sr = sf.read(audio_buffer)
27
- logger.info(f"Loaded audio from bytes: {len(audio)} samples at {sr}Hz")
28
- elif isinstance(audio_data, str):
29
- # Load audio from file path
30
- if not os.path.exists(audio_data):
31
- logger.error(f"Audio file not found: {audio_data}")
32
- return np.array([], dtype=np.float32)
33
-
34
- audio, sr = librosa.load(audio_data, sr=None)
35
- logger.info(f"Loaded audio from file {audio_data}: {len(audio)} samples at {sr}Hz")
36
- else:
37
- logger.error(f"Unsupported audio data type: {type(audio_data)}")
38
- return np.array([], dtype=np.float32)
39
-
40
- # Handle empty audio
41
- if len(audio) == 0:
42
- logger.warning("Empty audio data received")
43
- return np.array([], dtype=np.float32)
44
-
45
- # Convert to mono if stereo
46
- if len(audio.shape) > 1:
47
- audio = librosa.to_mono(audio)
48
- logger.info("Converted stereo to mono")
49
-
50
- # Resample to target sample rate if needed
51
- if sr != target_sr:
52
- audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
53
- logger.info(f"Resampled from {sr}Hz to {target_sr}Hz")
54
-
55
- # Normalize audio to [-1, 1] range
56
- if np.max(np.abs(audio)) > 0:
57
- audio = audio / np.max(np.abs(audio))
58
-
59
- # Convert to float32
60
- audio = audio.astype(np.float32)
61
-
62
- # Remove silence from beginning and end (optional but helpful)
63
- try:
64
- audio, _ = librosa.effects.trim(audio, top_db=20)
65
- logger.info(f"Trimmed silence, final length: {len(audio)} samples")
66
- except Exception as e:
67
- logger.warning(f"Could not trim silence: {str(e)}")
68
-
69
- # Ensure minimum length (avoid very short clips)
70
- min_samples = int(0.1 * target_sr) # 0.1 second minimum
71
- if len(audio) < min_samples:
72
- logger.warning(f"Audio too short ({len(audio)} samples), padding to {min_samples}")
73
- audio = np.pad(audio, (0, min_samples - len(audio)), mode='constant')
74
-
75
- logger.info(f"Audio preprocessing completed: {len(audio)} samples at {target_sr}Hz")
76
- return audio
77
-
78
- except Exception as e:
79
- logger.error(f"Error preprocessing audio: {str(e)}", exc_info=True)
80
- # Return empty audio array as fallback
81
- return np.array([], dtype=np.float32)
82
-
83
- def validate_audio_format(audio_path: str) -> bool:
84
- """
85
- Validate if audio file format is supported
86
-
87
- Args:
88
- audio_path: Path to audio file
89
-
90
- Returns:
91
- True if format is supported, False otherwise
92
- """
93
- try:
94
- if not os.path.exists(audio_path):
95
- return False
96
-
97
- # Get file extension
98
- _, ext = os.path.splitext(audio_path.lower())
99
-
100
- # Supported formats
101
- supported_formats = ['.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac']
102
-
103
- return ext in supported_formats
104
-
105
- except Exception as e:
106
- logger.error(f"Error validating audio format: {str(e)}")
107
- return False
108
-
109
- def get_audio_info(audio_path: str) -> dict:
110
- """
111
- Get information about audio file
112
-
113
- Args:
114
- audio_path: Path to audio file
115
-
116
- Returns:
117
- Dictionary with audio information
118
- """
119
- try:
120
- if not os.path.exists(audio_path):
121
- return {"error": "File not found"}
122
-
123
- # Load audio metadata
124
- info = sf.info(audio_path)
125
-
126
- return {
127
- "duration": info.duration,
128
- "sample_rate": info.samplerate,
129
- "channels": info.channels,
130
- "format": info.format,
131
- "subtype": info.subtype,
132
- "frames": info.frames
133
- }
134
-
135
- except Exception as e:
136
- logger.error(f"Error getting audio info: {str(e)}")
137
- return {"error": str(e)}
138
-
139
- def enhance_audio_quality(audio: np.ndarray, sr: int = 16000) -> np.ndarray:
140
- """
141
- Apply basic audio enhancement for better transcription
142
-
143
- Args:
144
- audio: Audio signal
145
- sr: Sample rate
146
-
147
- Returns:
148
- Enhanced audio signal
149
- """
150
- try:
151
- # Apply high-pass filter to remove low-frequency noise
152
- audio_filtered = librosa.effects.preemphasis(audio)
153
-
154
- # Normalize again after filtering
155
- if np.max(np.abs(audio_filtered)) > 0:
156
- audio_filtered = audio_filtered / np.max(np.abs(audio_filtered))
157
-
158
- return audio_filtered.astype(np.float32)
159
-
160
- except Exception as e:
161
- logger.warning(f"Audio enhancement failed: {str(e)}")
162
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastapi_app.py CHANGED
@@ -1,20 +1,17 @@
1
  #!/usr/bin/env python3
2
  """
3
- Medical AI Assistant - FastAPI Only Version
4
- Simplified endpoints for backend integration with Swagger UI
5
  """
6
 
7
- from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import JSONResponse
10
- from fastapi.openapi.docs import get_swagger_ui_html
11
- from fastapi.openapi.utils import get_openapi
12
  from pydantic import BaseModel, Field
13
- from typing import List, Optional, Dict, Any, Union
14
  import logging
15
  import uuid
16
  import os
17
- import json
18
  import asyncio
19
  from contextlib import asynccontextmanager
20
  import time
@@ -28,37 +25,23 @@ logger = logging.getLogger(__name__)
28
 
29
  # Initialize models globally
30
  pipeline = None
31
- whisper_model = None
32
 
33
  async def load_models():
34
  """Load ML models asynchronously"""
35
- global pipeline, whisper_model
36
  try:
37
- logger.info("Loading Medical AI models...")
38
 
39
- from medical_ai import CompetitionMedicalAIPipeline
40
- pipeline = CompetitionMedicalAIPipeline()
 
41
  logger.info("✅ Medical pipeline loaded successfully")
42
 
43
- try:
44
- from faster_whisper import WhisperModel
45
- model_cache = os.getenv('HF_HOME', '/tmp/models')
46
- whisper_model = WhisperModel(
47
- "medium",
48
- device="cpu",
49
- compute_type="int8",
50
- download_root=model_cache
51
- )
52
- logger.info("✅ Whisper model loaded successfully")
53
- except Exception as e:
54
- logger.warning(f"⚠️ Could not load Whisper model: {str(e)}")
55
- whisper_model = None
56
-
57
- logger.info("🚀 All models loaded successfully")
58
 
59
  except Exception as e:
60
  logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
61
- raise
62
 
63
  @asynccontextmanager
64
  async def lifespan(app: FastAPI):
@@ -68,496 +51,211 @@ async def lifespan(app: FastAPI):
68
  yield
69
  except Exception as e:
70
  logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
71
- raise
 
72
  finally:
73
  logger.info("🔄 Shutting down...")
74
 
75
- # Custom OpenAPI schema
76
- def custom_openapi():
77
- if app.openapi_schema:
78
- return app.openapi_schema
79
-
80
- openapi_schema = get_openapi(
81
- title="🩺 Medical AI Assistant API",
82
- version="2.0.0",
83
- description="""
84
- ## 🎯 Advanced Medical AI Assistant
85
-
86
- **Multilingual medical consultation API** supporting:
87
- - 🌍 French, English, and local African languages
88
- - 🎤 Audio processing with speech-to-text
89
- - 🧠 Advanced medical knowledge retrieval
90
- - ⚡ Real-time medical consultations
91
-
92
- ### 🔧 Main Endpoints:
93
- - **POST /medical/ask** - Text-based medical consultation
94
- - **POST /medical/audio** - Audio-based medical consultation
95
- - **GET /health** - System health check
96
- - **POST /feedback** - Submit user feedback
97
-
98
- ### 🔒 Important Medical Disclaimer:
99
- This API provides educational medical information only. Always consult qualified healthcare professionals for medical advice.
100
- """,
101
- routes=app.routes,
102
- contact={
103
- "name": "Medical AI Support",
104
- "email": "support@medicalai.com"
105
- },
106
- license_info={
107
- "name": "MIT License",
108
- "url": "https://opensource.org/licenses/MIT"
109
- }
110
- )
111
-
112
- # Add custom tags
113
- openapi_schema["tags"] = [
114
- {
115
- "name": "medical",
116
- "description": "Medical consultation endpoints"
117
- },
118
- {
119
- "name": "audio",
120
- "description": "Audio processing endpoints"
121
- },
122
- {
123
- "name": "system",
124
- "description": "System monitoring and health"
125
- },
126
- {
127
- "name": "feedback",
128
- "description": "User feedback and analytics"
129
- }
130
- ]
131
-
132
- app.openapi_schema = openapi_schema
133
- return app.openapi_schema
134
-
135
- # Initialize FastAPI app
136
  app = FastAPI(
137
  title="🩺 Medical AI Assistant",
138
- description="Advanced multilingual medical consultation API",
139
- version="2.0.0",
140
  lifespan=lifespan,
141
  docs_url="/docs",
142
- redoc_url="/redoc",
143
- openapi_url="/openapi.json"
144
  )
145
 
146
- # Set custom OpenAPI
147
- app.openapi = custom_openapi
148
-
149
  # CORS middleware
150
  app.add_middleware(
151
  CORSMiddleware,
152
  allow_origins=["*"],
153
  allow_credentials=True,
154
  allow_methods=["*"],
155
- allow_headers=["*"],
156
- expose_headers=["*"]
157
  )
158
 
159
  # ============================================================================
160
- # PYDANTIC MODELS FOR REQUEST/RESPONSE VALIDATION
161
  # ============================================================================
162
 
163
  class MedicalQuestion(BaseModel):
164
  """Medical question request model"""
165
- question: str = Field(..., description="The medical question", min_length=3, max_length=1000)
166
- language: str = Field("auto", description="Preferred language (auto, en, fr)", pattern="^(auto|en|fr)$")
167
- conversation_id: Optional[str] = Field(None, description="Optional conversation ID for context")
168
 
169
  class Config:
170
  schema_extra = {
171
  "example": {
172
- "question": "What are the symptoms of malaria and how is it treated?",
173
- "language": "en",
174
- "conversation_id": "conv_123"
175
  }
176
  }
177
 
178
  class MedicalResponse(BaseModel):
179
  """Medical response model"""
180
- success: bool = Field(..., description="Whether the request was successful")
181
- response: str = Field(..., description="The medical response")
182
- detected_language: str = Field(..., description="Detected or used language")
183
- conversation_id: str = Field(..., description="Conversation identifier")
184
- context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
185
- processing_time: float = Field(..., description="Response time in seconds")
186
- confidence: str = Field(..., description="Response confidence level")
187
 
188
  class Config:
189
  schema_extra = {
190
  "example": {
191
  "success": True,
192
- "response": "Malaria symptoms include high fever, chills, headache...",
193
  "detected_language": "en",
194
- "conversation_id": "conv_123",
195
- "context_used": ["Malaria treatment protocols", "Symptom guidelines"],
196
- "processing_time": 2.5,
197
- "confidence": "high"
198
- }
199
- }
200
-
201
- class AudioResponse(BaseModel):
202
- """Audio processing response model"""
203
- success: bool = Field(..., description="Whether the request was successful")
204
- transcription: str = Field(..., description="Transcribed text from audio")
205
- response: str = Field(..., description="The medical response")
206
- detected_language: str = Field(..., description="Detected audio language")
207
- conversation_id: str = Field(..., description="Conversation identifier")
208
- context_used: List[str] = Field(default_factory=list, description="Medical contexts used")
209
- processing_time: float = Field(..., description="Response time in seconds")
210
- audio_duration: Optional[float] = Field(None, description="Audio duration in seconds")
211
-
212
- class Config:
213
- schema_extra = {
214
- "example": {
215
- "success": True,
216
- "transcription": "What are the symptoms of malaria?",
217
- "response": "Malaria symptoms include high fever, chills...",
218
- "detected_language": "en",
219
- "conversation_id": "conv_456",
220
- "context_used": ["Malaria diagnosis"],
221
- "processing_time": 3.2,
222
- "audio_duration": 4.5
223
- }
224
- }
225
-
226
- class FeedbackRequest(BaseModel):
227
- """Feedback request model"""
228
- conversation_id: str = Field(..., description="Conversation ID")
229
- rating: int = Field(..., description="Rating from 1-5", ge=1, le=5)
230
- feedback: Optional[str] = Field(None, description="Optional text feedback", max_length=500)
231
-
232
- class Config:
233
- schema_extra = {
234
- "example": {
235
- "conversation_id": "conv_123",
236
- "rating": 5,
237
- "feedback": "Very helpful and accurate medical information"
238
  }
239
  }
240
 
241
  class HealthStatus(BaseModel):
242
- """System health status model"""
243
- status: str = Field(..., description="Overall system status")
244
- models_loaded: bool = Field(..., description="Whether ML models are loaded")
245
- audio_available: bool = Field(..., description="Whether audio processing is available")
246
- uptime: float = Field(..., description="System uptime in seconds")
247
  version: str = Field(..., description="API version")
248
-
249
- class Config:
250
- schema_extra = {
251
- "example": {
252
- "status": "healthy",
253
- "models_loaded": True,
254
- "audio_available": True,
255
- "uptime": 3600.0,
256
- "version": "2.0.0"
257
- }
258
- }
259
-
260
- class ErrorResponse(BaseModel):
261
- """Error response model"""
262
- success: bool = Field(False, description="Always false for errors")
263
- error: str = Field(..., description="Error message")
264
- error_code: str = Field(..., description="Error code")
265
- conversation_id: Optional[str] = Field(None, description="Conversation ID if available")
266
-
267
- # ============================================================================
268
- # UTILITY FUNCTIONS
269
- # ============================================================================
270
-
271
- def generate_conversation_id() -> str:
272
- """Generate a unique conversation ID"""
273
- return f"conv_{uuid.uuid4().hex[:8]}"
274
 
275
  def validate_models():
276
  """Check if models are loaded"""
277
  if pipeline is None:
278
  raise HTTPException(
279
  status_code=503,
280
- detail="Medical AI models are not loaded yet. Please try again in a moment."
281
  )
282
 
283
  # ============================================================================
284
- # API ENDPOINTS
285
  # ============================================================================
286
 
287
  @app.get("/", tags=["system"])
288
  async def root():
289
- """Root endpoint with API information"""
290
  return {
291
- "message": "🩺 Medical AI Assistant API",
292
- "version": "2.0.0",
293
  "status": "running",
294
  "docs": "/docs",
295
- "redoc": "/redoc",
296
  "endpoints": {
297
  "medical_consultation": "/medical/ask",
298
- "audio_consultation": "/medical/audio",
299
- "health_check": "/health",
300
- "feedback": "/feedback"
301
- }
302
  }
303
 
304
  @app.get("/health", response_model=HealthStatus, tags=["system"])
305
  async def health_check():
306
- """
307
- ## System Health Check
308
-
309
- Returns the current status of the Medical AI system including:
310
- - Overall system health
311
- - Model loading status
312
- - Audio processing availability
313
- - System uptime
314
- """
315
- global pipeline, whisper_model
316
-
317
- # Calculate uptime (simplified)
318
- uptime = time.time() - getattr(health_check, 'start_time', time.time())
319
- if not hasattr(health_check, 'start_time'):
320
- health_check.start_time = time.time()
321
 
322
  return HealthStatus(
323
  status="healthy" if pipeline is not None else "loading",
324
  models_loaded=pipeline is not None,
325
- audio_available=whisper_model is not None,
326
- uptime=uptime,
327
- version="2.0.0"
328
  )
329
 
330
  @app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
331
  async def medical_consultation(request: MedicalQuestion):
332
  """
333
- ## Text-based Medical Consultation
334
 
335
- Process a medical question and return expert medical guidance.
336
 
337
  **Features:**
338
- - 🌍 Multilingual support (auto-detect or specify language)
339
- - 🧠 AI-powered medical knowledge retrieval
340
- - ⚡ Fast response generation
341
- - 🔒 Medical disclaimers included
342
-
343
- **Supported Languages:** English (en), French (fr), Auto-detect (auto)
344
  """
345
  start_time = time.time()
346
- validate_models()
347
-
348
- conversation_id = request.conversation_id or generate_conversation_id()
349
 
350
- try:
351
- logger.info(f"🩺 Processing medical question: {request.question[:50]}...")
 
 
352
 
353
- # Process with medical AI pipeline
354
- result = pipeline.process(
355
- question=request.question,
356
- user_lang=request.language,
357
- conversation_history=[]
358
- )
359
 
360
- processing_time = time.time() - start_time
361
 
362
  return MedicalResponse(
363
  success=True,
364
- response=result["response"],
365
- detected_language=result["source_lang"],
366
- conversation_id=conversation_id,
367
- context_used=result.get("context_used", []),
368
- processing_time=round(processing_time, 2),
369
- confidence=result.get("confidence", "medium")
370
  )
371
-
372
- except Exception as e:
373
- logger.error(f"❌ Error in medical consultation: {str(e)}", exc_info=True)
374
- processing_time = time.time() - start_time
375
-
376
- raise HTTPException(
377
- status_code=500,
378
- detail={
379
- "success": False,
380
- "error": "Internal processing error occurred",
381
- "error_code": "MEDICAL_PROCESSING_ERROR",
382
- "conversation_id": conversation_id,
383
- "processing_time": round(processing_time, 2)
384
- }
385
- )
386
-
387
- @app.post("/medical/audio", response_model=AudioResponse, tags=["audio", "medical"])
388
- async def audio_medical_consultation(
389
- file: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, etc.)")
390
- ):
391
- """
392
- ## Audio-based Medical Consultation
393
-
394
- Process an audio medical question and return expert medical guidance.
395
-
396
- **Features:**
397
- - 🎤 Speech-to-text conversion
398
- - 🌍 Language detection from audio
399
- - 🧠 Medical AI processing of transcribed text
400
- - 📝 Full transcription provided
401
-
402
- **Supported Audio Formats:** WAV, MP3, M4A, FLAC, OGG
403
- **Max File Size:** 25MB
404
- **Max Duration:** 5 minutes
405
- """
406
- start_time = time.time()
407
- validate_models()
408
-
409
- if whisper_model is None:
410
- raise HTTPException(
411
- status_code=503,
412
- detail="Audio processing is currently unavailable"
413
- )
414
-
415
- conversation_id = generate_conversation_id()
416
 
417
  try:
418
- logger.info(f"🎤 Processing audio file: {file.filename}")
419
-
420
- # Read audio file
421
- file_bytes = await file.read()
422
-
423
- # Process audio
424
- from audio_utils import preprocess_audio
425
- processed_audio = preprocess_audio(file_bytes)
426
-
427
- if len(processed_audio) == 0:
428
- raise HTTPException(
429
- status_code=400,
430
- detail="Could not process audio file. Please check the format and try again."
431
- )
432
-
433
- # Transcribe audio
434
- segments, info = whisper_model.transcribe(
435
- processed_audio,
436
- beam_size=5,
437
- language=None,
438
- task='transcribe',
439
- vad_filter=True
440
- )
441
-
442
- transcription = "".join([seg.text for seg in segments])
443
- detected_language = info.language
444
-
445
- if not transcription.strip():
446
- raise HTTPException(
447
- status_code=400,
448
- detail="Could not transcribe audio. Please ensure clear speech and try again."
449
- )
450
-
451
- logger.info(f"🔤 Transcription: {transcription[:100]}...")
452
 
453
- # Process transcribed text with medical AI
454
  result = pipeline.process(
455
- question=transcription,
456
- user_lang=detected_language,
457
  conversation_history=[]
458
  )
459
 
460
  processing_time = time.time() - start_time
461
 
462
- return AudioResponse(
463
  success=True,
464
- transcription=transcription,
465
  response=result["response"],
466
- detected_language=detected_language,
467
- conversation_id=conversation_id,
468
- context_used=result.get("context_used", []),
469
- processing_time=round(processing_time, 2),
470
- audio_duration=len(processed_audio) / 16000 # Assuming 16kHz sample rate
471
  )
472
 
473
- except HTTPException:
474
- raise
475
  except Exception as e:
476
- logger.error(f"❌ Error in audio processing: {str(e)}", exc_info=True)
477
  processing_time = time.time() - start_time
478
 
479
  raise HTTPException(
480
  status_code=500,
481
  detail={
482
  "success": False,
483
- "error": "Audio processing error occurred",
484
- "error_code": "AUDIO_PROCESSING_ERROR",
485
- "conversation_id": conversation_id,
486
  "processing_time": round(processing_time, 2)
487
  }
488
  )
489
 
490
- @app.post("/feedback", tags=["feedback"])
491
- async def submit_feedback(request: FeedbackRequest):
492
  """
493
- ## Submit User Feedback
494
 
495
- Submit feedback about a medical consultation to help improve the service.
496
-
497
- **Rating Scale:**
498
- - 1: Very Poor
499
- - 2: Poor
500
- - 3: Average
501
- - 4: Good
502
- - 5: Excellent
503
  """
504
- try:
505
- logger.info(f"📊 Feedback received - ID: {request.conversation_id}, Rating: {request.rating}")
506
-
507
- # Here you could store feedback in a database
508
- # For now, just log it
509
- feedback_data = {
510
- "conversation_id": request.conversation_id,
511
- "rating": request.rating,
512
- "feedback": request.feedback,
513
- "timestamp": time.time()
514
- }
515
-
516
- return {
517
- "success": True,
518
- "message": "Thank you for your feedback! This helps us improve our medical AI service.",
519
- "feedback_id": f"fb_{uuid.uuid4().hex[:8]}"
520
- }
521
-
522
- except Exception as e:
523
- logger.error(f"❌ Error processing feedback: {str(e)}")
524
- raise HTTPException(
525
- status_code=500,
526
- detail="Error processing feedback"
527
- )
528
 
529
- @app.get("/medical/specialties", tags=["medical"])
530
- async def get_medical_specialties():
531
- """
532
- ## Get Supported Medical Specialties
533
-
534
- Returns a list of medical specialties and conditions supported by the AI.
535
- """
536
  return {
 
537
  "specialties": [
538
- {
539
- "name": "Primary Care",
540
- "description": "General medical consultations and health guidance",
541
- "conditions": ["General symptoms", "Preventive care", "Health maintenance"]
542
- },
543
- {
544
- "name": "Infectious Diseases",
545
- "description": "Infectious disease diagnosis and treatment",
546
- "conditions": ["Malaria", "Tuberculosis", "HIV/AIDS", "Respiratory infections"]
547
- },
548
- {
549
- "name": "Emergency Medicine",
550
- "description": "Emergency protocols and urgent care guidance",
551
- "conditions": ["Stroke recognition", "Cardiac emergencies", "Trauma assessment"]
552
- },
553
- {
554
- "name": "Chronic Disease Management",
555
- "description": "Management of chronic conditions",
556
- "conditions": ["Diabetes", "Hypertension", "Gastritis"]
557
- }
558
  ],
559
- "languages_supported": ["English", "French", "Auto-detect"],
560
- "disclaimer": "This AI provides educational information only. Always consult healthcare professionals for medical advice."
561
  }
562
 
563
  # ============================================================================
@@ -571,13 +269,11 @@ async def not_found_handler(request, exc):
571
  content={
572
  "success": False,
573
  "error": "Endpoint not found",
574
- "error_code": "NOT_FOUND",
575
  "available_endpoints": [
576
  "/docs - API Documentation",
577
- "/medical/ask - Text consultation",
578
- "/medical/audio - Audio consultation",
579
  "/health - System status",
580
- "/feedback - Submit feedback"
581
  ]
582
  }
583
  )
@@ -588,9 +284,19 @@ async def validation_exception_handler(request, exc):
588
  status_code=422,
589
  content={
590
  "success": False,
591
- "error": "Invalid request data",
592
- "error_code": "VALIDATION_ERROR",
593
- "details": exc.errors()
 
 
 
 
 
 
 
 
 
 
594
  }
595
  )
596
 
@@ -598,17 +304,24 @@ async def validation_exception_handler(request, exc):
598
  # STARTUP MESSAGE
599
  # ============================================================================
600
 
 
 
 
 
 
 
 
 
601
  if __name__ == "__main__":
602
  import uvicorn
603
 
604
- print("🩺 Starting Medical AI Assistant API...")
605
- print("📚 Documentation available at: http://localhost:8000/docs")
606
- print("🔄 Alternative docs at: http://localhost:8000/redoc")
607
 
608
  uvicorn.run(
609
  app,
610
  host="0.0.0.0",
611
- port=8000,
612
  log_level="info",
613
  reload=False
614
  )
 
1
  #!/usr/bin/env python3
2
  """
3
+ Medical AI Assistant - Lightweight FastAPI for Spaces
4
+ Optimized for Hugging Face Spaces deployment
5
  """
6
 
7
+ from fastapi import FastAPI, HTTPException, File, UploadFile
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import JSONResponse
 
 
10
  from pydantic import BaseModel, Field
11
+ from typing import List, Optional, Dict, Any
12
  import logging
13
  import uuid
14
  import os
 
15
  import asyncio
16
  from contextlib import asynccontextmanager
17
  import time
 
25
 
26
  # Initialize models globally
27
  pipeline = None
 
28
 
29
  async def load_models():
30
  """Load ML models asynchronously"""
31
+ global pipeline
32
  try:
33
+ logger.info("🔄 Loading Medical AI models...")
34
 
35
+ # Use the lightweight version
36
+ from medical_ai import SpacesMedicalAIPipeline
37
+ pipeline = SpacesMedicalAIPipeline()
38
  logger.info("✅ Medical pipeline loaded successfully")
39
 
40
+ logger.info("🚀 All models ready for Spaces!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  except Exception as e:
43
  logger.error(f"❌ Error loading models: {str(e)}", exc_info=True)
44
+ # Don't raise - let the app start with fallback responses
45
 
46
  @asynccontextmanager
47
  async def lifespan(app: FastAPI):
 
51
  yield
52
  except Exception as e:
53
  logger.error(f"❌ Error during startup: {str(e)}", exc_info=True)
54
+ # Continue anyway for demo purposes
55
+ yield
56
  finally:
57
  logger.info("🔄 Shutting down...")
58
 
59
+ # Initialize FastAPI app - Simplified for Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  app = FastAPI(
61
  title="🩺 Medical AI Assistant",
62
+ description="Multilingual medical consultation API optimized for Hugging Face Spaces",
63
+ version="2.0.0-spaces",
64
  lifespan=lifespan,
65
  docs_url="/docs",
66
+ redoc_url="/redoc"
 
67
  )
68
 
 
 
 
69
  # CORS middleware
70
  app.add_middleware(
71
  CORSMiddleware,
72
  allow_origins=["*"],
73
  allow_credentials=True,
74
  allow_methods=["*"],
75
+ allow_headers=["*"]
 
76
  )
77
 
78
  # ============================================================================
79
+ # PYDANTIC MODELS - Simplified
80
  # ============================================================================
81
 
82
  class MedicalQuestion(BaseModel):
83
  """Medical question request model"""
84
+ question: str = Field(..., description="The medical question", min_length=3, max_length=500)
85
+ language: str = Field("auto", description="Language (auto, en, fr)")
 
86
 
87
  class Config:
88
  schema_extra = {
89
  "example": {
90
+ "question": "What are the symptoms of malaria?",
91
+ "language": "en"
 
92
  }
93
  }
94
 
95
  class MedicalResponse(BaseModel):
96
  """Medical response model"""
97
+ success: bool = Field(..., description="Success status")
98
+ response: str = Field(..., description="Medical response")
99
+ detected_language: str = Field(..., description="Detected language")
100
+ processing_time: float = Field(..., description="Processing time")
 
 
 
101
 
102
  class Config:
103
  schema_extra = {
104
  "example": {
105
  "success": True,
106
+ "response": "Malaria symptoms include fever, chills, headache...",
107
  "detected_language": "en",
108
+ "processing_time": 2.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  }
110
  }
111
 
112
  class HealthStatus(BaseModel):
113
+ """System health status"""
114
+ status: str = Field(..., description="System status")
115
+ models_loaded: bool = Field(..., description="Models loaded status")
 
 
116
  version: str = Field(..., description="API version")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def validate_models():
119
  """Check if models are loaded"""
120
  if pipeline is None:
121
  raise HTTPException(
122
  status_code=503,
123
+ detail="Medical AI models are still loading. Please try again in a moment."
124
  )
125
 
126
  # ============================================================================
127
+ # API ENDPOINTS - Simplified for Spaces
128
  # ============================================================================
129
 
130
  @app.get("/", tags=["system"])
131
  async def root():
132
+ """Root endpoint"""
133
  return {
134
+ "message": "🩺 Medical AI Assistant - Spaces Edition",
135
+ "version": "2.0.0-spaces",
136
  "status": "running",
137
  "docs": "/docs",
 
138
  "endpoints": {
139
  "medical_consultation": "/medical/ask",
140
+ "health_check": "/health"
141
+ },
142
+ "demo_note": "Optimized for Hugging Face Spaces deployment"
 
143
  }
144
 
145
  @app.get("/health", response_model=HealthStatus, tags=["system"])
146
  async def health_check():
147
+ """System health check"""
148
+ global pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  return HealthStatus(
151
  status="healthy" if pipeline is not None else "loading",
152
  models_loaded=pipeline is not None,
153
+ version="2.0.0-spaces"
 
 
154
  )
155
 
156
  @app.post("/medical/ask", response_model=MedicalResponse, tags=["medical"])
157
  async def medical_consultation(request: MedicalQuestion):
158
  """
159
+ ## Medical Consultation
160
 
161
+ Ask medical questions and get AI-powered responses.
162
 
163
  **Features:**
164
+ - 🌍 Multilingual support (English, French)
165
+ - 🧠 Medical knowledge retrieval
166
+ - ⚡ Optimized for Spaces
 
 
 
167
  """
168
  start_time = time.time()
 
 
 
169
 
170
+ # Check if models are loaded - with graceful fallback
171
+ if pipeline is None:
172
+ logger.warning("Models not loaded, using fallback response")
173
+ processing_time = time.time() - start_time
174
 
175
+ fallback_responses = {
176
+ "en": "Medical AI is still initializing. For immediate medical concerns, please consult a healthcare professional. Common symptoms like fever, headache, or persistent pain should be evaluated by a doctor.",
177
+ "fr": "L'IA médicale s'initialise encore. Pour des préoccupations médicales immédiates, veuillez consulter un professionnel de santé. Les symptômes courants comme la fièvre, les maux de tête ou la douleur persistante doivent être évalués par un médecin."
178
+ }
 
 
179
 
180
+ detected_lang = "fr" if any(word in request.question.lower() for word in ["quoi", "comment", "pourquoi"]) else "en"
181
 
182
  return MedicalResponse(
183
  success=True,
184
+ response=fallback_responses.get(detected_lang, fallback_responses["en"]) + "\n\n⚕️ Medical Disclaimer: Always consult healthcare professionals for proper medical advice.",
185
+ detected_language=detected_lang,
186
+ processing_time=round(processing_time, 2)
 
 
 
187
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  try:
190
+ logger.info(f"🩺 Processing medical question: {request.question[:50]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # Process with medical AI pipeline
193
  result = pipeline.process(
194
+ question=request.question,
195
+ user_lang=request.language,
196
  conversation_history=[]
197
  )
198
 
199
  processing_time = time.time() - start_time
200
 
201
+ return MedicalResponse(
202
  success=True,
 
203
  response=result["response"],
204
+ detected_language=result["source_lang"],
205
+ processing_time=round(processing_time, 2)
 
 
 
206
  )
207
 
 
 
208
  except Exception as e:
209
+ logger.error(f"❌ Error in medical consultation: {str(e)}")
210
  processing_time = time.time() - start_time
211
 
212
  raise HTTPException(
213
  status_code=500,
214
  detail={
215
  "success": False,
216
+ "error": "Processing error occurred. Please try again.",
 
 
217
  "processing_time": round(processing_time, 2)
218
  }
219
  )
220
 
221
+ @app.get("/medical/demo", tags=["medical"])
222
+ async def demo_questions():
223
  """
224
+ ## Demo Questions
225
 
226
+ Get sample questions to try with the Medical AI.
 
 
 
 
 
 
 
227
  """
228
+ return {
229
+ "demo_questions": {
230
+ "english": [
231
+ "What are the symptoms of malaria?",
232
+ "How can I prevent diabetes?",
233
+ "What should I do for a fever?",
234
+ "How to maintain good hygiene?"
235
+ ],
236
+ "french": [
237
+ "Quels sont les symptômes du paludisme?",
238
+ "Comment puis-je prévenir le diabète?",
239
+ "Que dois-je faire pour une fièvre?",
240
+ "Comment maintenir une bonne hygiène?"
241
+ ]
242
+ },
243
+ "note": "Try these questions to test the Medical AI Assistant"
244
+ }
 
 
 
 
 
 
 
245
 
246
+ @app.get("/medical/info", tags=["medical"])
247
+ async def medical_info():
248
+ """Medical AI information"""
 
 
 
 
249
  return {
250
+ "supported_languages": ["English", "French"],
251
  "specialties": [
252
+ "General Medicine",
253
+ "Tropical Diseases",
254
+ "Preventive Care",
255
+ "Emergency Guidelines"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  ],
257
+ "disclaimer": "⚕️ This AI provides educational information only. Always consult qualified healthcare professionals for medical advice.",
258
+ "optimization": "Lightweight version optimized for Hugging Face Spaces"
259
  }
260
 
261
  # ============================================================================
 
269
  content={
270
  "success": False,
271
  "error": "Endpoint not found",
 
272
  "available_endpoints": [
273
  "/docs - API Documentation",
274
+ "/medical/ask - Medical consultation",
 
275
  "/health - System status",
276
+ "/medical/demo - Demo questions"
277
  ]
278
  }
279
  )
 
284
  status_code=422,
285
  content={
286
  "success": False,
287
+ "error": "Invalid request data",
288
+ "details": "Please check your request format"
289
+ }
290
+ )
291
+
292
+ @app.exception_handler(500)
293
+ async def internal_error_handler(request, exc):
294
+ return JSONResponse(
295
+ status_code=500,
296
+ content={
297
+ "success": False,
298
+ "error": "Internal server error",
299
+ "message": "The Medical AI is experiencing technical difficulties"
300
  }
301
  )
302
 
 
304
  # STARTUP MESSAGE
305
  # ============================================================================
306
 
307
+ @app.on_event("startup")
308
+ async def startup_event():
309
+ """Startup event handler"""
310
+ logger.info("🩺 Medical AI Assistant starting up on Hugging Face Spaces")
311
+ logger.info("📚 API Documentation: /docs")
312
+ logger.info("🔄 Alternative docs: /redoc")
313
+ logger.info("⚡ Optimized for Spaces deployment")
314
+
315
  if __name__ == "__main__":
316
  import uvicorn
317
 
318
+ print("🩺 Starting Medical AI Assistant for Spaces...")
319
+ print("📚 Documentation available at: http://localhost:7860/docs")
 
320
 
321
  uvicorn.run(
322
  app,
323
  host="0.0.0.0",
324
+ port=7860, # Spaces port
325
  log_level="info",
326
  reload=False
327
  )
medical_ai.py CHANGED
@@ -1,4 +1,4 @@
1
- # medical_ai.py - VERSION COMPETITION OPTIMISÉE
2
 
3
  import os
4
  import json
@@ -7,8 +7,7 @@ from typing import List, Dict, Any
7
  from sentence_transformers import SentenceTransformer
8
  import faiss
9
  from functools import lru_cache
10
- from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, pipeline
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
  import torch
13
  from typing import Optional
14
  import logging
@@ -18,64 +17,52 @@ import re
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # === CONFIGURATION COMPÉTITION ===
22
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
23
- NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
24
- # MODÈLE PRINCIPAL - MEDIUM pour la compétition
25
- MODEL_NAME = "microsoft/DialoGPT-medium"
26
  PATIENT_RECORDS_PATH = "patient_records.json"
27
 
28
- # Configuration optimisée pour CPU avec performance maximale
29
  DEVICE = "cpu"
30
- MAX_LENGTH = 512 # Augmenté pour des réponses plus complètes
31
- TEMPERATURE = 0.7 # Équilibre créativité/cohérence
32
  TOP_P = 0.9
33
- TOP_K = 50
34
 
35
- # === 1. DÉTECTION DE LANGUE AVANCÉE ===
36
- class AdvancedLanguageDetector:
37
  def __init__(self):
 
38
  try:
39
- # Utilise un modèle plus précis pour la détection
40
- self.lang_id = pipeline("text-classification",
41
- model="papluca/xlm-roberta-base-language-detection",
42
- device=-1) # Force CPU
43
- self.lang_map = {
44
- 'fr': 'fr', 'en': 'en', 'bss': 'bss', 'dua': 'dua', 'ewo': 'ewo',
45
- 'fr-FR': 'fr', 'en-EN': 'en', 'fr_XX': 'fr', 'en_XX': 'en',
46
- 'LABEL_0': 'en', 'LABEL_1': 'fr' # Fallbacks
47
- }
48
- logger.info("Advanced language detector initialized")
49
- except Exception as e:
50
- logger.error(f"Error initializing language detector: {str(e)}")
51
- self.lang_id = None
52
 
53
  @lru_cache(maxsize=256)
54
  def detect_language(self, text: str) -> str:
55
  if not text.strip():
56
  return 'en'
57
 
58
- # Méthode hybride : ML + règles
59
- if self.lang_id:
60
  try:
61
- pred = self.lang_id(text)[0]
62
- detected = pred['label'] if isinstance(pred, dict) else str(pred)
63
- confidence = pred.get('score', 0.5) if isinstance(pred, dict) else 0.5
64
-
65
- # Si confiance faible, utiliser détection par mots-clés
66
- if confidence < 0.8:
67
- return self._keyword_detection(text)
68
-
69
- return self.lang_map.get(detected, 'en')
70
  except:
71
- return self._keyword_detection(text)
72
 
 
73
  return self._keyword_detection(text)
74
 
75
  def _keyword_detection(self, text: str) -> str:
76
- """Détection par mots-clés comme fallback"""
77
- french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement', 'médecin', 'santé']
78
- english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment', 'doctor', 'health']
79
 
80
  text_lower = text.lower()
81
  fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
@@ -83,60 +70,32 @@ class AdvancedLanguageDetector:
83
 
84
  return 'fr' if fr_score > en_score else 'en'
85
 
86
- # === 2. TRADUCTION OPTIMISÉE ===
87
- class OptimizedTranslator:
88
- def __init__(self, model_name=NLLB_MODEL_NAME):
 
89
  try:
90
- self.tokenizer = NllbTokenizer.from_pretrained(model_name)
91
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
92
- model_name,
93
- torch_dtype=torch.float32, # CPU optimized
94
- low_cpu_mem_usage=True
95
- )
96
- self.lang_code_map = {
97
- 'fr': 'fra_Latn', 'en': 'eng_Latn', 'bss': 'bss_Latn',
98
- 'dua': 'dua_Latn', 'ewo': 'ewo_Latn',
99
- }
100
- logger.info("Optimized translator initialized")
101
  except Exception as e:
102
  logger.error(f"Error initializing translator: {str(e)}")
103
- self.tokenizer = None
104
- self.model = None
105
 
106
  @lru_cache(maxsize=256)
107
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
108
  if not text.strip() or source_lang == target_lang:
109
  return text
110
-
111
- if self.tokenizer is None or self.model is None:
112
- return text
113
 
114
- try:
115
- src = self.lang_code_map.get(source_lang, 'eng_Latn')
116
- tgt = self.lang_code_map.get(target_lang, 'eng_Latn')
117
-
118
- self.tokenizer.src_lang = src
119
- inputs = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
120
-
121
- with torch.no_grad():
122
- generated_tokens = self.model.generate(
123
- **inputs,
124
- forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(tgt),
125
- max_length=512,
126
- num_beams=4, # Améliore la qualité
127
- early_stopping=True
128
- )
129
-
130
- result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
131
- return result
132
- except Exception as e:
133
- logger.error(f"Translation error: {str(e)}")
134
- return text
135
 
136
- # === 3. RAG MÉDICAL AVANCÉ ===
137
- class AdvancedMedicalRAG:
138
  def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
139
  try:
 
140
  self.embedder = SentenceTransformer(embedding_model_name)
141
 
142
  if not os.path.exists(records_path):
@@ -146,43 +105,33 @@ class AdvancedMedicalRAG:
146
  with open(records_path, 'r', encoding='utf-8') as f:
147
  self.records = json.load(f)
148
 
149
- # Construction d'indices spécialisés
150
  self.medical_chunks = []
151
- self.educational_chunks = []
152
- self.emergency_chunks = []
153
- self.prevention_chunks = []
154
-
155
- self._build_specialized_chunks()
156
 
157
- # Indices FAISS multiples pour différents types de requêtes
158
- self.medical_index, _ = self._build_faiss_index(self.medical_chunks)
159
- self.edu_index, _ = self._build_faiss_index(self.educational_chunks)
160
- self.emergency_index, _ = self._build_faiss_index(self.emergency_chunks)
161
- self.prevention_index, _ = self._build_faiss_index(self.prevention_chunks)
162
 
163
- logger.info(f"Advanced RAG initialized: {len(self.medical_chunks)} medical, "
164
- f"{len(self.educational_chunks)} educational, {len(self.emergency_chunks)} emergency chunks")
165
 
166
  except Exception as e:
167
- logger.error(f"Error initializing Advanced RAG: {str(e)}")
168
  self._initialize_fallback()
169
 
170
  def _create_sample_records(self, path: str):
171
- """Crée des enregistrements médicaux de base pour la compétition"""
172
  sample_records = [
173
  {
174
  "id": "malaria_001",
175
- "diagnosis": {"en": "Malaria (Plasmodium falciparum)", "fr": "Paludisme (Plasmodium falciparum)"},
176
- "symptoms": {"en": "High fever, chills, headache, nausea, vomiting, fatigue", "fr": "Fièvre élevée, frissons, maux de tête, nausées, vomissements, fatigue"},
177
- "medications": [{"name": {"en": "Artemether-Lumefantrine", "fr": "Artéméther-Luméfantrine"}, "dosage": "20mg/120mg twice daily for 3 days"}],
178
- "care_instructions": {"en": "Complete bed rest, increase fluid intake, complete full medication course, return if symptoms worsen or fever persists after 48 hours", "fr": "Repos complet au lit, augmenter l'apport hydrique, terminer le traitement complet, revenir si les symptômes s'aggravent ou si la fièvre persiste après 48 heures"}
179
  },
180
  {
181
- "id": "diabetes_prevention",
182
- "context_type": "prevention",
183
- "topic": {"en": "Type 2 Diabetes Prevention", "fr": "Prévention du Diabète de Type 2"},
184
- "educational_content": {"en": "Maintain healthy BMI (18.5-24.9), engage in 150 minutes moderate exercise weekly, consume balanced diet rich in fiber and low in processed sugars, regular blood glucose monitoring for high-risk individuals", "fr": "Maintenir un IMC sain (18,5-24,9), pratiquer 150 minutes d'exercice modéré par semaine, consommer une alimentation équilibrée riche en fibres et pauvre en sucres transformés, surveillance régulière de la glycémie pour les personnes à risque"},
185
- "target_group": "Adults over 30, family history of diabetes, sedentary lifestyle"
186
  }
187
  ]
188
 
@@ -190,65 +139,25 @@ class AdvancedMedicalRAG:
190
  json.dump(sample_records, f, ensure_ascii=False, indent=2)
191
 
192
  def _initialize_fallback(self):
193
- """Initialise un système de fallback basique"""
194
- self.medical_chunks = ["General medical consultation and symptom assessment"]
195
- self.educational_chunks = ["Health education and prevention guidelines"]
196
- self.emergency_chunks = ["Emergency medical procedures and protocols"]
197
- self.prevention_chunks = ["Disease prevention and health maintenance"]
198
-
199
  self.medical_index = None
200
- self.edu_index = None
201
- self.emergency_index = None
202
- self.prevention_index = None
203
 
204
- def _build_specialized_chunks(self):
205
- """Construit des chunks spécialisés pour différents types de requêtes médicales"""
206
  for rec in self.records:
207
  try:
208
- # Chunks médicaux (diagnostics, traitements)
209
- if 'diagnosis' in rec:
210
- medical_parts = []
211
- medical_parts.append(f"Condition: {rec['diagnosis'].get('en', '')}")
 
212
 
213
- if 'symptoms' in rec:
214
- medical_parts.append(f"Symptoms: {rec['symptoms'].get('en', '')}")
215
-
216
- if 'medications' in rec:
217
- meds = [f"{m['name'].get('en', '')} ({m.get('dosage', '')})" for m in rec['medications']]
218
- medical_parts.append(f"Treatment: {', '.join(meds)}")
219
-
220
- if 'care_instructions' in rec:
221
- medical_parts.append(f"Care instructions: {rec['care_instructions'].get('en', '')}")
222
-
223
- if medical_parts:
224
- self.medical_chunks.append(". ".join(medical_parts))
225
-
226
- # Chunks éducatifs
227
- if rec.get('context_type') == 'prevention' or 'educational_content' in rec:
228
- edu_parts = []
229
- if 'topic' in rec:
230
- edu_parts.append(f"Topic: {rec['topic'].get('en', '')}")
231
- if 'educational_content' in rec:
232
- edu_parts.append(f"Information: {rec['educational_content'].get('en', '')}")
233
- if 'target_group' in rec:
234
- edu_parts.append(f"Target: {rec['target_group']}")
235
-
236
- if edu_parts:
237
- chunk = ". ".join(edu_parts)
238
- self.educational_chunks.append(chunk)
239
- if 'prevention' in chunk.lower():
240
- self.prevention_chunks.append(chunk)
241
-
242
- # Chunks d'urgence
243
- if rec.get('context_type') == 'emergency_education' or 'emergency' in str(rec).lower():
244
- emergency_parts = []
245
- if 'scenario' in rec:
246
- emergency_parts.append(f"Emergency: {rec['scenario'].get('en', '')}")
247
- if 'action_steps' in rec:
248
- emergency_parts.append(f"Actions: {rec['action_steps'].get('en', '')}")
249
-
250
- if emergency_parts:
251
- self.emergency_chunks.append(". ".join(emergency_parts))
252
 
253
  except Exception as e:
254
  logger.error(f"Error processing record: {str(e)}")
@@ -256,281 +165,152 @@ class AdvancedMedicalRAG:
256
 
257
  def _build_faiss_index(self, chunks):
258
  if not chunks:
259
- return None, None
260
  try:
261
  embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
262
  index = faiss.IndexFlatL2(embeddings.shape[1])
263
  index.add(embeddings)
264
- return index, embeddings
265
  except Exception as e:
266
  logger.error(f"Error building FAISS index: {str(e)}")
267
- return None, None
268
 
269
- def get_smart_contexts(self, question: str, lang: str = "en") -> Dict[str, List[str]]:
270
- """Récupère des contextes intelligents basés sur le type de question"""
271
- question_lower = question.lower()
272
- contexts = {
273
- "medical": [],
274
- "educational": [],
275
- "emergency": [],
276
- "prevention": []
277
- }
278
-
279
  try:
280
- q_emb = self.embedder.encode([question], convert_to_numpy=True)
 
281
 
282
- # Détection du type de question
283
- is_emergency = any(word in question_lower for word in ['emergency', 'urgent', 'severe', 'critical', 'urgence', 'grave'])
284
- is_prevention = any(word in question_lower for word in ['prevent', 'prevention', 'avoid', 'prévenir', 'éviter'])
285
- is_educational = any(word in question_lower for word in ['what is', 'explain', 'how', 'why', "qu'est-ce que", 'expliquer', 'comment', 'pourquoi'])
286
-
287
- # Récupération contextuelle intelligente
288
- if is_emergency and self.emergency_index:
289
- _, I = self.emergency_index.search(q_emb, min(3, len(self.emergency_chunks)))
290
- contexts["emergency"] = [self.emergency_chunks[i] for i in I[0] if i < len(self.emergency_chunks)]
291
-
292
- if is_prevention and self.prevention_index:
293
- _, I = self.prevention_index.search(q_emb, min(2, len(self.prevention_chunks)))
294
- contexts["prevention"] = [self.prevention_chunks[i] for i in I[0] if i < len(self.prevention_chunks)]
295
-
296
- if is_educational and self.edu_index:
297
- _, I = self.edu_index.search(q_emb, min(3, len(self.educational_chunks)))
298
- contexts["educational"] = [self.educational_chunks[i] for i in I[0] if i < len(self.educational_chunks)]
299
-
300
- # Toujours inclure du contexte médical général
301
- if self.medical_index:
302
- n_med = 4 if not any(contexts.values()) else 2
303
- _, I = self.medical_index.search(q_emb, min(n_med, len(self.medical_chunks)))
304
- contexts["medical"] = [self.medical_chunks[i] for i in I[0] if i < len(self.medical_chunks)]
305
 
306
  except Exception as e:
307
- logger.error(f"Error getting smart contexts: {str(e)}")
308
-
309
- return contexts
310
 
311
- # === 4. GÉNÉRATEUR LLM OPTIMISÉ ===
312
- class CompetitionMedicalLLM:
313
  def __init__(self, model_name: str = MODEL_NAME):
314
  self.device = DEVICE
315
- logger.info(f"Loading competition model {model_name} on {self.device}...")
316
 
317
  try:
318
- # Configuration optimale pour DialoGPT-medium sur CPU
319
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
320
- self.model = AutoModelForCausalLM.from_pretrained(
321
- model_name,
322
- torch_dtype=torch.float32,
323
- low_cpu_mem_usage=True,
324
- device_map="auto" if DEVICE == "cuda" else None
325
- )
326
-
327
- # Configuration du tokenizer
328
- if self.tokenizer.pad_token is None:
329
- self.tokenizer.pad_token = self.tokenizer.eos_token
330
-
331
  self.generator = pipeline(
332
  "text-generation",
333
- model=self.model,
334
- tokenizer=self.tokenizer,
335
  device=-1, # CPU
336
- framework="pt"
 
337
  )
338
 
339
- logger.info(f"Competition model {model_name} loaded successfully")
340
  except Exception as e:
341
  logger.error(f"Error loading model: {str(e)}")
342
  self.generator = None
343
 
344
- def generate_expert_response(self, question: str, contexts: Dict[str, List[str]], lang: str = "en") -> str:
345
- """Génère une réponse d'expert médical de niveau compétition"""
346
 
347
  if self.generator is None:
348
- return self._expert_fallback_response(question, contexts, lang)
349
 
350
  try:
351
- # Construction du prompt expert
352
- prompt = self._build_expert_prompt(question, contexts, lang)
 
 
 
 
 
353
 
354
- # Génération avec paramètres optimisés
355
  response = self.generator(
356
  prompt,
357
- max_length=len(prompt) + 300, # Plus long pour des réponses complètes
358
- num_return_sequences=1,
359
  temperature=TEMPERATURE,
360
  top_p=TOP_P,
361
  top_k=TOP_K,
362
  do_sample=True,
363
- pad_token_id=self.tokenizer.eos_token_id,
364
- eos_token_id=self.tokenizer.eos_token_id,
365
- repetition_penalty=1.1, # Évite les répétitions
366
- length_penalty=1.0,
367
- early_stopping=True
368
  )
369
 
370
- # Extraction et nettoyage expert
371
  full_text = response[0]['generated_text']
372
  response_text = full_text[len(prompt):].strip()
373
 
374
- # Post-processing expert
375
- response_text = self._expert_post_process(response_text, lang)
 
 
 
 
 
 
376
 
377
- return response_text
378
 
379
  except Exception as e:
380
- logger.error(f"Error in expert generation: {str(e)}")
381
- return self._expert_fallback_response(question, contexts, lang)
382
-
383
- def _build_expert_prompt(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
384
- """Construit un prompt de niveau expert pour la compétition"""
385
-
386
- # Agrégation intelligente des contextes
387
- context_parts = []
388
- if contexts.get("emergency"):
389
- context_parts.append(f"🚨 Emergency Protocol: {' | '.join(contexts['emergency'][:2])}")
390
- if contexts.get("medical"):
391
- context_parts.append(f"📋 Clinical Information: {' | '.join(contexts['medical'][:2])}")
392
- if contexts.get("prevention"):
393
- context_parts.append(f"🛡️ Prevention Guidelines: {' | '.join(contexts['prevention'][:1])}")
394
- if contexts.get("educational"):
395
- context_parts.append(f"📚 Educational Content: {' | '.join(contexts['educational'][:1])}")
396
-
397
- context_str = "\n".join(context_parts) if context_parts else "General medical consultation context."
398
-
399
- # Prompt structuré pour excellence
400
- if lang == "fr":
401
- prompt = f"""Contexte médical expert:
402
- {context_str}
403
-
404
- Question du patient: {question}
405
-
406
- Réponse médicale experte (structurée et complète):"""
407
- else:
408
- prompt = f"""Expert medical context:
409
- {context_str}
410
-
411
- Patient question: {question}
412
-
413
- Expert medical response (structured and comprehensive):"""
414
-
415
- return prompt
416
 
417
- def _expert_post_process(self, response: str, lang: str) -> str:
418
- """Post-traitement expert de la réponse"""
419
-
420
- # Nettoyage des artifacts
421
- for stop_seq in ["</s>", "\nPatient:", "\nDoctor:", "\nExpert:", "\n\nContext:", "Question:"]:
422
- if stop_seq in response:
423
- response = response.split(stop_seq)[0].strip()
424
-
425
- # Structuration expert
426
- if len(response.split('.')) > 2: # Si réponse assez longue
427
- sentences = [s.strip() for s in response.split('.') if s.strip()]
428
- if len(sentences) >= 3:
429
- response = '. '.join(sentences[:4]) + '.' # Limiter à 4 phrases max
430
-
431
- # Ajout disclaimer expert
432
- disclaimer = {
433
- "en": "\n\n⚕️ Medical Disclaimer: This information is for educational purposes. Always consult a qualified healthcare professional for proper diagnosis and treatment.",
434
- "fr": "\n\n⚕️ Avertissement médical: Cette information est à des fins éducatives. Consultez toujours un professionnel de santé qualifié pour un diagnostic et traitement appropriés."
435
- }
436
-
437
- if "consult" not in response.lower() and "disclaimer" not in response.lower():
438
- response += disclaimer.get(lang, disclaimer["en"])
439
-
440
- return response.strip()
441
-
442
- def _expert_fallback_response(self, question: str, contexts: Dict[str, List[str]], lang: str) -> str:
443
- """Réponse de fallback de niveau expert"""
444
 
445
  templates = {
446
- "en": {
447
- "intro": "Based on medical expertise and available clinical information:",
448
- "structure": "\n\n🔍 Assessment: This appears to be a medical inquiry requiring professional evaluation.\n\n💡 General Guidance: Monitor symptoms, maintain proper hygiene, stay hydrated, and seek appropriate medical care.\n\n⚠️ Important: For accurate diagnosis and treatment, please consult with a healthcare professional.",
449
- "context_available": "According to medical literature and clinical guidelines: "
450
- },
451
- "fr": {
452
- "intro": "Sur la base de l'expertise médicale et des informations cliniques disponibles:",
453
- "structure": "\n\n🔍 Évaluation: Il s'agit d'une demande médicale nécessitant une évaluation professionnelle.\n\n💡 Guidance générale: Surveillez les symptômes, maintenez une hygiène appropriée, restez hydraté et consultez un professionnel de santé.\n\n⚠️ Important: Pour un diagnostic et traitement précis, veuillez consulter un professionnel de santé.",
454
- "context_available": "Selon la littérature médicale et les directives cliniques: "
455
- }
456
  }
457
 
458
- template = templates.get(lang, templates["en"])
459
- response = template["intro"]
460
-
461
- # Intégrer contextes si disponibles
462
- all_contexts = []
463
- for context_list in contexts.values():
464
- all_contexts.extend(context_list)
465
-
466
- if all_contexts:
467
- response += f" {template['context_available']}{' | '.join(all_contexts[:2])}"
468
-
469
- response += template["structure"]
470
-
471
- return response
472
 
473
- # === PIPELINE PRINCIPAL COMPÉTITION ===
474
- class CompetitionMedicalAIPipeline:
475
  def __init__(self):
476
- logger.info("🏆 Initializing COMPETITION Medical AI Pipeline...")
477
  try:
478
- self.lang_detector = AdvancedLanguageDetector()
479
- self.translator = OptimizedTranslator()
480
- self.rag = AdvancedMedicalRAG()
481
- self.llm = CompetitionMedicalLLM()
482
- logger.info("🎯 Competition Medical AI Pipeline ready for excellence!")
483
  except Exception as e:
484
- logger.error(f"Error initializing competition pipeline: {str(e)}")
485
  raise
486
 
487
  def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
488
- """Traitement de niveau compétition"""
489
  try:
490
  if not question or not question.strip():
491
  return self._empty_question_response(user_lang)
492
 
493
- # Détection langue avancée
494
  detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
495
- logger.info(f"🎯 Processing competition-level question in {detected_lang}")
496
-
497
- # Traduction si nécessaire avec qualité optimale
498
- question_en = question
499
- if detected_lang != "en":
500
- question_en = self.translator.translate(question, detected_lang, "en")
501
-
502
- # RAG intelligent multi-contexte
503
- smart_contexts = self.rag.get_smart_contexts(question_en, "en")
504
-
505
- # Génération experte
506
- response_en = self.llm.generate_expert_response(question_en, smart_contexts, "en")
507
 
508
- # Traduction retour avec qualité optimale
509
- final_response = response_en
510
- if detected_lang != "en":
511
- final_response = self.translator.translate(response_en, "en", detected_lang)
512
 
513
- # Contextes utilisés pour transparence
514
- all_contexts = []
515
- for context_list in smart_contexts.values():
516
- all_contexts.extend(context_list)
517
 
518
  return {
519
- "response": final_response,
520
  "source_lang": detected_lang,
521
- "context_used": all_contexts[:5], # Top 5 contextes
522
- "confidence": "high" # Indicateur de qualité
523
  }
524
 
525
  except Exception as e:
526
- logger.error(f"Competition processing error: {str(e)}")
527
  return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
528
 
529
  def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
530
- """Réponse pour question vide"""
531
  responses = {
532
- "en": "Please provide a medical question for me to assist you with professional healthcare guidance.",
533
- "fr": "Veuillez poser une question médicale pour que je puisse vous fournir des conseils de santé professionnels."
534
  }
535
  lang = user_lang if user_lang != "auto" else "en"
536
  return {
@@ -541,17 +321,18 @@ class CompetitionMedicalAIPipeline:
541
  }
542
 
543
  def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
544
- """Réponse d'erreur professionnelle"""
545
  responses = {
546
- "en": "I apologize, but I'm experiencing technical difficulties. Please try rephrasing your medical question, and I'll provide you with professional healthcare guidance.",
547
- "fr": "Je m'excuse, mais je rencontre des difficultés techniques. Veuillez reformuler votre question médicale, et je vous fournirai des conseils de santé professionnels."
548
  }
549
  return {
550
  "response": responses.get(lang, responses["en"]),
551
  "source_lang": lang,
552
  "context_used": [],
553
- "confidence": "medium"
554
  }
555
 
556
- # Alias pour compatibilité
557
- MedicalAIPipeline = CompetitionMedicalAIPipeline
 
 
1
+ # medical_ai.py - LIGHTWEIGHT VERSION FOR HUGGING FACE SPACES
2
 
3
  import os
4
  import json
 
7
  from sentence_transformers import SentenceTransformer
8
  import faiss
9
  from functools import lru_cache
10
+ from transformers import pipeline, AutoTokenizer
 
11
  import torch
12
  from typing import Optional
13
  import logging
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ # === LIGHTWEIGHT CONFIGURATION FOR SPACES ===
21
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
22
+ # Use a much lighter model that works on Spaces
23
+ MODEL_NAME = "microsoft/DialoGPT-small" # Changed from medium to small
 
24
  PATIENT_RECORDS_PATH = "patient_records.json"
25
 
26
+ # Optimized for Spaces CPU limits
27
  DEVICE = "cpu"
28
+ MAX_LENGTH = 256 # Reduced for faster processing
29
+ TEMPERATURE = 0.7
30
  TOP_P = 0.9
31
+ TOP_K = 40
32
 
33
+ # === 1. SIMPLE LANGUAGE DETECTOR ===
34
+ class SimpleLanguageDetector:
35
  def __init__(self):
36
+ # Use lightweight langdetect instead of heavy ML model
37
  try:
38
+ from langdetect import detect
39
+ self.detect_func = detect
40
+ logger.info("Simple language detector initialized")
41
+ except ImportError:
42
+ logger.warning("langdetect not available, using keyword detection")
43
+ self.detect_func = None
 
 
 
 
 
 
 
44
 
45
  @lru_cache(maxsize=256)
46
  def detect_language(self, text: str) -> str:
47
  if not text.strip():
48
  return 'en'
49
 
50
+ # Try langdetect first
51
+ if self.detect_func:
52
  try:
53
+ detected = self.detect_func(text)
54
+ if detected in ['fr', 'en']:
55
+ return detected
 
 
 
 
 
 
56
  except:
57
+ pass
58
 
59
+ # Fallback to keyword detection
60
  return self._keyword_detection(text)
61
 
62
  def _keyword_detection(self, text: str) -> str:
63
+ """Keyword-based detection as fallback"""
64
+ french_indicators = ['que', 'quoi', 'comment', 'pourquoi', 'symptômes', 'maladie', 'traitement']
65
+ english_indicators = ['what', 'how', 'why', 'symptoms', 'disease', 'treatment']
66
 
67
  text_lower = text.lower()
68
  fr_score = sum(2 if indicator in text_lower else 0 for indicator in french_indicators)
 
70
 
71
  return 'fr' if fr_score > en_score else 'en'
72
 
73
+ # === 2. SIMPLE TRANSLATOR ===
74
+ class SimpleTranslator:
75
+ def __init__(self):
76
+ # Use a lightweight translation approach
77
  try:
78
+ # Only load if really needed
79
+ self.translator = None
80
+ logger.info("Simple translator initialized")
 
 
 
 
 
 
 
 
81
  except Exception as e:
82
  logger.error(f"Error initializing translator: {str(e)}")
83
+ self.translator = None
 
84
 
85
  @lru_cache(maxsize=256)
86
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
87
  if not text.strip() or source_lang == target_lang:
88
  return text
 
 
 
89
 
90
+ # For Spaces demo, we'll use simple template responses
91
+ # In production, you'd want proper translation
92
+ return text # Simplified for demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # === 3. LIGHTWEIGHT MEDICAL RAG ===
95
+ class LightweightMedicalRAG:
96
  def __init__(self, embedding_model_name=EMBEDDING_MODEL_NAME, records_path=PATIENT_RECORDS_PATH):
97
  try:
98
+ logger.info("Loading lightweight embedder...")
99
  self.embedder = SentenceTransformer(embedding_model_name)
100
 
101
  if not os.path.exists(records_path):
 
105
  with open(records_path, 'r', encoding='utf-8') as f:
106
  self.records = json.load(f)
107
 
108
+ # Build simple medical chunks
109
  self.medical_chunks = []
110
+ self._build_medical_chunks()
 
 
 
 
111
 
112
+ # Single FAISS index
113
+ self.medical_index = self._build_faiss_index(self.medical_chunks)
 
 
 
114
 
115
+ logger.info(f"Lightweight RAG initialized: {len(self.medical_chunks)} chunks")
 
116
 
117
  except Exception as e:
118
+ logger.error(f"Error initializing RAG: {str(e)}")
119
  self._initialize_fallback()
120
 
121
  def _create_sample_records(self, path: str):
122
+ """Create basic medical records"""
123
  sample_records = [
124
  {
125
  "id": "malaria_001",
126
+ "diagnosis": {"en": "Malaria", "fr": "Paludisme"},
127
+ "symptoms": {"en": "Fever, chills, headache", "fr": "Fièvre, frissons, maux de tête"},
128
+ "treatment": {"en": "Antimalarial medication and rest", "fr": "Médicaments antipaludiques et repos"}
 
129
  },
130
  {
131
+ "id": "diabetes_001",
132
+ "diagnosis": {"en": "Diabetes", "fr": "Diabète"},
133
+ "symptoms": {"en": "Increased thirst, frequent urination", "fr": "Soif excessive, mictions fréquentes"},
134
+ "treatment": {"en": "Diet control and medication", "fr": "Contrôle alimentaire et médicaments"}
 
135
  }
136
  ]
137
 
 
139
  json.dump(sample_records, f, ensure_ascii=False, indent=2)
140
 
141
  def _initialize_fallback(self):
142
+ """Initialize fallback system"""
143
+ self.medical_chunks = [
144
+ "General medical consultation and symptom assessment",
145
+ "Common tropical diseases like malaria require immediate medical attention",
146
+ "Diabetes management involves diet control and regular monitoring"
147
+ ]
148
  self.medical_index = None
 
 
 
149
 
150
+ def _build_medical_chunks(self):
151
+ """Build simple medical chunks"""
152
  for rec in self.records:
153
  try:
154
+ if 'diagnosis' in rec and 'symptoms' in rec:
155
+ chunk = f"Condition: {rec['diagnosis'].get('en', '')}. "
156
+ chunk += f"Symptoms: {rec['symptoms'].get('en', '')}. "
157
+ if 'treatment' in rec:
158
+ chunk += f"Treatment: {rec['treatment'].get('en', '')}"
159
 
160
+ self.medical_chunks.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  except Exception as e:
163
  logger.error(f"Error processing record: {str(e)}")
 
165
 
166
  def _build_faiss_index(self, chunks):
167
  if not chunks:
168
+ return None
169
  try:
170
  embeddings = self.embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
171
  index = faiss.IndexFlatL2(embeddings.shape[1])
172
  index.add(embeddings)
173
+ return index
174
  except Exception as e:
175
  logger.error(f"Error building FAISS index: {str(e)}")
176
+ return None
177
 
178
+ def get_contexts(self, question: str, lang: str = "en") -> List[str]:
179
+ """Get relevant medical contexts"""
 
 
 
 
 
 
 
 
180
  try:
181
+ if self.medical_index is None:
182
+ return self.medical_chunks[:2]
183
 
184
+ q_emb = self.embedder.encode([question], convert_to_numpy=True)
185
+ _, I = self.medical_index.search(q_emb, min(3, len(self.medical_chunks)))
186
+ return [self.medical_chunks[i] for i in I[0] if i < len(self.medical_chunks)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  except Exception as e:
189
+ logger.error(f"Error getting contexts: {str(e)}")
190
+ return self.medical_chunks[:2]
 
191
 
192
+ # === 4. LIGHTWEIGHT LLM ===
193
+ class LightweightMedicalLLM:
194
  def __init__(self, model_name: str = MODEL_NAME):
195
  self.device = DEVICE
196
+ logger.info(f"Loading lightweight model {model_name}...")
197
 
198
  try:
199
+ # Use pipeline for simplicity
 
 
 
 
 
 
 
 
 
 
 
 
200
  self.generator = pipeline(
201
  "text-generation",
202
+ model=model_name,
 
203
  device=-1, # CPU
204
+ torch_dtype=torch.float32,
205
+ model_kwargs={"low_cpu_mem_usage": True}
206
  )
207
 
208
+ logger.info(f"Lightweight model {model_name} loaded successfully")
209
  except Exception as e:
210
  logger.error(f"Error loading model: {str(e)}")
211
  self.generator = None
212
 
213
+ def generate_response(self, question: str, contexts: List[str], lang: str = "en") -> str:
214
+ """Generate medical response"""
215
 
216
  if self.generator is None:
217
+ return self._fallback_response(question, contexts, lang)
218
 
219
  try:
220
+ # Build simple prompt
221
+ context_str = " | ".join(contexts[:2]) if contexts else "General medical consultation"
222
+
223
+ if lang == "fr":
224
+ prompt = f"Contexte médical: {context_str}\n\nQuestion: {question}\n\nRéponse médicale:"
225
+ else:
226
+ prompt = f"Medical context: {context_str}\n\nQuestion: {question}\n\nMedical response:"
227
 
228
+ # Generate with conservative settings for Spaces
229
  response = self.generator(
230
  prompt,
231
+ max_length=len(prompt) + 150, # Shorter for faster processing
 
232
  temperature=TEMPERATURE,
233
  top_p=TOP_P,
234
  top_k=TOP_K,
235
  do_sample=True,
236
+ pad_token_id=self.generator.tokenizer.eos_token_id
 
 
 
 
237
  )
238
 
239
+ # Extract response
240
  full_text = response[0]['generated_text']
241
  response_text = full_text[len(prompt):].strip()
242
 
243
+ # Add medical disclaimer
244
+ disclaimer = {
245
+ "en": "\n\n⚕️ Medical Disclaimer: Consult a healthcare professional for proper diagnosis.",
246
+ "fr": "\n\n⚕️ Avertissement médical: Consultez un professionnel de santé pour un diagnostic approprié."
247
+ }
248
+
249
+ if "disclaimer" not in response_text.lower():
250
+ response_text += disclaimer.get(lang, disclaimer["en"])
251
 
252
+ return response_text.strip()
253
 
254
  except Exception as e:
255
+ logger.error(f"Error in generation: {str(e)}")
256
+ return self._fallback_response(question, contexts, lang)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ def _fallback_response(self, question: str, contexts: List[str], lang: str) -> str:
259
+ """Fallback response for errors"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  templates = {
262
+ "en": "Based on medical knowledge: This requires professional medical evaluation. Please consult with a healthcare provider for proper diagnosis and treatment. Stay hydrated and monitor symptoms.",
263
+ "fr": "Selon les connaissances médicales: Ceci nécessite une évaluation médicale professionnelle. Veuillez consulter un professionnel de santé pour un diagnostic et traitement appropriés."
 
 
 
 
 
 
 
 
264
  }
265
 
266
+ return templates.get(lang, templates["en"])
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # === MAIN PIPELINE ===
269
+ class SpacesMedicalAIPipeline:
270
  def __init__(self):
271
+ logger.info("🚀 Initializing Spaces Medical AI Pipeline...")
272
  try:
273
+ self.lang_detector = SimpleLanguageDetector()
274
+ self.translator = SimpleTranslator()
275
+ self.rag = LightweightMedicalRAG()
276
+ self.llm = LightweightMedicalLLM()
277
+ logger.info(" Spaces Medical AI Pipeline ready!")
278
  except Exception as e:
279
+ logger.error(f"Error initializing pipeline: {str(e)}")
280
  raise
281
 
282
  def process(self, question: str, user_lang: str = "auto", conversation_history: list = None) -> Dict[str, Any]:
283
+ """Process medical question for Spaces"""
284
  try:
285
  if not question or not question.strip():
286
  return self._empty_question_response(user_lang)
287
 
288
+ # Detect language
289
  detected_lang = self.lang_detector.detect_language(question) if user_lang == "auto" else user_lang
290
+ logger.info(f"Processing question in {detected_lang}")
 
 
 
 
 
 
 
 
 
 
 
291
 
292
+ # Get medical contexts
293
+ contexts = self.rag.get_contexts(question, detected_lang)
 
 
294
 
295
+ # Generate response
296
+ response = self.llm.generate_response(question, contexts, detected_lang)
 
 
297
 
298
  return {
299
+ "response": response,
300
  "source_lang": detected_lang,
301
+ "context_used": contexts[:3],
302
+ "confidence": "medium"
303
  }
304
 
305
  except Exception as e:
306
+ logger.error(f"Processing error: {str(e)}")
307
  return self._error_response(str(e), user_lang if user_lang != "auto" else "en")
308
 
309
  def _empty_question_response(self, user_lang: str) -> Dict[str, Any]:
310
+ """Response for empty question"""
311
  responses = {
312
+ "en": "Please provide a medical question for consultation.",
313
+ "fr": "Veuillez poser une question médicale pour consultation."
314
  }
315
  lang = user_lang if user_lang != "auto" else "en"
316
  return {
 
321
  }
322
 
323
  def _error_response(self, error: str, lang: str) -> Dict[str, Any]:
324
+ """Error response"""
325
  responses = {
326
+ "en": "I'm experiencing technical difficulties. Please try rephrasing your medical question.",
327
+ "fr": "Je rencontre des difficultés techniques. Veuillez reformuler votre question médicale."
328
  }
329
  return {
330
  "response": responses.get(lang, responses["en"]),
331
  "source_lang": lang,
332
  "context_used": [],
333
+ "confidence": "low"
334
  }
335
 
336
+ # Compatibility aliases
337
+ CompetitionMedicalAIPipeline = SpacesMedicalAIPipeline
338
+ MedicalAIPipeline = SpacesMedicalAIPipeline
requirements.txt CHANGED
@@ -1,39 +1,34 @@
1
- # FastAPI Medical AI - Requirements
2
  # Core FastAPI dependencies
3
  fastapi==0.104.1
4
  uvicorn[standard]==0.24.0
5
  python-multipart==0.0.6
6
  pydantic==2.4.2
7
 
8
- # ML and AI models
9
  transformers==4.35.2
10
- torch==2.1.0
11
  sentence-transformers==2.2.2
12
  faiss-cpu==1.7.4
13
- faster-whisper==0.9.0
14
 
15
- # Audio processing
16
  librosa==0.10.1
17
  soundfile==0.12.1
18
  numpy==1.24.3
19
 
20
- # HTTP requests for testing
21
- requests==2.31.0
 
 
22
 
23
  # Language processing
24
  sentencepiece==0.1.99
25
  langdetect==1.0.9
26
 
27
- # Performance optimizations
28
- accelerate==0.24.1
29
- optimum==1.13.2
30
 
31
  # System monitoring
32
  psutil==5.9.6
33
 
34
- # Development and testing
35
- pytest==7.4.3
36
- pytest-asyncio==0.21.1
37
-
38
- # Optional: For production deployment
39
- gunicorn==21.2.0
 
1
+ # FastAPI Medical AI - Fixed Requirements for Hugging Face Spaces
2
  # Core FastAPI dependencies
3
  fastapi==0.104.1
4
  uvicorn[standard]==0.24.0
5
  python-multipart==0.0.6
6
  pydantic==2.4.2
7
 
8
+ # Lightweight ML models for Spaces
9
  transformers==4.35.2
10
+ torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu
11
  sentence-transformers==2.2.2
12
  faiss-cpu==1.7.4
 
13
 
14
+ # Audio processing (lightweight versions)
15
  librosa==0.10.1
16
  soundfile==0.12.1
17
  numpy==1.24.3
18
 
19
+ # Remove heavy dependencies that cause timeouts
20
+ # faster-whisper==0.9.0 # Too heavy for Spaces
21
+ # accelerate==0.24.1 # Not needed for CPU
22
+ # optimum==1.13.2 # Not needed for basic setup
23
 
24
  # Language processing
25
  sentencepiece==0.1.99
26
  langdetect==1.0.9
27
 
28
+ # HTTP requests
29
+ requests==2.31.0
 
30
 
31
  # System monitoring
32
  psutil==5.9.6
33
 
34
+ # Keep lightweight for Spaces deployment