Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple, Type, Any | |
| from pathlib import Path | |
| import uuid | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torchxrayvision as xrv | |
| import matplotlib.pyplot as plt | |
| import skimage.io | |
| import skimage.measure | |
| import skimage.transform | |
| import traceback | |
| from pydantic import BaseModel, Field | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| class ChestXRaySegmentationInput(BaseModel): | |
| """Input schema for the Chest X-ray Segmentation Tool.""" | |
| image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") | |
| organs: Optional[List[str]] = Field( | |
| None, | |
| description="List of organs to segment. If None, all available organs will be segmented. " | |
| "Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " | |
| "Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " | |
| "Mediastinum, Weasand, Spine", | |
| ) | |
| class OrganMetrics(BaseModel): | |
| """Detailed metrics for a segmented organ.""" | |
| # Basic metrics | |
| area_pixels: int = Field(..., description="Area in pixels") | |
| area_cm2: float = Field(..., description="Approximate area in cm²") | |
| centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") | |
| bbox: Tuple[int, int, int, int] = Field( | |
| ..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)" | |
| ) | |
| # Size metrics | |
| width: int = Field(..., description="Width of the organ in pixels") | |
| height: int = Field(..., description="Height of the organ in pixels") | |
| aspect_ratio: float = Field(..., description="Height/width ratio") | |
| # Position metrics | |
| relative_position: Dict[str, float] = Field( | |
| ..., description="Position relative to image boundaries (0-1 scale)" | |
| ) | |
| # Analysis metrics | |
| mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") | |
| std_intensity: float = Field(..., description="Standard deviation of pixel intensity") | |
| confidence_score: float = Field(..., description="Model confidence score for this organ") | |
| class ChestXRaySegmentationTool(BaseTool): | |
| """Tool for performing detailed segmentation analysis of chest X-ray images.""" | |
| name: str = "chest_xray_segmentation" | |
| description: str = ( | |
| "Segments chest X-ray images to specified anatomical structures. " | |
| "Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " | |
| "Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " | |
| "Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " | |
| "and Spine. Returns segmentation visualization and comprehensive metrics. " | |
| "Let the user know the area is not accurate unless input has been DICOM." | |
| ) | |
| args_schema: Type[BaseModel] = ChestXRaySegmentationInput | |
| model: Any = None | |
| device: Optional[str] = "cuda" | |
| transform: Any = None | |
| pixel_spacing_mm: float = 0.2 | |
| temp_dir: Path = Path("temp") | |
| organ_map: Dict[str, int] = None | |
| def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")): | |
| """Initialize the segmentation tool with model and temporary directory.""" | |
| super().__init__() | |
| self.model = xrv.baseline_models.chestx_det.PSPNet() | |
| self.device = torch.device(device) if device else "cuda" | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| self.transform = torchvision.transforms.Compose( | |
| [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)] | |
| ) | |
| self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir) | |
| self.temp_dir.mkdir(exist_ok=True) | |
| # Map friendly names to model target indices | |
| self.organ_map = { | |
| "Left Clavicle": 0, | |
| "Right Clavicle": 1, | |
| "Left Scapula": 2, | |
| "Right Scapula": 3, | |
| "Left Lung": 4, | |
| "Right Lung": 5, | |
| "Left Hilus Pulmonis": 6, | |
| "Right Hilus Pulmonis": 7, | |
| "Heart": 8, | |
| "Aorta": 9, | |
| "Facies Diaphragmatica": 10, | |
| "Mediastinum": 11, | |
| "Weasand": 12, | |
| "Spine": 13, | |
| } | |
| def _align_mask_to_original( | |
| self, mask: np.ndarray, original_shape: Tuple[int, int] | |
| ) -> np.ndarray: | |
| """ | |
| Align a mask from the transformed (cropped/resized) space back to the full original image. | |
| Assumes that the transform does a center crop to a square of side = min(original height, width) | |
| and then resizes to (512,512). | |
| """ | |
| orig_h, orig_w = original_shape | |
| crop_size = min(orig_h, orig_w) | |
| crop_top = (orig_h - crop_size) // 2 | |
| crop_left = (orig_w - crop_size) // 2 | |
| # Resize mask (from 512x512) to the cropped region size | |
| resized_mask = skimage.transform.resize( | |
| mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False | |
| ) | |
| full_mask = np.zeros(original_shape) | |
| full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask | |
| return full_mask | |
| def _compute_organ_metrics( | |
| self, mask: np.ndarray, original_img: np.ndarray, confidence: float | |
| ) -> Optional[OrganMetrics]: | |
| """Compute comprehensive metrics for a single organ mask.""" | |
| # Align mask to the original image coordinates if needed | |
| if mask.shape != original_img.shape: | |
| mask = self._align_mask_to_original(mask, original_img.shape) | |
| props = skimage.measure.regionprops(mask.astype(int)) | |
| if not props: | |
| return None | |
| props = props[0] | |
| area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 | |
| img_height, img_width = mask.shape | |
| cy, cx = props.centroid | |
| relative_pos = { | |
| "top": cy / img_height, | |
| "left": cx / img_width, | |
| "center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), | |
| } | |
| organ_pixels = original_img[mask > 0] | |
| mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 | |
| std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 | |
| return OrganMetrics( | |
| area_pixels=int(mask.sum()), | |
| area_cm2=float(area_cm2), | |
| centroid=(float(cy), float(cx)), | |
| bbox=tuple(map(int, props.bbox)), | |
| width=int(props.bbox[3] - props.bbox[1]), | |
| height=int(props.bbox[2] - props.bbox[0]), | |
| aspect_ratio=float( | |
| (props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1]) | |
| ), | |
| relative_position=relative_pos, | |
| mean_intensity=float(mean_intensity), | |
| std_intensity=float(std_intensity), | |
| confidence_score=float(confidence), | |
| ) | |
| def _save_visualization( | |
| self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int] | |
| ) -> str: | |
| """Save visualization of original image with segmentation masks overlaid.""" | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow( | |
| original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0] | |
| ) | |
| # Generate color palette for organs | |
| colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices))) | |
| # Process and overlay each organ mask | |
| for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): | |
| mask = pred_masks[0, organ_idx].cpu().numpy() | |
| if mask.sum() > 0: | |
| # Align the mask to the original image coordinates | |
| if mask.shape != original_img.shape: | |
| mask = self._align_mask_to_original(mask, original_img.shape) | |
| # Create a colored overlay with transparency | |
| colored_mask = np.zeros((*original_img.shape, 4)) | |
| colored_mask[mask > 0] = (*color[:3], 0.3) | |
| plt.imshow( | |
| colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0] | |
| ) | |
| # Add legend entry for the organ | |
| organ_name = list(self.organ_map.keys())[ | |
| list(self.organ_map.values()).index(organ_idx) | |
| ] | |
| plt.plot([], [], color=color, label=organ_name, linewidth=3) | |
| plt.title("Segmentation Overlay") | |
| plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") | |
| plt.axis("off") | |
| save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" | |
| plt.savefig(save_path, bbox_inches="tight", dpi=300) | |
| plt.close() | |
| return str(save_path) | |
| def _run( | |
| self, | |
| image_path: str, | |
| organs: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Run segmentation analysis for specified organs.""" | |
| try: | |
| # Validate and get organ indices | |
| if organs: | |
| organs = [o.strip() for o in organs] | |
| invalid_organs = [o for o in organs if o not in self.organ_map] | |
| if invalid_organs: | |
| raise ValueError(f"Invalid organs specified: {invalid_organs}") | |
| organ_indices = [self.organ_map[o] for o in organs] | |
| else: | |
| organ_indices = list(self.organ_map.values()) | |
| organs = list(self.organ_map.keys()) | |
| # Load and process image | |
| original_img = skimage.io.imread(image_path) | |
| if len(original_img.shape) > 2: | |
| original_img = original_img[:, :, 0] | |
| img = xrv.datasets.normalize(original_img, 255) | |
| img = img[None, ...] | |
| img = self.transform(img) | |
| img = torch.from_numpy(img) | |
| img = img.to(self.device) | |
| # Generate predictions | |
| with torch.no_grad(): | |
| pred = self.model(img) | |
| pred_probs = torch.sigmoid(pred) | |
| pred_masks = (pred_probs > 0.5).float() | |
| # Save visualization | |
| viz_path = self._save_visualization(original_img, pred_masks, organ_indices) | |
| # Compute metrics for selected organs | |
| results = {} | |
| for idx, organ_name in zip(organ_indices, organs): | |
| mask = pred_masks[0, idx].cpu().numpy() | |
| if mask.sum() > 0: | |
| metrics = self._compute_organ_metrics( | |
| mask, original_img, float(pred_probs[0, idx].mean().cpu()) | |
| ) | |
| if metrics: | |
| results[organ_name] = metrics | |
| output = { | |
| "segmentation_image_path": viz_path, | |
| "metrics": {organ: metrics.dict() for organ, metrics in results.items()}, | |
| } | |
| metadata = { | |
| "image_path": image_path, | |
| "segmentation_image_path": viz_path, | |
| "original_size": original_img.shape, | |
| "model_size": tuple(img.shape[-2:]), | |
| "pixel_spacing_mm": self.pixel_spacing_mm, | |
| "requested_organs": organs, | |
| "processed_organs": list(results.keys()), | |
| "analysis_status": "completed", | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| error_output = {"error": str(e)} | |
| error_metadata = { | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| "error_traceback": traceback.format_exc(), | |
| } | |
| return error_output, error_metadata | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| organs: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Async version of _run.""" | |
| return self._run(image_path, organs) | |