import os import json import traceback from typing import Dict, Any import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from safetensors.torch import load_file from huggingface_hub import hf_hub_download from transformers import AutoProcessor, AutoModel import gradio as gr # --- Device Setup --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # For 8-bit models, the vision dtype is handled by bitsandbytes # We still need HEAD_DTYPE for our classifier head HEAD_DTYPE = torch.float32 # --- DINOv3 Specific Constants --- DINOV3_PATCH_SIZE = 16 MAX_DINOV3_RESOLUTION = 4096 print(f"Using device: {DEVICE}") print(f"Head model dtype: {HEAD_DTYPE}") # --- Model Definitions (Copied from hybrid_model.py) --- # (RMSNorm, SwiGLUFFN, ResBlockRMS, HybridHeadModel classes are unchanged and go here) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight class SwiGLUFFN(nn.Module): def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or int(in_features * 8 / 3 / 2 * 2 ) hidden_features = (hidden_features + 1) // 2 * 2 self.w12 = nn.Linear(in_features, hidden_features * 2, bias=False) self.act = act_layer() self.dropout1 = nn.Dropout(dropout) self.w3 = nn.Linear(hidden_features, out_features, bias=False) self.dropout2 = nn.Dropout(dropout) def forward(self, x): gate_val, up_val = self.w12(x).chunk(2, dim=-1) x = self.dropout1(self.act(gate_val) * up_val) x = self.dropout2(self.w3(x)) return x class ResBlockRMS(nn.Module): def __init__(self, ch: int, dropout: float = 0.0, rms_norm_eps: float = 1e-6): super().__init__() self.norm = RMSNorm(ch, eps=rms_norm_eps) self.ffn = SwiGLUFFN(in_features=ch, dropout=dropout) def forward(self, x): return x + self.ffn(self.norm(x)) class HybridHeadModel(nn.Module): def __init__(self, features: int, hidden_dim: int = 1280, num_classes: int = 2, use_attention: bool = True, num_attn_heads: int = 16, attn_dropout: float = 0.1, num_res_blocks: int = 3, dropout_rate: float = 0.1, rms_norm_eps: float = 1e-6, output_mode: str = 'linear'): super().__init__() self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes self.use_attention = use_attention; self.output_mode = output_mode.lower() self.attention = None; self.norm_attn = None if self.use_attention: actual_num_heads = num_attn_heads if features % num_attn_heads != 0: possible_heads = [h for h in [1, 2, 4, 8, 16, 32] if features % h == 0] # Expanded list if not possible_heads: actual_num_heads = 1 else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads)) if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads} for features={features}") self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True) self.norm_attn = RMSNorm(features, eps=rms_norm_eps) mlp_layers = [nn.Linear(features, hidden_dim), RMSNorm(hidden_dim, eps=rms_norm_eps)] for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps)) mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps)) down_proj_hidden = hidden_dim // 2 mlp_layers.append(SwiGLUFFN(hidden_dim, hidden_features=down_proj_hidden, out_features=down_proj_hidden, dropout=dropout_rate)) mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps)) mlp_layers.append(nn.Linear(down_proj_hidden, num_classes)) self.mlp_head = nn.Sequential(*mlp_layers) def forward(self, x: torch.Tensor): if self.use_attention and self.attention is not None: x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1)) logits = self.mlp_head(x.to(HEAD_DTYPE)) output_mode = self.output_mode if output_mode == 'linear': output = logits elif output_mode == 'sigmoid': output = torch.sigmoid(logits) elif output_mode == 'softmax': output = F.softmax(logits, dim=-1) elif output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0 else: raise RuntimeError(f"Invalid output_mode '{output_mode}'.") if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1) return output # --- Model Catalog --- MODEL_CATALOG = { "AnatomyFlaws-v15.5 (DINOv3 7b bf16)": { "repo_id": "Enferlain/lumi-classifier", "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json", "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors", # Explicitly define the vision model repo ID to prevent errors # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit" bnb 8bit # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4", int4 "vision_model_repo_id": "PIA-SPACE-LAB/dinov3-vit7b16-pretrain-lvd1689m", }, "AnatomyFlaws-v14.7 (SigLIP naflex)": { "repo_id": "Enferlain/lumi-classifier", "config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json", "head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors", # The base SigLIP model is not custom, so we use its official ID "vision_model_repo_id": "google/siglip2-so400m-patch16-naflex" }, } # --- Model Manager Class --- class ModelManager: def __init__(self, catalog: Dict[str, Dict[str, str]]): self.catalog = catalog self.current_model_name: str = None self.vision_model: nn.Module = None self.hf_processor: Any = None self.head_model: HybridHeadModel = None self.labels: Dict[int, str] = None self.config: Dict[str, Any] = None def load_model(self, model_name: str): if model_name == self.current_model_name: return if model_name not in self.catalog: raise ValueError(f"Model '{model_name}' not found.") print(f"Switching to model: {model_name}...") model_info = self.catalog[model_name] repo_id = model_info["repo_id"] config_filename = model_info["config_filename"] head_filename = model_info["head_filename"] vision_model_repo_id = model_info["vision_model_repo_id"] try: config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) with open(config_path, 'r', encoding='utf-8') as f: self.config = json.load(f) print(f"Loading vision model: {vision_model_repo_id}") self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True) # --- UPDATED: CPU-compatible loading logic --- if DEVICE == "cpu": # For CPU, load unquantized model with BF16 (original format) print("Loading unquantized model for CPU...") try: self.vision_model = AutoModel.from_pretrained( vision_model_repo_id, torch_dtype=torch.bfloat16, # Keep original BF16 format device_map={"": "cpu"}, # Force CPU device mapping trust_remote_code=True ).eval() print("Successfully loaded model in BF16 format.") except Exception as bf16_error: print(f"BF16 loading failed: {bf16_error}") print("Falling back to FP32...") self.vision_model = AutoModel.from_pretrained( vision_model_repo_id, torch_dtype=torch.float32, # Fallback to FP32 device_map={"": "cpu"}, trust_remote_code=True ).eval() print("Successfully loaded model in FP32 format.") else: # For GPU environments (unchanged) self.vision_model = AutoModel.from_pretrained( vision_model_repo_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ).to(DEVICE).eval() # Load classifier head (unchanged) head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename) print(f"Loading head model: {head_filename}") state_dict = load_file(head_model_path, device='cpu') head_params = self.config.get("predictor_params", self.config) self.head_model = HybridHeadModel( features=head_params.get("features"), hidden_dim=head_params.get("hidden_dim"), num_classes=self.config.get("num_classes"), use_attention=head_params.get("use_attention"), num_attn_heads=head_params.get("num_attn_heads"), attn_dropout=head_params.get("attn_dropout"), num_res_blocks=head_params.get("num_res_blocks"), dropout_rate=head_params.get("dropout_rate"), output_mode=head_params.get("output_mode", "linear") ) self.head_model.load_state_dict(state_dict, strict=True) self.head_model.to(DEVICE).eval() raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'}) self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()} self.current_model_name = model_name print(f"Successfully loaded '{model_name}'.") except Exception as e: self.current_model_name = None raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}") # --- Global Model Manager Instance --- model_manager = ModelManager(MODEL_CATALOG) # --- Prediction Function (v3 from before) --- def predict_anatomy_v3(image: Image.Image, model_name: str): if image is None: return {"Error": 1.0, "Info": 0.0} # Return numeric values try: model_manager.load_model(model_name) pil_image = image.convert("RGB") emb = None with torch.no_grad(): base_model_type = model_manager.config.get("base_vision_model", "") if "dinov3" in base_model_type.lower(): current_w, current_h = pil_image.size img_to_process = pil_image if max(current_w, current_h) > MAX_DINOV3_RESOLUTION: scale = MAX_DINOV3_RESOLUTION / max(current_w, current_h) current_w, current_h = int(current_w * scale), int(current_h * scale) img_to_process = pil_image.resize((current_w, current_h), Image.Resampling.LANCZOS) new_w = ((current_w + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE new_h = ((current_h + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE if new_w != current_w or new_h != current_h: img_to_process = img_to_process.resize((new_w, new_h), Image.Resampling.LANCZOS) inputs = model_manager.hf_processor(images=[img_to_process], return_tensors="pt") # For 8-bit, send inputs to the same device as the model pixel_values = inputs.pixel_values.to(model_manager.vision_model.device) outputs = model_manager.vision_model(pixel_values=pixel_values) last_hidden_state = outputs.last_hidden_state nreg = getattr(model_manager.vision_model.config, 'num_register_tokens', 0) patch_embeddings = last_hidden_state[:, 1 + nreg:] emb = torch.mean(patch_embeddings, dim=1) elif "siglip" in base_model_type.lower(): inputs = model_manager.hf_processor(images=[pil_image], return_tensors="pt") pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=torch.float16) if "naflex" in base_model_type.lower(): attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE) spatial_shapes = inputs.get("spatial_shapes") model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask, "spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)} vision_model_component = getattr(model_manager.vision_model, 'vision_model', model_manager.vision_model) emb = vision_model_component(**model_call_kwargs).pooler_output else: emb = model_manager.vision_model.get_image_features(pixel_values=pixel_values) else: raise ValueError(f"Unknown base model type for embedding: {base_model_type}") if emb is None: raise ValueError("Failed to get embedding.") norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8) emb_normalized = emb / norm.to(emb.dtype) with torch.no_grad(): prediction = model_manager.head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE)) output_probs = {} if model_manager.head_model.num_classes == 2: probs = F.softmax(prediction.squeeze().float(), dim=-1) output_probs[model_manager.labels[0]] = probs[0].item() output_probs[model_manager.labels[1]] = probs[1].item() else: prob_good = torch.sigmoid(prediction.squeeze()).item() output_probs[model_manager.labels[0]] = 1.0 - prob_good output_probs[model_manager.labels[1]] = prob_good return output_probs except Exception as e: print(f"Error during prediction: {e}\n{traceback.format_exc()}") # Return properly formatted error for Gradio Label error_msg = str(e)[:50] + "..." if len(str(e)) > 50 else str(e) return { f"Error: {error_msg}": 1.0, "Please check logs": 0.0 } # --- Gradio Interface --- DESCRIPTION = """ ## Lumi's Anatomy Flaw Classifier Demo ✨ Select a model from the dropdown, then upload an image to classify its anatomy/structural correctness. Will be slow since it runs on cpu, ~2minutes on dinov3. """ EXAMPLE_DIR = "examples" default_model = list(MODEL_CATALOG.keys())[0] # 1. Find the paths to our example images example_paths = [] if os.path.isdir(EXAMPLE_DIR): example_paths = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] # 2. Create the nested list Gradio needs: [[image, model_name], [image, model_name], ...] examples_nested = [] if example_paths: examples_nested = [[path, default_model] for path in example_paths] # 3. Create the interface, passing the correctly formatted list interface = gr.Interface( fn=predict_anatomy_v3, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model") ], outputs=gr.Label(label="Class Probabilities", num_top_classes=2), title="Lumi's Anatomy Classifier", description=DESCRIPTION, examples=examples_nested if examples_nested else None, # Pass the new nested list allow_flagging="never", cache_examples=True ) if __name__ == "__main__": try: print("Pre-loading default model...") model_manager.load_model(default_model) except Exception as e: print(f"WARNING: Could not pre-load default model. Error: {e}") interface.launch()