Spaces:
Runtime error
Runtime error
| # medgemma_tool.py | |
| from typing import Any, Dict, Optional, Tuple, Type | |
| from pathlib import Path | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoModelForImageTextToText, | |
| AutoProcessor, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| from langchain_core.callbacks import ( | |
| CallbackManagerForToolRun, | |
| AsyncCallbackManagerForToolRun, | |
| ) | |
| class MedGemmaInput(BaseModel): | |
| """Input schema for MedGEMMA X-ray tool.""" | |
| image_path: str = Field(..., description="Path to a chest X-ray image") | |
| prompt: str = Field(..., description="Question or instruction for the image") | |
| max_new_tokens: int = Field( | |
| 300, | |
| description="Maximum number of tokens to generate in the answer", | |
| ) | |
| class MedGemmaXRayTool(BaseTool): | |
| """A tool that uses medgemma to answer questions about chest X-ray images.""" | |
| name: str = "medgemma_xray_expert" | |
| description: str = ( | |
| "The 1st tool to be used by the agent to answer any questions related to xray images." | |
| "The tool is specialized in performing multiple tasks including Visual Question Answering," | |
| "Report generation, Abnormality detection, Anatomical localization, Clinical interpretations," | |
| "Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to" | |
| "X-ray images and a natural language prompt describing the task to be carried out." | |
| ) | |
| args_schema: Type[BaseModel] = MedGemmaInput | |
| return_direct: bool = True | |
| # model handles | |
| model: Optional[AutoModelForImageTextToText] = None | |
| processor: Optional[AutoProcessor] = None | |
| # config | |
| model_name: str = "google/medgemma-4b-it" | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype: torch.dtype = torch.bfloat16 | |
| def __init__( | |
| self, | |
| model_name: str = "google/medgemma-4b-it", | |
| device: Optional[str] = None, | |
| dtype: torch.dtype = torch.bfloat16, | |
| cache_dir: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.model_name = model_name | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = dtype | |
| # Load model & processor | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_name, trust_remote_code=True, cache_dir=cache_dir | |
| ) | |
| self.model.eval() | |
| def _generate( | |
| self, | |
| image_path: str, | |
| prompt: str, | |
| max_new_tokens: int, | |
| ) -> str: | |
| """Run MedGEMMA and return decoded answer.""" | |
| img = Image.open(image_path).convert("RGB") | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image", "image": img}, | |
| ], | |
| }, | |
| ] | |
| # 3. Tokenise with chat template | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(self.model.device, dtype=self.dtype) | |
| start_len = inputs["input_ids"].shape[-1] | |
| # 4. Generate | |
| with torch.inference_mode(): | |
| gens = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| ) | |
| decoded = self.processor.decode( | |
| gens[0][start_len:], skip_special_tokens=True | |
| ) | |
| return decoded.strip() | |
| def _run( | |
| self, | |
| image_path: str, | |
| prompt: str, | |
| max_new_tokens: int = 300, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Validate, invoke model, return output + metadata.""" | |
| try: | |
| if not Path(image_path).is_file(): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| answer = self._generate(image_path, prompt, max_new_tokens) | |
| return ( | |
| {"response": answer}, | |
| { | |
| "image_path": image_path, | |
| "prompt": prompt, | |
| "max_new_tokens": max_new_tokens, | |
| "status": "completed", | |
| }, | |
| ) | |
| except Exception as e: | |
| return ( | |
| {"error": str(e)}, | |
| { | |
| "image_path": image_path, | |
| "prompt": prompt, | |
| "max_new_tokens": max_new_tokens, | |
| "status": "failed", | |
| "error": str(e), | |
| }, | |
| ) | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| prompt: str, | |
| max_new_tokens: int = 300, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Asynchronous wrapper (delegates to sync).""" | |
| return self._run(image_path, prompt, max_new_tokens) | |