LorAI / app.py
edeler's picture
Upload app.py
4a754ed verified
# IMPORTANT: Import spaces FIRST, before any CUDA-related packages
try:
import spaces
HF_SPACES_GPU = True
except ImportError:
# Create a dummy decorator if not running on Spaces
class spaces:
@staticmethod
def GPU(func):
return func
HF_SPACES_GPU = False
import os
import json
import gc
import traceback
from typing import Optional, Tuple, Any
import torch
import gradio as gr
import supervision as sv
from PIL import Image
# Try to import optional dependencies
try:
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoModelForImageTextToText,
AutoProcessor,
BitsAndBytesConfig,
)
except Exception:
AutoModelForCausalLM = None
AutoTokenizer = None
AutoModelForImageTextToText = None
AutoProcessor = None
BitsAndBytesConfig = None
# Try to import huggingface_hub for model downloading
try:
from huggingface_hub import hf_hub_download
except ImportError:
hf_hub_download = None
# Import RF-DETR (assumes it's in the same directory or installed)
try:
from rfdetr import RFDETRMedium
except ImportError:
print("Warning: RF-DETR not found. Please ensure it's properly installed.")
RFDETRMedium = None
# ============================================================================
# Configuration for Hugging Face Spaces
# ============================================================================
class SpacesConfig:
"""Configuration optimized for Hugging Face Spaces."""
def __init__(self):
# Get HF token from environment
hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGINGFACE_TOKEN')
self.settings = {
'results_dir': '/tmp/results',
'checkpoint': None,
'hf_model_repo': 'edeler/lorai', # Hugging Face model repository
'hf_model_filename': 'lorai.pth',
'hf_token': hf_token,
'resolution': 576,
'threshold': 0.7,
'use_llm': True,
'llm_model_id': 'google/medgemma-4b-it',
'llm_max_new_tokens': 200,
'llm_temperature': 0.2,
'llm_4bit': True,
'enable_caching': True,
'max_cache_size': 100,
}
def get(self, key: str, default: Any = None) -> Any:
return self.settings.get(key, default)
def set_hf_model_repo(self, repo_id: str, filename: str = 'lorai.pth'):
"""Set Hugging Face model repository."""
self.settings['hf_model_repo'] = repo_id
self.settings['hf_model_filename'] = filename
# ============================================================================
# Memory Management (simplified for Spaces)
# ============================================================================
class MemoryManager:
"""Simplified memory management for Spaces."""
def __init__(self):
self.memory_thresholds = {
'gpu_warning': 0.8,
'system_warning': 0.85,
}
def cleanup_memory(self, force: bool = False) -> None:
"""Perform memory cleanup."""
try:
gc.collect()
if torch and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception as e:
print(f"Memory cleanup error: {e}")
# Global memory manager
memory_manager = MemoryManager()
# ============================================================================
# Model Loading
# ============================================================================
def find_checkpoint(hf_repo: Optional[str] = None, hf_filename: str = 'lorai.pth') -> Optional[str]:
"""Find RF-DETR checkpoint in various locations or download from Hugging Face Hub."""
# First check if we should download from Hugging Face
repo_id = hf_repo or os.environ.get('HF_MODEL_REPO')
if repo_id and hf_hub_download is not None:
try:
print(f"Downloading checkpoint from Hugging Face Hub: {repo_id}/{hf_filename}")
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
cache_dir="/tmp/hf_cache"
)
print(f"βœ“ Downloaded checkpoint to: {checkpoint_path}")
return checkpoint_path
except Exception as e:
print(f"Warning: Failed to download from Hugging Face Hub: {e}")
print("Falling back to local checkpoints...")
# Fall back to local file search
candidates = [
"lorai.pth", # Current directory
"rf-detr-medium.pth",
"/tmp/results/checkpoint_best_total.pth",
"/tmp/results/checkpoint_best_ema.pth",
"/tmp/results/checkpoint_best_regular.pth",
"/tmp/results/checkpoint.pth",
]
for path in candidates:
if os.path.isfile(path):
print(f"Found local checkpoint: {path}")
return path
return None
def load_model(checkpoint_path: str, resolution: int):
"""Load RF-DETR model."""
if RFDETRMedium is None:
raise RuntimeError("RF-DETR not available. Please install it properly.")
model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
try:
model.optimize_for_inference()
except Exception:
pass
return model
# ============================================================================
# LLM Integration
# ============================================================================
class TextGenerator:
"""Simplified text generator for Spaces."""
def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
self.model_id = model_id
self.max_tokens = max_tokens
self.temperature = temperature
self.model = None
self.tokenizer = None
self.processor = None
self.is_multimodal = False
def load_model(self, hf_token: Optional[str] = None):
"""Load the LLM model."""
if self.model is not None:
return
if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
raise RuntimeError("Transformers not available")
# Clear memory before loading
memory_manager.cleanup_memory()
print(f"Loading model: {self.model_id}")
model_kwargs = {
"device_map": "auto",
"low_cpu_mem_usage": True,
}
# Add token if provided
if hf_token:
model_kwargs["token"] = hf_token
if torch and torch.cuda.is_available():
model_kwargs["torch_dtype"] = torch.bfloat16
# Use 4-bit quantization if available
if BitsAndBytesConfig is not None:
try:
compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model_kwargs["torch_dtype"] = compute_dtype
except Exception:
pass
# Check if it's a multimodal model
is_multimodal = "medgemma" in self.model_id.lower()
if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
self.processor = AutoProcessor.from_pretrained(self.model_id, token=hf_token)
self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
self.is_multimodal = True
elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token)
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
self.is_multimodal = False
else:
raise RuntimeError("Required model classes not available")
print("βœ“ Model loaded successfully")
def generate(self, text: str, image: Optional[Image.Image] = None, hf_token: Optional[str] = None) -> str:
"""Generate text using the loaded model."""
self.load_model(hf_token)
if self.model is None:
return f"[Model not loaded: {text}]"
try:
# Create messages
system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
if self.is_multimodal:
# Multimodal model
user_content = [{"type": "text", "text": user_text}]
if image is not None:
user_content.append({"type": "image", "image": image})
messages = [
{"role": "system", "content": [{"type": "text", "text": system_text}]},
{"role": "user", "content": user_content},
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
if torch:
inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
with torch.inference_mode():
generation = self.model.generate(
**inputs,
max_new_tokens=self.max_tokens,
do_sample=self.temperature > 0,
temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
use_cache=False,
)
input_len = inputs["input_ids"].shape[-1]
generation = generation[0][input_len:]
decoded = self.processor.decode(generation, skip_special_tokens=True)
return decoded.strip()
else:
# Text-only model
messages = [
{"role": "system", "content": system_text},
{"role": "user", "content": user_text},
]
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
with torch.inference_mode():
generation = self.model.generate(
**inputs,
max_new_tokens=self.max_tokens,
do_sample=self.temperature > 0,
temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
use_cache=False,
)
input_len = inputs["input_ids"].shape[-1]
generation = generation[0][input_len:]
decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
return decoded.strip()
except Exception as e:
error_msg = f"[Generation error: {e}]"
print(f"Generation error: {traceback.format_exc()}")
return f"{error_msg}\n\n{text}"
# ============================================================================
# Application State
# ============================================================================
class AppState:
"""Application state for Spaces."""
def __init__(self):
self.config = SpacesConfig()
self.model = None
self.class_names = None
self.text_generator = None
def load_model(self):
"""Load the detection model."""
if self.model is not None:
return
checkpoint = find_checkpoint(
hf_repo=self.config.get('hf_model_repo'),
hf_filename=self.config.get('hf_model_filename', 'lorai.pth')
)
if not checkpoint:
hf_repo = self.config.get('hf_model_repo') or os.environ.get('HF_MODEL_REPO')
if hf_repo:
raise FileNotFoundError(
f"No RF-DETR checkpoint found. Could not download from '{hf_repo}'. "
"Please check the repository ID and ensure the model file exists."
)
else:
raise FileNotFoundError(
"No RF-DETR checkpoint found. Please either:\n"
"1. Set HF_MODEL_REPO environment variable (e.g., 'edeler/lorai'), or\n"
"2. Upload lorai.pth to your Space's root directory"
)
print(f"Loading RF-DETR from: {checkpoint}")
self.model = load_model(checkpoint, self.config.get('resolution'))
# Set default class names (can be overridden by results.json)
# Index corresponds to class_id from the model
default_class_names = ["Background", "Granuloma"]
# Try to load class names from results.json
loaded_classes = None
try:
results_json = "/tmp/results/results.json"
if os.path.isfile(results_json):
with open(results_json, 'r') as f:
data = json.load(f)
classes = []
for split in ("valid", "test", "train"):
if "class_map" in data and split in data["class_map"]:
for item in data["class_map"][split]:
name = item.get("class")
if name and name != "all" and name not in classes:
classes.append(name)
if classes:
loaded_classes = classes
except Exception as e:
print(f"Could not load class names from results.json: {e}")
# Use loaded classes if available, otherwise use defaults
self.class_names = loaded_classes if loaded_classes else default_class_names
print(f"Using class names: {self.class_names}")
print("βœ“ RF-DETR model loaded")
def preload_all_models(self):
"""Preload both detection and LLM models into VRAM at startup."""
print("=" * 60)
print("Preloading all models into VRAM...")
print("=" * 60)
# Load detection model
print("\n[1/2] Loading RF-DETR detection model...")
self.load_model()
# Load LLM model
if self.config.get('use_llm'):
print("\n[2/2] Loading MedGemma LLM model...")
try:
model_size = "4B" # Default to 4B model
generator = self.get_text_generator(model_size)
hf_token = self.config.get('hf_token')
generator.load_model(hf_token)
print("βœ“ MedGemma model loaded and ready")
except Exception as e:
print(f"⚠️ Warning: Could not preload LLM model: {e}")
print("LLM will be loaded on first use instead")
print("\n" + "=" * 60)
print("βœ“ All models loaded and ready in VRAM!")
print("=" * 60 + "\n")
def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
"""Get or create text generator."""
# Determine model ID based on size selection
model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
# Check if we need to create a new generator for different model size
if (self.text_generator is None or
hasattr(self.text_generator, 'model_id') and
self.text_generator.model_id != model_id):
max_tokens = self.config.get('llm_max_new_tokens')
temperature = self.config.get('llm_temperature')
self.text_generator = TextGenerator(model_id, max_tokens, temperature)
return self.text_generator
# ============================================================================
# UI and Inference
# ============================================================================
def create_detection_interface():
"""Create the Gradio interface."""
# Color palette for annotations
COLOR_PALETTE = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
"#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
"#66ff66", "#99ff00",
])
@spaces.GPU
def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
"""Process an image and return annotated version with description."""
if image is None:
return None, "Please upload an image."
try:
# Models are preloaded at startup, but check just in case
if app_state.model is None:
app_state.load_model()
# Run detection
detections = app_state.model.predict(image, threshold=threshold)
# Annotate image
bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
labels = []
for i in range(len(detections)):
class_id = int(detections.class_id[i]) if detections.class_id is not None else None
conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
if app_state.class_names and class_id is not None:
if 0 <= class_id < len(app_state.class_names):
label_name = app_state.class_names[class_id]
else:
label_name = str(class_id)
else:
label_name = str(class_id) if class_id is not None else "object"
labels.append(f"{label_name} {conf:.2f}")
annotated = image.copy()
annotated = bbox_annotator.annotate(annotated, detections)
annotated = label_annotator.annotate(annotated, detections, labels)
# Generate description
description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
if len(detections) > 0:
counts = {}
for i in range(len(detections)):
class_id = int(detections.class_id[i]) if detections.class_id is not None else None
if app_state.class_names and class_id is not None:
if 0 <= class_id < len(app_state.class_names):
name = app_state.class_names[class_id]
else:
name = str(class_id)
else:
name = str(class_id) if class_id is not None else "object"
counts[name] = counts.get(name, 0) + 1
for name, count in counts.items():
description += f"- {count}Γ— {name}\n"
# Use LLM for description if enabled
if app_state.config.get('use_llm'):
try:
generator = app_state.get_text_generator(model_size)
hf_token = app_state.config.get('hf_token')
# Model is already preloaded, just generate
llm_description = generator.generate(description, image=annotated, hf_token=hf_token)
description = llm_description
except Exception as e:
print(f"LLM generation failed: {e}")
# Just use the basic description if LLM fails
pass
else:
description += "No objects detected above the confidence threshold."
return annotated, description
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(f"Processing error: {traceback.format_exc()}")
return None, error_msg
# Create the interface
with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ₯ Medical Image Analysis")
gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
# Check if HF token is available
hf_token = app_state.config.get('hf_token')
if not hf_token:
gr.Markdown("⚠️ **Note:** HF_TOKEN not set. AI text generation will be disabled. Detection will still work.")
else:
gr.Markdown("βœ… **AI-powered analysis enabled** using MedGemma 4B")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", height=400)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Confidence Threshold",
info="Higher values = fewer but more confident detections"
)
model_size_radio = gr.Radio(
choices=["4B"],
value="4B",
label="MedGemma Model Size",
info="Using MedGemma 4B for AI-generated analysis",
visible=False # Hide since only one option
)
analyze_btn = gr.Button("πŸ” Analyze Image", variant="primary")
# Example images
gr.Examples(
examples=[
["1.jpg"],
["2.jpg"],
["3.jpg"],
],
inputs=input_image,
label="Example Images",
examples_per_page=3
)
with gr.Column():
output_image = gr.Image(type="pil", label="Results", height=400)
output_text = gr.Textbox(
label="Analysis Results",
lines=8,
max_lines=15,
show_copy_button=True
)
# Wire up the interface
analyze_btn.click(
fn=annotate_image,
inputs=[input_image, threshold_slider, model_size_radio],
outputs=[output_image, output_text]
)
# Also run when image is uploaded
input_image.change(
fn=annotate_image,
inputs=[input_image, threshold_slider, model_size_radio],
outputs=[output_image, output_text]
)
# Footer
gr.Markdown("---")
return demo
# ============================================================================
# Main Application
# ============================================================================
# Global app state
app_state = AppState()
def main():
"""Main entry point for the Spaces app."""
print("πŸš€ Starting Medical Image Analysis App")
# Ensure results directory exists
os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
# Preload models if NOT running on HF Spaces with GPU
# On HF Spaces, models will be loaded on first inference call (triggered by @spaces.GPU)
if not HF_SPACES_GPU:
print("Running locally - preloading models into VRAM...")
try:
app_state.preload_all_models()
except Exception as e:
print(f"⚠️ Warning: Failed to preload models: {e}")
print("Models will be loaded on first use instead")
else:
print("Running on HF Spaces - models will load on first inference (via @spaces.GPU)")
print("This is the recommended approach for Spaces GPU management.")
# Create and launch the interface
demo = create_detection_interface()
# Launch with Spaces-optimized settings
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # Spaces handles this
show_error=True,
show_api=False,
)
if __name__ == "__main__":
main()