# IMPORTANT: Import spaces FIRST, before any CUDA-related packages try: import spaces HF_SPACES_GPU = True except ImportError: # Create a dummy decorator if not running on Spaces class spaces: @staticmethod def GPU(func): return func HF_SPACES_GPU = False import os import json import gc import traceback from typing import Optional, Tuple, Any import torch import gradio as gr import supervision as sv from PIL import Image # Try to import optional dependencies try: from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, ) except Exception: AutoModelForCausalLM = None AutoTokenizer = None AutoModelForImageTextToText = None AutoProcessor = None BitsAndBytesConfig = None # Try to import huggingface_hub for model downloading try: from huggingface_hub import hf_hub_download except ImportError: hf_hub_download = None # Import RF-DETR (assumes it's in the same directory or installed) try: from rfdetr import RFDETRMedium except ImportError: print("Warning: RF-DETR not found. Please ensure it's properly installed.") RFDETRMedium = None # ============================================================================ # Configuration for Hugging Face Spaces # ============================================================================ class SpacesConfig: """Configuration optimized for Hugging Face Spaces.""" def __init__(self): # Get HF token from environment hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGINGFACE_TOKEN') self.settings = { 'results_dir': '/tmp/results', 'checkpoint': None, 'hf_model_repo': 'edeler/lorai', # Hugging Face model repository 'hf_model_filename': 'lorai.pth', 'hf_token': hf_token, 'resolution': 576, 'threshold': 0.7, 'use_llm': True, 'llm_model_id': 'google/medgemma-4b-it', 'llm_max_new_tokens': 200, 'llm_temperature': 0.2, 'llm_4bit': True, 'enable_caching': True, 'max_cache_size': 100, } def get(self, key: str, default: Any = None) -> Any: return self.settings.get(key, default) def set_hf_model_repo(self, repo_id: str, filename: str = 'lorai.pth'): """Set Hugging Face model repository.""" self.settings['hf_model_repo'] = repo_id self.settings['hf_model_filename'] = filename # ============================================================================ # Memory Management (simplified for Spaces) # ============================================================================ class MemoryManager: """Simplified memory management for Spaces.""" def __init__(self): self.memory_thresholds = { 'gpu_warning': 0.8, 'system_warning': 0.85, } def cleanup_memory(self, force: bool = False) -> None: """Perform memory cleanup.""" try: gc.collect() if torch and torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() except Exception as e: print(f"Memory cleanup error: {e}") # Global memory manager memory_manager = MemoryManager() # ============================================================================ # Model Loading # ============================================================================ def find_checkpoint(hf_repo: Optional[str] = None, hf_filename: str = 'lorai.pth') -> Optional[str]: """Find RF-DETR checkpoint in various locations or download from Hugging Face Hub.""" # First check if we should download from Hugging Face repo_id = hf_repo or os.environ.get('HF_MODEL_REPO') if repo_id and hf_hub_download is not None: try: print(f"Downloading checkpoint from Hugging Face Hub: {repo_id}/{hf_filename}") checkpoint_path = hf_hub_download( repo_id=repo_id, filename=hf_filename, cache_dir="/tmp/hf_cache" ) print(f"✓ Downloaded checkpoint to: {checkpoint_path}") return checkpoint_path except Exception as e: print(f"Warning: Failed to download from Hugging Face Hub: {e}") print("Falling back to local checkpoints...") # Fall back to local file search candidates = [ "lorai.pth", # Current directory "rf-detr-medium.pth", "/tmp/results/checkpoint_best_total.pth", "/tmp/results/checkpoint_best_ema.pth", "/tmp/results/checkpoint_best_regular.pth", "/tmp/results/checkpoint.pth", ] for path in candidates: if os.path.isfile(path): print(f"Found local checkpoint: {path}") return path return None def load_model(checkpoint_path: str, resolution: int): """Load RF-DETR model.""" if RFDETRMedium is None: raise RuntimeError("RF-DETR not available. Please install it properly.") model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution) try: model.optimize_for_inference() except Exception: pass return model # ============================================================================ # LLM Integration # ============================================================================ class TextGenerator: """Simplified text generator for Spaces.""" def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2): self.model_id = model_id self.max_tokens = max_tokens self.temperature = temperature self.model = None self.tokenizer = None self.processor = None self.is_multimodal = False def load_model(self, hf_token: Optional[str] = None): """Load the LLM model.""" if self.model is not None: return if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None): raise RuntimeError("Transformers not available") # Clear memory before loading memory_manager.cleanup_memory() print(f"Loading model: {self.model_id}") model_kwargs = { "device_map": "auto", "low_cpu_mem_usage": True, } # Add token if provided if hf_token: model_kwargs["token"] = hf_token if torch and torch.cuda.is_available(): model_kwargs["torch_dtype"] = torch.bfloat16 # Use 4-bit quantization if available if BitsAndBytesConfig is not None: try: compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16 model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model_kwargs["torch_dtype"] = compute_dtype except Exception: pass # Check if it's a multimodal model is_multimodal = "medgemma" in self.model_id.lower() if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None: self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token) self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs) self.is_multimodal = True elif AutoModelForCausalLM is not None and AutoTokenizer is not None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token) self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs) self.is_multimodal = False else: raise RuntimeError("Required model classes not available") print("✓ Model loaded successfully") def generate(self, text: str, image: Optional[Image.Image] = None, hf_token: Optional[str] = None) -> str: """Generate text using the loaded model.""" self.load_model(hf_token) if self.model is None: return f"[Model not loaded: {text}]" try: # Create messages system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice." user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}" if self.is_multimodal: # Multimodal model user_content = [{"type": "text", "text": user_text}] if image is not None: user_content.append({"type": "image", "image": image}) messages = [ {"role": "system", "content": [{"type": "text", "text": system_text}]}, {"role": "user", "content": user_content}, ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) if torch: inputs = inputs.to(self.model.device, dtype=torch.bfloat16) with torch.inference_mode(): generation = self.model.generate( **inputs, max_new_tokens=self.max_tokens, do_sample=self.temperature > 0, temperature=max(0.01, self.temperature) if self.temperature > 0 else None, use_cache=False, ) input_len = inputs["input_ids"].shape[-1] generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) return decoded.strip() else: # Text-only model messages = [ {"role": "system", "content": system_text}, {"role": "user", "content": user_text}, ] inputs = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) inputs = inputs.to(self.model.device) with torch.inference_mode(): generation = self.model.generate( **inputs, max_new_tokens=self.max_tokens, do_sample=self.temperature > 0, temperature=max(0.01, self.temperature) if self.temperature > 0 else None, use_cache=False, ) input_len = inputs["input_ids"].shape[-1] generation = generation[0][input_len:] decoded = self.tokenizer.decode(generation, skip_special_tokens=True) return decoded.strip() except Exception as e: error_msg = f"[Generation error: {e}]" print(f"Generation error: {traceback.format_exc()}") return f"{error_msg}\n\n{text}" # ============================================================================ # Application State # ============================================================================ class AppState: """Application state for Spaces.""" def __init__(self): self.config = SpacesConfig() self.model = None self.class_names = None self.text_generator = None def load_model(self): """Load the detection model.""" if self.model is not None: return checkpoint = find_checkpoint( hf_repo=self.config.get('hf_model_repo'), hf_filename=self.config.get('hf_model_filename', 'lorai.pth') ) if not checkpoint: hf_repo = self.config.get('hf_model_repo') or os.environ.get('HF_MODEL_REPO') if hf_repo: raise FileNotFoundError( f"No RF-DETR checkpoint found. Could not download from '{hf_repo}'. " "Please check the repository ID and ensure the model file exists." ) else: raise FileNotFoundError( "No RF-DETR checkpoint found. Please either:\n" "1. Set HF_MODEL_REPO environment variable (e.g., 'edeler/lorai'), or\n" "2. Upload lorai.pth to your Space's root directory" ) print(f"Loading RF-DETR from: {checkpoint}") self.model = load_model(checkpoint, self.config.get('resolution')) # Set default class names (can be overridden by results.json) # Index corresponds to class_id from the model default_class_names = ["Background", "Granuloma"] # Try to load class names from results.json loaded_classes = None try: results_json = "/tmp/results/results.json" if os.path.isfile(results_json): with open(results_json, 'r') as f: data = json.load(f) classes = [] for split in ("valid", "test", "train"): if "class_map" in data and split in data["class_map"]: for item in data["class_map"][split]: name = item.get("class") if name and name != "all" and name not in classes: classes.append(name) if classes: loaded_classes = classes except Exception as e: print(f"Could not load class names from results.json: {e}") # Use loaded classes if available, otherwise use defaults self.class_names = loaded_classes if loaded_classes else default_class_names print(f"Using class names: {self.class_names}") print("✓ RF-DETR model loaded") def preload_all_models(self): """Preload both detection and LLM models into VRAM at startup.""" print("=" * 60) print("Preloading all models into VRAM...") print("=" * 60) # Load detection model print("\n[1/2] Loading RF-DETR detection model...") self.load_model() # Load LLM model if self.config.get('use_llm'): print("\n[2/2] Loading MedGemma LLM model...") try: model_size = "4B" # Default to 4B model generator = self.get_text_generator(model_size) hf_token = self.config.get('hf_token') generator.load_model(hf_token) print("✓ MedGemma model loaded and ready") except Exception as e: print(f"⚠️ Warning: Could not preload LLM model: {e}") print("LLM will be loaded on first use instead") print("\n" + "=" * 60) print("✓ All models loaded and ready in VRAM!") print("=" * 60 + "\n") def get_text_generator(self, model_size: str = "4B") -> TextGenerator: """Get or create text generator.""" # Determine model ID based on size selection model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it' # Check if we need to create a new generator for different model size if (self.text_generator is None or hasattr(self.text_generator, 'model_id') and self.text_generator.model_id != model_id): max_tokens = self.config.get('llm_max_new_tokens') temperature = self.config.get('llm_temperature') self.text_generator = TextGenerator(model_id, max_tokens, temperature) return self.text_generator # ============================================================================ # UI and Inference # ============================================================================ def create_detection_interface(): """Create the Gradio interface.""" # Color palette for annotations COLOR_PALETTE = sv.ColorPalette.from_hex([ "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2", "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00", ]) @spaces.GPU def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]: """Process an image and return annotated version with description.""" if image is None: return None, "Please upload an image." try: # Models are preloaded at startup, but check just in case if app_state.model is None: app_state.load_model() # Run detection detections = app_state.model.predict(image, threshold=threshold) # Annotate image bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2) label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK) labels = [] for i in range(len(detections)): class_id = int(detections.class_id[i]) if detections.class_id is not None else None conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0 if app_state.class_names and class_id is not None: if 0 <= class_id < len(app_state.class_names): label_name = app_state.class_names[class_id] else: label_name = str(class_id) else: label_name = str(class_id) if class_id is not None else "object" labels.append(f"{label_name} {conf:.2f}") annotated = image.copy() annotated = bbox_annotator.annotate(annotated, detections) annotated = label_annotator.annotate(annotated, detections, labels) # Generate description description = f"Found {len(detections)} detections above threshold {threshold}:\n\n" if len(detections) > 0: counts = {} for i in range(len(detections)): class_id = int(detections.class_id[i]) if detections.class_id is not None else None if app_state.class_names and class_id is not None: if 0 <= class_id < len(app_state.class_names): name = app_state.class_names[class_id] else: name = str(class_id) else: name = str(class_id) if class_id is not None else "object" counts[name] = counts.get(name, 0) + 1 for name, count in counts.items(): description += f"- {count}× {name}\n" # Use LLM for description if enabled if app_state.config.get('use_llm'): try: generator = app_state.get_text_generator(model_size) hf_token = app_state.config.get('hf_token') # Model is already preloaded, just generate llm_description = generator.generate(description, image=annotated, hf_token=hf_token) description = llm_description except Exception as e: print(f"LLM generation failed: {e}") # Just use the basic description if LLM fails pass else: description += "No objects detected above the confidence threshold." return annotated, description except Exception as e: error_msg = f"Error processing image: {str(e)}" print(f"Processing error: {traceback.format_exc()}") return None, error_msg # Create the interface with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🏥 Medical Image Analysis") gr.Markdown("Upload a medical image to detect and analyze findings using AI.") # Check if HF token is available hf_token = app_state.config.get('hf_token') if not hf_token: gr.Markdown("⚠️ **Note:** HF_TOKEN not set. AI text generation will be disabled. Detection will still work.") else: gr.Markdown("✅ **AI-powered analysis enabled** using MedGemma 4B") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image", height=400) threshold_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Confidence Threshold", info="Higher values = fewer but more confident detections" ) model_size_radio = gr.Radio( choices=["4B"], value="4B", label="MedGemma Model Size", info="Using MedGemma 4B for AI-generated analysis", visible=False # Hide since only one option ) analyze_btn = gr.Button("🔍 Analyze Image", variant="primary") # Example images gr.Examples( examples=[ ["1.jpg"], ["2.jpg"], ["3.jpg"], ], inputs=input_image, label="Example Images", examples_per_page=3 ) with gr.Column(): output_image = gr.Image(type="pil", label="Results", height=400) output_text = gr.Textbox( label="Analysis Results", lines=8, max_lines=15, show_copy_button=True ) # Wire up the interface analyze_btn.click( fn=annotate_image, inputs=[input_image, threshold_slider, model_size_radio], outputs=[output_image, output_text] ) # Also run when image is uploaded input_image.change( fn=annotate_image, inputs=[input_image, threshold_slider, model_size_radio], outputs=[output_image, output_text] ) # Footer gr.Markdown("---") return demo # ============================================================================ # Main Application # ============================================================================ # Global app state app_state = AppState() def main(): """Main entry point for the Spaces app.""" print("🚀 Starting Medical Image Analysis App") # Ensure results directory exists os.makedirs(app_state.config.get('results_dir'), exist_ok=True) # Preload models if NOT running on HF Spaces with GPU # On HF Spaces, models will be loaded on first inference call (triggered by @spaces.GPU) if not HF_SPACES_GPU: print("Running locally - preloading models into VRAM...") try: app_state.preload_all_models() except Exception as e: print(f"⚠️ Warning: Failed to preload models: {e}") print("Models will be loaded on first use instead") else: print("Running on HF Spaces - models will load on first inference (via @spaces.GPU)") print("This is the recommended approach for Spaces GPU management.") # Create and launch the interface demo = create_detection_interface() # Launch with Spaces-optimized settings demo.launch( server_name="0.0.0.0", server_port=7860, share=False, # Spaces handles this show_error=True, show_api=False, ) if __name__ == "__main__": main()