Spaces:
Runtime error
Runtime error
| from typing import Dict, Optional, Tuple, Type | |
| from pydantic import BaseModel, Field | |
| import skimage.io | |
| import torch | |
| import torchvision | |
| import torchxrayvision as xrv | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| class ChestXRayInput(BaseModel): | |
| """Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" | |
| image_path: str = Field( | |
| ..., description="Path to the radiology image file, only supports JPG or PNG images" | |
| ) | |
| class ChestXRayClassifierTool(BaseTool): | |
| """Tool that classifies chest X-ray images for multiple pathologies. | |
| This tool uses a pre-trained DenseNet model to analyze chest X-ray images and | |
| predict the likelihood of various pathologies. The model can classify the following 18 conditions: | |
| Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, | |
| Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, | |
| Lung Lesion, Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax | |
| The output values represent the probability (from 0 to 1) of each condition being present in the image. | |
| A higher value indicates a higher likelihood of the condition being present. | |
| """ | |
| name: str = "chest_xray_classifier" | |
| description: str = ( | |
| "A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. " | |
| "Input should be the path to a chest X-ray image file. " | |
| "Output is a dictionary of pathologies and their predicted probabilities (0 to 1). " | |
| "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, " | |
| "Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, Lung Lesion, " | |
| "Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. " | |
| "Higher values indicate a higher likelihood of the condition being present." | |
| ) | |
| args_schema: Type[BaseModel] = ChestXRayInput | |
| model: xrv.models.DenseNet = None | |
| device: Optional[str] = "cuda" | |
| transform: torchvision.transforms.Compose = None | |
| def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = "cuda"): | |
| super().__init__() | |
| self.model = xrv.models.DenseNet(weights=model_name) | |
| self.model.eval() | |
| self.device = torch.device(device) if device else "cuda" | |
| self.model = self.model.to(self.device) | |
| self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) | |
| def _process_image(self, image_path: str) -> torch.Tensor: | |
| """ | |
| Process the input chest X-ray image for model inference. | |
| This method loads the image, normalizes it, applies necessary transformations, | |
| and prepares it as a torch.Tensor for model input. | |
| Args: | |
| image_path (str): The file path to the chest X-ray image. | |
| Returns: | |
| torch.Tensor: A processed image tensor ready for model inference. | |
| Raises: | |
| FileNotFoundError: If the specified image file does not exist. | |
| ValueError: If the image cannot be properly loaded or processed. | |
| """ | |
| img = skimage.io.imread(image_path) | |
| img = xrv.datasets.normalize(img, 255) | |
| if len(img.shape) > 2: | |
| img = img[:, :, 0] | |
| img = img[None, :, :] | |
| img = self.transform(img) | |
| img = torch.from_numpy(img).unsqueeze(0) | |
| img = img.to(self.device) | |
| return img | |
| def _run( | |
| self, | |
| image_path: str, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, float], Dict]: | |
| """Classify the chest X-ray image for multiple pathologies. | |
| Args: | |
| image_path (str): The path to the chest X-ray image file. | |
| run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. | |
| Returns: | |
| Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
| (pathologies and their probabilities from 0 to 1) | |
| and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the image or during classification. | |
| """ | |
| try: | |
| img = self._process_image(image_path) | |
| with torch.inference_mode(): | |
| preds = self.model(img).cpu()[0] | |
| output = dict(zip(xrv.datasets.default_pathologies, preds.cpu().numpy().astype(float).tolist())) | |
| metadata = { | |
| "image_path": image_path, | |
| "analysis_status": "completed", | |
| "note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.", | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| return {"error": str(e)}, { | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| } | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, float], Dict]: | |
| """Asynchronously classify the chest X-ray image for multiple pathologies. | |
| This method currently calls the synchronous version, as the model inference | |
| is not inherently asynchronous. For true asynchronous behavior, consider | |
| using a separate thread or process. | |
| Args: | |
| image_path (str): The path to the chest X-ray image file. | |
| run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. | |
| Returns: | |
| Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
| (pathologies and their probabilities from 0 to 1) | |
| and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the image or during classification. | |
| """ | |
| return self._run(image_path) | |