Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple, Type, Any | |
| from pathlib import Path | |
| import uuid | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from PIL import Image | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| class XRayPhraseGroundingInput(BaseModel): | |
| """Input schema for the XRay Phrase Grounding Tool. Only supports JPG or PNG images.""" | |
| image_path: str = Field( | |
| ..., | |
| description="Path to the frontal chest X-ray image file, only supports JPG or PNG images", | |
| ) | |
| phrase: str = Field( | |
| ..., | |
| description="Medical finding or condition to locate in the image (e.g., 'Pleural effusion')", | |
| ) | |
| max_new_tokens: int = Field(default=300, description="Maximum number of new tokens to generate") | |
| class XRayPhraseGroundingTool(BaseTool): | |
| """Tool for grounding medical findings in chest X-ray images using the MAIRA-2 model. | |
| This tool processes chest X-ray images and locates specific medical findings mentioned | |
| in the input phrase. It returns both the bounding box coordinates and a visualization | |
| of the finding's location in the image. | |
| """ | |
| name: str = "xray_phrase_grounding" | |
| description: str = ( | |
| "Locates and visualizes specific medical findings in chest X-ray images. " | |
| "Takes a chest X-ray image and medical phrase to locate (e.g., 'Pleural effusion', 'Cardiomegaly'). " | |
| "Returns bounding box coordinates in format [x_topleft, y_topleft, x_bottomright, y_bottomright] " | |
| "where each value is between 0-1 representing relative position in the image, " | |
| "a visualization of the finding's location, and confidence metadata. " | |
| "Example input: {'image_path': '/path/to/xray.png', 'phrase': 'Pleural effusion', 'max_new_tokens': 300}" | |
| ) | |
| args_schema: Type[BaseModel] = XRayPhraseGroundingInput | |
| model: Any = None | |
| processor: Any = None | |
| device: str = "cuda" | |
| temp_dir: Path = None | |
| def __init__( | |
| self, | |
| model_path: str = "microsoft/maira-2", | |
| cache_dir: Optional[str] = None, | |
| temp_dir: Optional[str] = None, | |
| load_in_4bit: bool = False, | |
| load_in_8bit: bool = False, | |
| device: Optional[str] = "cuda", | |
| ): | |
| """Initialize the XRay Phrase Grounding Tool.""" | |
| super().__init__() | |
| self.device = torch.device(device) if device else "cuda" | |
| # Setup quantization config | |
| if load_in_4bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| elif load_in_8bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| ) | |
| else: | |
| quantization_config = None | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map=self.device, | |
| cache_dir=cache_dir, | |
| trust_remote_code=True, | |
| quantization_config=quantization_config, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_path, cache_dir=cache_dir, trust_remote_code=True | |
| ) | |
| self.model = self.model.eval() | |
| self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp()) | |
| self.temp_dir.mkdir(exist_ok=True) | |
| def _visualize_bboxes( | |
| self, image: Image.Image, bboxes: List[Tuple[float, float, float, float]], phrase: str | |
| ) -> str: | |
| """Create and save visualization of multiple bounding boxes on the image.""" | |
| plt.figure(figsize=(12, 12)) | |
| plt.imshow(image, cmap="gray") | |
| for bbox in bboxes: | |
| x1, y1, x2, y2 = bbox | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| plt.gca().add_patch( | |
| plt.Rectangle( | |
| (x1 * image.width, y1 * image.height), | |
| width * image.width, | |
| height * image.height, | |
| fill=False, | |
| color="red", | |
| linewidth=2, | |
| ) | |
| ) | |
| plt.title(f"Located: {phrase}", pad=20) | |
| plt.axis("off") | |
| viz_path = self.temp_dir / f"grounding_{uuid.uuid4().hex[:8]}.png" | |
| plt.savefig(viz_path, bbox_inches="tight", dpi=150) | |
| plt.close() | |
| return str(viz_path) | |
| def _run( | |
| self, | |
| image_path: str, | |
| phrase: str, | |
| max_new_tokens: int = 300, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Ground a medical finding phrase in an X-ray image. | |
| Args: | |
| image_path: Path to the chest X-ray image file | |
| phrase: Medical finding to locate in the image | |
| max_new_tokens: Maximum number of new tokens to generate | |
| run_manager: Optional callback manager | |
| Returns: | |
| Tuple[Dict, Dict]: Output dictionary and metadata dictionary | |
| """ | |
| try: | |
| image = Image.open(image_path) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| inputs = self.processor.format_and_preprocess_phrase_grounding_input( | |
| frontal_image=image, phrase=phrase, return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| ) | |
| prompt_length = inputs["input_ids"].shape[-1] | |
| decoded_text = self.processor.decode( | |
| output[0][prompt_length:], skip_special_tokens=True | |
| ) | |
| predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence( | |
| decoded_text | |
| ) | |
| metadata = { | |
| "image_path": image_path, | |
| "original_size": image.size, | |
| "model_input_size": tuple(inputs["pixel_values"].shape[-2:]), | |
| "device": str(self.device), | |
| "analysis_status": "completed", | |
| } | |
| if not predictions: | |
| output = { | |
| "predictions": [], | |
| "visualization_path": None, | |
| } | |
| metadata["analysis_status"] = "completed_no_finding" | |
| return output, metadata | |
| # Process multiple predictions | |
| processed_predictions = [] | |
| for pred_phrase, pred_bboxes in predictions: | |
| if not pred_bboxes: # Skip if no bounding boxes | |
| continue | |
| # Convert model bboxes to list format and get original image bboxes | |
| model_bboxes = [list(bbox) for bbox in pred_bboxes] | |
| original_bboxes = [ | |
| self.processor.adjust_box_for_original_image_size( | |
| bbox, width=image.size[0], height=image.size[1] | |
| ) | |
| for bbox in model_bboxes | |
| ] | |
| processed_predictions.append( | |
| { | |
| "phrase": pred_phrase, | |
| "bounding_boxes": { | |
| "model_coordinates": model_bboxes, | |
| "image_coordinates": original_bboxes, | |
| }, | |
| } | |
| ) | |
| # Create visualization with all bounding boxes | |
| if processed_predictions: | |
| all_bboxes = [] | |
| for pred in processed_predictions: | |
| all_bboxes.extend(pred["bounding_boxes"]["image_coordinates"]) | |
| viz_path = self._visualize_bboxes(image, all_bboxes, phrase) | |
| else: | |
| viz_path = None | |
| metadata["analysis_status"] = "completed_no_finding" | |
| output = { | |
| "predictions": processed_predictions, | |
| "visualization_path": viz_path, | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| output = {"error": str(e)} | |
| metadata = { | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| "error_details": str(e), | |
| } | |
| return output, metadata | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| phrase: str, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Asynchronous version of _run.""" | |
| return self._run(image_path, phrase, run_manager) | |