Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import pandas as pd | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer | |
| import gradio as gr | |
| import cv2 | |
| import traceback | |
| from datetime import datetime | |
| import re | |
| import random | |
| import functools | |
| import gc | |
| from collections import OrderedDict | |
| import json | |
| import sys | |
| import time | |
| from tqdm.auto import tqdm | |
| import warnings | |
| import matplotlib.pyplot as plt | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Any, Union | |
| import base64 | |
| import io | |
| # Suppress unnecessary warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # === Configuration === | |
| class Config: | |
| """Configuration for MediQuery system""" | |
| # Model configuration | |
| IMAGE_MODEL = "chexnet" # Options: "chexnet", "densenet" | |
| TEXT_MODEL = "biobert" # Options: "biobert", "clinicalbert" | |
| GEN_MODEL = "flan-t5-base-finetuned" # Base generation model | |
| # Resource management | |
| CACHE_SIZE = 50 # Reduced from 200 for deployment | |
| CACHE_EXPIRY_TIME = 1800 # Cache expiry time in seconds (30 minutes) | |
| LAZY_LOADING = True # Enable lazy loading of models | |
| USE_HALF_PRECISION = True # Use half precision for models if available | |
| # Feature flags | |
| DEBUG = True # Enable detailed debugging | |
| PHI_DETECTION_ENABLED = True # Enable PHI detection | |
| ANATOMY_MAPPING_ENABLED = True # Enable anatomical mapping | |
| # Thresholds and parameters | |
| CONFIDENCE_THRESHOLD = 0.4 # Threshold for flagging low confidence | |
| TOP_K_RETRIEVAL = 10 # Reduced from 30 for deployment | |
| MAX_CONTEXT_DOCS = 3 # Reduced from 5 for deployment | |
| # Advanced retrieval settings | |
| DYNAMIC_RERANKING = True # Dynamically adjust reranking weights | |
| DIVERSITY_PENALTY = 0.1 # Penalty for duplicate content | |
| # Performance optimization | |
| BATCH_SIZE = 1 # Reduced from 4 for deployment | |
| OPTIMIZE_MEMORY = True # Optimize memory usage | |
| USE_CACHING = True # Use caching for embeddings and queries | |
| # Path settings | |
| DEFAULT_KNOWLEDGE_BASE_DIR = "./knowledge_base" | |
| DEFAULT_MODEL_PATH = "./models/flan-t5-finetuned" | |
| LOG_DIR = "./logs" | |
| # Advanced settings | |
| EMBEDDING_AGGREGATION = "weighted_avg" # Options: "avg", "weighted_avg", "cls", "pooled" | |
| EMBEDDING_NORMALIZE = True # Normalize embeddings to unit length | |
| # Error recovery settings | |
| MAX_RETRIES = 2 # Reduced from 3 for deployment | |
| RECOVERY_WAIT_TIME = 1 # Seconds to wait between retries | |
| # Set up logging with improved formatting | |
| os.makedirs(Config.LOG_DIR, exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.DEBUG if Config.DEBUG else logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(os.path.join(Config.LOG_DIR, f"mediquery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger("MediQuery") | |
| def debug_print(msg): | |
| """Print and log debug messages""" | |
| if Config.DEBUG: | |
| logger.debug(msg) | |
| print(f"DEBUG: {msg}") | |
| # === Helper Functions for Conditions === | |
| def get_mimic_cxr_conditions(): | |
| """Return the comprehensive list of conditions in MIMIC-CXR dataset""" | |
| return [ | |
| "atelectasis", | |
| "cardiomegaly", | |
| "consolidation", | |
| "edema", | |
| "enlarged cardiomediastinum", | |
| "fracture", | |
| "lung lesion", | |
| "lung opacity", | |
| "no finding", | |
| "pleural effusion", | |
| "pleural other", | |
| "pneumonia", | |
| "pneumothorax", | |
| "support devices" | |
| ] | |
| def get_condition_synonyms(): | |
| """Return synonyms for conditions to improve matching""" | |
| return { | |
| "atelectasis": ["atelectatic change", "collapsed lung", "lung collapse"], | |
| "cardiomegaly": ["enlarged heart", "cardiac enlargement", "heart enlargement"], | |
| "consolidation": ["airspace opacity", "air-space opacity", "alveolar opacity"], | |
| "edema": ["pulmonary edema", "fluid overload", "vascular congestion"], | |
| "fracture": ["broken bone", "bone fracture", "rib fracture"], | |
| "lung opacity": ["pulmonary opacity", "opacification", "lung opacification"], | |
| "pleural effusion": ["pleural fluid", "fluid in pleural space", "effusion"], | |
| "pneumonia": ["pulmonary infection", "lung infection", "bronchopneumonia"], | |
| "pneumothorax": ["air in pleural space", "collapsed lung", "ptx"], | |
| "support devices": ["tube", "line", "catheter", "pacemaker", "device"] | |
| } | |
| def get_anatomical_regions(): | |
| """Return mapping of anatomical regions with descriptions and conditions""" | |
| return { | |
| "upper_right_lung": { | |
| "description": "Upper right lung field", | |
| "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] | |
| }, | |
| "upper_left_lung": { | |
| "description": "Upper left lung field", | |
| "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] | |
| }, | |
| "middle_right_lung": { | |
| "description": "Middle right lung field", | |
| "conditions": ["pneumonia", "lung opacity", "atelectasis"] | |
| }, | |
| "lower_right_lung": { | |
| "description": "Lower right lung field", | |
| "conditions": ["pneumonia", "pleural effusion", "atelectasis"] | |
| }, | |
| "lower_left_lung": { | |
| "description": "Lower left lung field", | |
| "conditions": ["pneumonia", "pleural effusion", "atelectasis"] | |
| }, | |
| "heart": { | |
| "description": "Cardiac silhouette", | |
| "conditions": ["cardiomegaly", "enlarged cardiomediastinum"] | |
| }, | |
| "hilar": { | |
| "description": "Hilar regions", | |
| "conditions": ["enlarged cardiomediastinum", "adenopathy"] | |
| }, | |
| "costophrenic_angles": { | |
| "description": "Costophrenic angles", | |
| "conditions": ["pleural effusion", "pneumothorax"] | |
| }, | |
| "spine": { | |
| "description": "Spine", | |
| "conditions": ["fracture", "degenerative changes"] | |
| }, | |
| "diaphragm": { | |
| "description": "Diaphragm", | |
| "conditions": ["elevated diaphragm", "flattened diaphragm"] | |
| } | |
| } | |
| # === PHI Detection and Anonymization === | |
| def detect_phi(text): | |
| """Detect potential PHI (Protected Health Information) in text""" | |
| # Patterns for PHI detection | |
| patterns = { | |
| 'name': r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', | |
| 'mrn': r'\b[A-Z]{0,3}[0-9]{4,10}\b', | |
| 'ssn': r'\b[0-9]{3}[-]?[0-9]{2}[-]?[0-9]{4}\b', | |
| 'date': r'\b(0?[1-9]|1[0-2])[\/\-](0?[1-9]|[12]\d|3[01])[\/\-](19|20)\d{2}\b', | |
| 'phone': r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b', | |
| 'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', | |
| 'address': r'\b\d+\s+[A-Z][a-z]+\s+[A-Z][a-z]+\.?\b' | |
| } | |
| # Check each pattern | |
| phi_detected = {} | |
| for phi_type, pattern in patterns.items(): | |
| matches = re.findall(pattern, text) | |
| if matches: | |
| phi_detected[phi_type] = matches | |
| return phi_detected | |
| def anonymize_text(text): | |
| """Replace potential PHI with [REDACTED]""" | |
| if not text: | |
| return "" | |
| if not Config.PHI_DETECTION_ENABLED: | |
| return text | |
| try: | |
| # Detect PHI | |
| phi_detected = detect_phi(text) | |
| # Replace PHI with [REDACTED] | |
| anonymized = text | |
| for phi_type, matches in phi_detected.items(): | |
| for match in matches: | |
| anonymized = anonymized.replace(match, "[REDACTED]") | |
| return anonymized | |
| except Exception as e: | |
| debug_print(f"Error in anonymize_text: {str(e)}") | |
| return text | |
| # === LRU Cache Implementation with Enhanced Features === | |
| class LRUCache: | |
| """LRU (Least Recently Used) Cache implementation with TTL and size tracking""" | |
| def __init__(self, capacity=Config.CACHE_SIZE, expiry_time=Config.CACHE_EXPIRY_TIME): | |
| self.cache = OrderedDict() | |
| self.capacity = capacity | |
| self.expiry_time = expiry_time # in seconds | |
| self.timestamps = {} | |
| self.size_tracking = { | |
| "current_size_bytes": 0, | |
| "max_size_bytes": 0, | |
| "items_evicted": 0, | |
| "cache_hits": 0, | |
| "cache_misses": 0 | |
| } | |
| def get(self, key): | |
| """Get item from cache with statistics tracking""" | |
| if key not in self.cache: | |
| self.size_tracking["cache_misses"] += 1 | |
| return None | |
| # Check expiry | |
| if self.is_expired(key): | |
| self._remove_with_tracking(key) | |
| self.size_tracking["cache_misses"] += 1 | |
| return None | |
| # Move to end (recently used) | |
| self.size_tracking["cache_hits"] += 1 | |
| value = self.cache.pop(key) | |
| self.cache[key] = value | |
| return value | |
| def put(self, key, value): | |
| """Add item to cache with size tracking""" | |
| # Calculate approximate size of the value | |
| value_size = self._estimate_size(value) | |
| if key in self.cache: | |
| old_value = self.cache.pop(key) | |
| old_size = self._estimate_size(old_value) | |
| self.size_tracking["current_size_bytes"] -= old_size | |
| # Make space if needed | |
| while len(self.cache) >= self.capacity or ( | |
| Config.OPTIMIZE_MEMORY and | |
| self.size_tracking["current_size_bytes"] + value_size > 1e9 # 1 GB limit | |
| ): | |
| self._evict_least_recently_used() | |
| # Add new item and timestamp | |
| self.cache[key] = value | |
| self.timestamps[key] = datetime.now().timestamp() | |
| self.size_tracking["current_size_bytes"] += value_size | |
| # Update max size | |
| if self.size_tracking["current_size_bytes"] > self.size_tracking["max_size_bytes"]: | |
| self.size_tracking["max_size_bytes"] = self.size_tracking["current_size_bytes"] | |
| def is_expired(self, key): | |
| """Check if item has expired""" | |
| if key not in self.timestamps: | |
| return True | |
| current_time = datetime.now().timestamp() | |
| return (current_time - self.timestamps[key]) > self.expiry_time | |
| def _evict_least_recently_used(self): | |
| """Remove least recently used item with tracking""" | |
| if not self.cache: | |
| return | |
| # Get oldest item | |
| key, value = self.cache.popitem(last=False) | |
| # Remove from timestamps and update tracking | |
| self._remove_with_tracking(key) | |
| def _remove_with_tracking(self, key): | |
| """Remove item with size tracking""" | |
| if key in self.cache: | |
| value = self.cache.pop(key) | |
| value_size = self._estimate_size(value) | |
| self.size_tracking["current_size_bytes"] -= value_size | |
| self.size_tracking["items_evicted"] += 1 | |
| if key in self.timestamps: | |
| self.timestamps.pop(key) | |
| def remove(self, key): | |
| """Remove item from cache""" | |
| self._remove_with_tracking(key) | |
| def clear(self): | |
| """Clear the cache""" | |
| self.cache.clear() | |
| self.timestamps.clear() | |
| self.size_tracking["current_size_bytes"] = 0 | |
| def get_stats(self): | |
| """Get cache statistics""" | |
| return { | |
| "size_bytes": self.size_tracking["current_size_bytes"], | |
| "max_size_bytes": self.size_tracking["max_size_bytes"], | |
| "items": len(self.cache), | |
| "capacity": self.capacity, | |
| "items_evicted": self.size_tracking["items_evicted"], | |
| "hit_rate": self.size_tracking["cache_hits"] / | |
| (self.size_tracking["cache_hits"] + self.size_tracking["cache_misses"] + 1e-8) | |
| } | |
| def _estimate_size(self, obj): | |
| """Estimate memory size of an object in bytes""" | |
| if obj is None: | |
| return 0 | |
| if isinstance(obj, np.ndarray): | |
| return obj.nbytes | |
| elif isinstance(obj, torch.Tensor): | |
| return obj.element_size() * obj.nelement() | |
| elif isinstance(obj, (str, bytes)): | |
| return len(obj) | |
| elif isinstance(obj, (list, tuple)): | |
| return sum(self._estimate_size(x) for x in obj) | |
| elif isinstance(obj, dict): | |
| return sum(self._estimate_size(k) + self._estimate_size(v) for k, v in obj.items()) | |
| else: | |
| # Fallback - rough estimate | |
| return sys.getsizeof(obj) | |
| # === Improved Lazy Model Loading === | |
| class LazyModel: | |
| """Lazy loading wrapper for models with proper method forwarding and error recovery""" | |
| def __init__(self, model_name, model_class, device, **kwargs): | |
| self.model_name = model_name | |
| self.model_class = model_class | |
| self.device = device | |
| self.kwargs = kwargs | |
| self._model = None | |
| self.last_error = None | |
| self.last_used = datetime.now() | |
| debug_print(f"LazyModel initialized for {model_name}") | |
| def _ensure_loaded(self, retries=Config.MAX_RETRIES): | |
| """Ensure model is loaded with retry mechanism""" | |
| if self._model is None: | |
| debug_print(f"Lazy loading model: {self.model_name}") | |
| for attempt in range(retries): | |
| try: | |
| self._model = self.model_class.from_pretrained(self.model_name, **self.kwargs) | |
| # Apply memory optimizations | |
| if Config.OPTIMIZE_MEMORY: | |
| # Convert to half precision if available and enabled | |
| if Config.USE_HALF_PRECISION and self.device.type == 'cuda' and hasattr(self._model, 'half'): | |
| self._model = self._model.half() | |
| debug_print(f"Using half precision for {self.model_name}") | |
| self._model = self._model.to(self.device) | |
| self._model.eval() # Set to evaluation mode | |
| debug_print(f"Model {self.model_name} loaded successfully") | |
| self.last_error = None | |
| break | |
| except Exception as e: | |
| self.last_error = str(e) | |
| debug_print(f"Error loading model {self.model_name} (attempt {attempt+1}/{retries}): {str(e)}") | |
| if attempt < retries - 1: | |
| # Wait before retrying | |
| time.sleep(Config.RECOVERY_WAIT_TIME) | |
| else: | |
| raise RuntimeError(f"Failed to load model {self.model_name} after {retries} attempts: {str(e)}") | |
| # Update last used timestamp | |
| self.last_used = datetime.now() | |
| return self._model | |
| def __call__(self, *args, **kwargs): | |
| """Call the model""" | |
| model = self._ensure_loaded() | |
| return model(*args, **kwargs) | |
| # Forward common model methods | |
| def generate(self, *args, **kwargs): | |
| """Forward generate method to model with error recovery""" | |
| model = self._ensure_loaded() | |
| try: | |
| return model.generate(*args, **kwargs) | |
| except Exception as e: | |
| # If generation fails, try reloading the model once | |
| debug_print(f"Generation failed, reloading model: {str(e)}") | |
| self.unload() | |
| model = self._ensure_loaded() | |
| return model.generate(*args, **kwargs) | |
| def to(self, device): | |
| """Move model to specified device""" | |
| self.device = device | |
| if self._model is not None: | |
| self._model = self._model.to(device) | |
| return self | |
| def eval(self): | |
| """Set model to evaluation mode""" | |
| if self._model is not None: | |
| self._model.eval() | |
| return self | |
| def unload(self): | |
| """Unload model from memory""" | |
| if self._model is not None: | |
| del self._model | |
| self._model = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| debug_print(f"Model {self.model_name} unloaded") | |
| # === MediQuery Core System === | |
| class MediQuery: | |
| """Core MediQuery system for medical image and text analysis""" | |
| def __init__(self, knowledge_base_dir=Config.DEFAULT_KNOWLEDGE_BASE_DIR, model_path=Config.DEFAULT_MODEL_PATH): | |
| self.knowledge_base_dir = knowledge_base_dir | |
| self.model_path = model_path | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| debug_print(f"Using device: {self.device}") | |
| # Create directories if they don't exist | |
| os.makedirs(knowledge_base_dir, exist_ok=True) | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| # Initialize caches | |
| self.embedding_cache = LRUCache(capacity=Config.CACHE_SIZE) | |
| self.query_cache = LRUCache(capacity=Config.CACHE_SIZE) | |
| # Initialize models | |
| self._init_models() | |
| # Load knowledge base | |
| self._init_knowledge_base() | |
| debug_print("MediQuery system initialized") | |
| def _init_models(self): | |
| """Initialize all required models with lazy loading""" | |
| debug_print("Initializing models...") | |
| # Image model | |
| if Config.IMAGE_MODEL == "chexnet": | |
| self.image_model = models.densenet121(pretrained=False) | |
| # For deployment, we'll download the weights during initialization | |
| try: | |
| # Simplified for deployment - would need to download weights | |
| self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]) | |
| debug_print("CheXNet model initialized") | |
| except Exception as e: | |
| debug_print(f"Error initializing CheXNet: {str(e)}") | |
| # Fallback to standard DenseNet | |
| self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) | |
| else: | |
| self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) | |
| self.image_model = self.image_model.to(self.device).eval() | |
| # Text model - lazy loaded | |
| text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT" | |
| self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) | |
| self.text_model = LazyModel( | |
| text_model_name, | |
| AutoModel, | |
| self.device | |
| ) | |
| # Generation model - lazy loaded | |
| if os.path.exists(self.model_path): | |
| gen_model_path = self.model_path | |
| else: | |
| gen_model_path = "google/flan-t5-base" # Fallback to base model | |
| self.gen_tokenizer = T5Tokenizer.from_pretrained(gen_model_path) | |
| self.gen_model = LazyModel( | |
| gen_model_path, | |
| T5ForConditionalGeneration, | |
| self.device | |
| ) | |
| # Image transformation | |
| self.image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| debug_print("Models initialized") | |
| def _init_knowledge_base(self): | |
| """Initialize knowledge base with FAISS indices""" | |
| debug_print("Initializing knowledge base...") | |
| # For deployment, we'll create a minimal knowledge base | |
| # In a real deployment, you would download the knowledge base files | |
| # Create dummy knowledge base for demonstration | |
| self.text_data = pd.DataFrame({ | |
| 'combined_text': [ | |
| "The chest X-ray shows clear lung fields with no evidence of consolidation, effusion, or pneumothorax. The heart size is normal. No acute cardiopulmonary abnormality.", | |
| "Bilateral patchy airspace opacities consistent with multifocal pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", | |
| "Cardiomegaly with pulmonary vascular congestion and bilateral pleural effusions, consistent with congestive heart failure. No pneumothorax or pneumonia.", | |
| "Right upper lobe opacity concerning for pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", | |
| "Left lower lobe atelectasis. No pneumothorax or pleural effusion. Heart size is normal.", | |
| "Bilateral pleural effusions with bibasilar atelectasis. Cardiomegaly present. Findings consistent with heart failure.", | |
| "Right pneumothorax with partial lung collapse. No pleural effusion. Heart size is normal.", | |
| "Endotracheal tube, central venous catheter, and nasogastric tube in place. No pneumothorax or pleural effusion.", | |
| "Hyperinflated lungs with flattened diaphragms, consistent with COPD. No acute infiltrate or effusion.", | |
| "Multiple rib fractures on the right side. No pneumothorax or hemothorax. Lung fields are clear." | |
| ], | |
| 'valid_index': list(range(10)) | |
| }) | |
| # Create dummy FAISS indices | |
| self.image_index = None # Will be created on first use | |
| self.text_index = None # Will be created on first use | |
| debug_print("Knowledge base initialized") | |
| def _create_dummy_indices(self): | |
| """Create dummy FAISS indices for demonstration""" | |
| # Text embeddings (768 dimensions for BERT-based models) | |
| text_dim = 768 | |
| text_embeddings = np.random.rand(len(self.text_data), text_dim).astype('float32') | |
| # Image embeddings (1024 dimensions for DenseNet121) | |
| image_dim = 1024 | |
| image_embeddings = np.random.rand(len(self.text_data), image_dim).astype('float32') | |
| # Create FAISS indices | |
| self.text_index = faiss.IndexFlatL2(text_dim) | |
| self.text_index.add(text_embeddings) | |
| self.image_index = faiss.IndexFlatL2(image_dim) | |
| self.image_index.add(image_embeddings) | |
| debug_print("Dummy FAISS indices created") | |
| def process_image(self, image_path): | |
| """Process an X-ray image and return analysis results""" | |
| try: | |
| debug_print(f"Processing image: {image_path}") | |
| # Check cache | |
| if Config.USE_CACHING: | |
| cached_result = self.query_cache.get(f"img_{image_path}") | |
| if cached_result: | |
| debug_print("Using cached image result") | |
| return cached_result | |
| # Load and preprocess image | |
| image = Image.open(image_path).convert('RGB') | |
| image_tensor = self.image_transform(image).unsqueeze(0).to(self.device) | |
| # Generate image embedding | |
| with torch.no_grad(): | |
| image_embedding = self.image_model(image_tensor) | |
| image_embedding = nn.functional.avg_pool2d(image_embedding, kernel_size=7).squeeze().cpu().numpy() | |
| # Initialize FAISS indices if needed | |
| if self.image_index is None: | |
| self._create_dummy_indices() | |
| # Retrieve similar cases | |
| distances, indices = self.image_index.search(np.array([image_embedding]), k=Config.TOP_K_RETRIEVAL) | |
| # Get relevant text data | |
| retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] | |
| # Generate context for the model | |
| context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) | |
| # Generate analysis | |
| prompt = f"Analyze this chest X-ray based on similar cases:\n\n{context}\n\nProvide a detailed radiological assessment including findings and impression:" | |
| analysis = self._generate_text(prompt) | |
| # Generate attention map (simplified for deployment) | |
| attention_map = self._generate_attention_map(image) | |
| # Prepare result | |
| result = { | |
| "analysis": analysis, | |
| "attention_map": attention_map, | |
| "confidence": 0.85, # Placeholder | |
| "similar_cases": retrieved_texts[:3] # Return top 3 similar cases | |
| } | |
| # Cache result | |
| if Config.USE_CACHING: | |
| self.query_cache.put(f"img_{image_path}", result) | |
| return result | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}" | |
| debug_print(error_msg) | |
| return {"error": error_msg} | |
| def process_query(self, query_text): | |
| """Process a text query and return relevant information""" | |
| try: | |
| debug_print(f"Processing query: {query_text}") | |
| # Check cache | |
| if Config.USE_CACHING: | |
| cached_result = self.query_cache.get(f"txt_{query_text}") | |
| if cached_result: | |
| debug_print("Using cached query result") | |
| return cached_result | |
| # Anonymize query | |
| query_text = anonymize_text(query_text) | |
| # Generate text embedding | |
| query_embedding = self._generate_text_embedding(query_text) | |
| # Initialize FAISS indices if needed | |
| if self.text_index is None: | |
| self._create_dummy_indices() | |
| # Retrieve similar texts | |
| distances, indices = self.text_index.search(np.array([query_embedding]), k=Config.TOP_K_RETRIEVAL) | |
| # Get relevant text data | |
| retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] | |
| # Generate context for the model | |
| context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) | |
| # Generate response | |
| prompt = f"Answer this medical question based on the following information:\n\nQuestion: {query_text}\n\nRelevant information:\n{context}\n\nDetailed answer:" | |
| response = self._generate_text(prompt) | |
| # Prepare result | |
| result = { | |
| "response": response, | |
| "confidence": 0.9, # Placeholder | |
| "sources": retrieved_texts[:3] # Return top 3 sources | |
| } | |
| # Cache result | |
| if Config.USE_CACHING: | |
| self.query_cache.put(f"txt_{query_text}", result) | |
| return result | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}\n{traceback.format_exc()}" | |
| debug_print(error_msg) | |
| return {"error": error_msg} | |
| def _generate_text_embedding(self, text): | |
| """Generate embedding for text using the text model""" | |
| try: | |
| # Check cache | |
| if Config.USE_CACHING: | |
| cached_embedding = self.embedding_cache.get(f"txt_emb_{text}") | |
| if cached_embedding is not None: | |
| return cached_embedding | |
| # Tokenize | |
| inputs = self.text_tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ).to(self.device) | |
| # Generate embedding | |
| with torch.no_grad(): | |
| outputs = self.text_model(**inputs) | |
| # Use mean pooling | |
| embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0] | |
| # Cache embedding | |
| if Config.USE_CACHING: | |
| self.embedding_cache.put(f"txt_emb_{text}", embedding) | |
| return embedding | |
| except Exception as e: | |
| debug_print(f"Error generating text embedding: {str(e)}") | |
| # Return random embedding as fallback | |
| return np.random.rand(768).astype('float32') | |
| def _generate_text(self, prompt): | |
| """Generate text using the language model""" | |
| try: | |
| # Tokenize | |
| inputs = self.gen_tokenizer( | |
| prompt, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ).to(self.device) | |
| # Generate | |
| with torch.no_grad(): | |
| output_ids = self.gen_model.generate( | |
| inputs.input_ids, | |
| max_length=256, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| # Decode | |
| output_text = self.gen_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return output_text | |
| except Exception as e: | |
| debug_print(f"Error generating text: {str(e)}") | |
| return "I apologize, but I'm unable to generate a response at this time. Please try again later." | |
| def _generate_attention_map(self, image): | |
| """Generate a simplified attention map for the image""" | |
| try: | |
| # Convert to numpy array | |
| img_np = np.array(image.resize((224, 224))) | |
| # Create a simple heatmap (this is a placeholder - real implementation would use model attention) | |
| heatmap = np.zeros((224, 224), dtype=np.float32) | |
| # Add some random "attention" areas | |
| for _ in range(3): | |
| x, y = np.random.randint(50, 174, 2) | |
| radius = np.random.randint(20, 50) | |
| for i in range(224): | |
| for j in range(224): | |
| dist = np.sqrt((i - x)**2 + (j - y)**2) | |
| if dist < radius: | |
| heatmap[i, j] += max(0, 1 - dist/radius) | |
| # Normalize | |
| heatmap = heatmap / heatmap.max() | |
| # Apply colormap | |
| heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| # Overlay on original image | |
| img_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| overlay = cv2.addWeighted(img_rgb, 0.7, heatmap_colored, 0.3, 0) | |
| # Convert to base64 for API response | |
| _, buffer = cv2.imencode('.png', overlay) | |
| img_str = base64.b64encode(buffer).decode('utf-8') | |
| return img_str | |
| except Exception as e: | |
| debug_print(f"Error generating attention map: {str(e)}") | |
| return None | |
| def cleanup(self): | |
| """Clean up resources""" | |
| debug_print("Cleaning up resources...") | |
| # Unload models | |
| if hasattr(self, 'text_model') and isinstance(self.text_model, LazyModel): | |
| self.text_model.unload() | |
| if hasattr(self, 'gen_model') and isinstance(self.gen_model, LazyModel): | |
| self.gen_model.unload() | |
| # Clear caches | |
| if hasattr(self, 'embedding_cache'): | |
| self.embedding_cache.clear() | |
| if hasattr(self, 'query_cache'): | |
| self.query_cache.clear() | |
| # Force garbage collection | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| debug_print("Cleanup complete") | |
| # === FastAPI Application === | |
| app = FastAPI(title="MediQuery API", description="API for MediQuery AI medical assistant") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, specify the actual frontend domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize MediQuery system | |
| mediquery = MediQuery() | |
| # Define API models | |
| class QueryRequest(BaseModel): | |
| text: str | |
| class QueryResponse(BaseModel): | |
| response: str | |
| confidence: float | |
| sources: List[str] | |
| error: Optional[str] = None | |
| class ImageAnalysisResponse(BaseModel): | |
| analysis: str | |
| attention_map: Optional[str] = None | |
| confidence: float | |
| similar_cases: List[str] | |
| error: Optional[str] = None | |
| async def process_text_query(query: QueryRequest): | |
| """Process a text query and return relevant information""" | |
| result = mediquery.process_query(query.text) | |
| return result | |
| async def analyze_image(file: UploadFile = File(...)): | |
| """Analyze an X-ray image and return results""" | |
| # Save uploaded file temporarily | |
| temp_file = f"/tmp/{file.filename}" | |
| with open(temp_file, "wb") as f: | |
| f.write(await file.read()) | |
| # Process image | |
| result = mediquery.process_image(temp_file) | |
| # Clean up | |
| os.remove(temp_file) | |
| return result | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "ok", "version": "1.0.0"} | |
| # === Gradio Interface === | |
| def create_gradio_interface(): | |
| """Create a Gradio interface for the MediQuery system""" | |
| # Define processing functions | |
| def process_image_gradio(image): | |
| # Save image temporarily | |
| temp_file = "/tmp/gradio_image.png" | |
| image.save(temp_file) | |
| # Process image | |
| result = mediquery.process_image(temp_file) | |
| # Clean up | |
| os.remove(temp_file) | |
| # Prepare output | |
| analysis = result.get("analysis", "Error processing image") | |
| attention_map_b64 = result.get("attention_map") | |
| # Convert base64 to image if available | |
| attention_map = None | |
| if attention_map_b64: | |
| try: | |
| attention_map = Image.open(io.BytesIO(base64.b64decode(attention_map_b64))) | |
| except: | |
| pass | |
| return analysis, attention_map | |
| def process_query_gradio(query): | |
| result = mediquery.process_query(query) | |
| return result.get("response", "Error processing query") | |
| # Create interface | |
| with gr.Blocks(title="MediQuery") as demo: | |
| gr.Markdown("# MediQuery - AI Medical Assistant") | |
| with gr.Tab("Image Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Chest X-ray") | |
| image_button = gr.Button("Analyze X-ray") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Analysis Results", lines=10) | |
| image_output = gr.Image(label="Attention Map") | |
| image_button.click( | |
| fn=process_image_gradio, | |
| inputs=[image_input], | |
| outputs=[text_output, image_output] | |
| ) | |
| with gr.Tab("Text Query"): | |
| query_input = gr.Textbox(label="Medical Query", lines=3, placeholder="e.g., What does pneumonia look like on a chest X-ray?") | |
| query_button = gr.Button("Submit Query") | |
| query_output = gr.Textbox(label="Response", lines=10) | |
| query_button.click( | |
| fn=process_query_gradio, | |
| inputs=[query_input], | |
| outputs=[query_output] | |
| ) | |
| gr.Markdown("## Example Queries") | |
| gr.Examples( | |
| examples=[ | |
| ["What does pleural effusion look like?"], | |
| ["How to differentiate pneumonia from tuberculosis?"], | |
| ["What are the signs of cardiomegaly on X-ray?"] | |
| ], | |
| inputs=[query_input] | |
| ) | |
| return demo | |
| # Create Gradio interface | |
| demo = create_gradio_interface() | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Startup and shutdown events | |
| async def startup_event(): | |
| """Initialize resources on startup""" | |
| debug_print("API starting up...") | |
| async def shutdown_event(): | |
| """Clean up resources on shutdown""" | |
| debug_print("API shutting down...") | |
| mediquery.cleanup() | |
| # Run the FastAPI app with uvicorn when executed directly | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |