vaibhavm29 commited on
Commit
7d4bd7e
Β·
1 Parent(s): 917983f

included medgemma tool

Browse files
Files changed (4) hide show
  1. app.py +2 -0
  2. medrax/tools/__init__.py +1 -0
  3. medrax/tools/medgemma.py +170 -0
  4. pyproject.toml +1 -1
app.py CHANGED
@@ -54,6 +54,7 @@ def initialize_agent(
54
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
55
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
56
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
 
57
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
58
  cache_dir=model_dir, device=device
59
  ),
@@ -107,6 +108,7 @@ if __name__ == "__main__":
107
  "XRayVQATool",
108
  "LlavaMedTool",
109
  "XRayPhraseGroundingTool",
 
110
  # "ChestXRayGeneratorTool",
111
  ]
112
 
 
54
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
55
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
56
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
57
+ "MedgemmaXRayTool": lambda: MedGemmaXRayTool(cache_dir=model_dir, device=device),
58
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
59
  cache_dir=model_dir, device=device
60
  ),
 
108
  "XRayVQATool",
109
  "LlavaMedTool",
110
  "XRayPhraseGroundingTool",
111
+ "MedGemmaXRayTool"
112
  # "ChestXRayGeneratorTool",
113
  ]
114
 
medrax/tools/__init__.py CHANGED
@@ -9,3 +9,4 @@ from .grounding import *
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
 
 
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
12
+ from .medgemma import *
medrax/tools/medgemma.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # medgemma_tool.py
2
+ from typing import Any, Dict, Optional, Tuple, Type
3
+
4
+ from pathlib import Path
5
+ from pydantic import BaseModel, Field
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import (
10
+ AutoModelForImageTextToText,
11
+ AutoProcessor,
12
+ )
13
+
14
+ from langchain_core.tools import BaseTool
15
+ from langchain_core.callbacks import (
16
+ CallbackManagerForToolRun,
17
+ AsyncCallbackManagerForToolRun,
18
+ )
19
+
20
+ class MedGemmaInput(BaseModel):
21
+ """Input schema for MedGEMMA X-ray tool."""
22
+ image_path: str = Field(..., description="Path to a chest X-ray image")
23
+ prompt: str = Field(..., description="Question or instruction for the image")
24
+ max_new_tokens: int = Field(
25
+ 300,
26
+ description="Maximum number of tokens to generate in the answer",
27
+ )
28
+
29
+
30
+ class MedGemmaXRayTool(BaseTool):
31
+ """A tool that uses medgemma to answer questions about chest X-ray images."""
32
+
33
+ name: str = "medgemma_xray_expert"
34
+ description: str = (
35
+ "The 1st tool to be used by the agent to answer any questions related to xray images."
36
+ "The tool is specialized in performing multiple tasks including Visual Question Answering,"
37
+ "Report generation, Abnormality detection, Anatomical localization, Clinical interpretations,"
38
+ "Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to"
39
+ "X-ray images and a natural language prompt describing the task to be carried out."
40
+ )
41
+ args_schema: Type[BaseModel] = MedGemmaInput
42
+ return_direct: bool = True
43
+
44
+ # model handles
45
+ model: Optional[AutoModelForImageTextToText] = None
46
+ processor: Optional[AutoProcessor] = None
47
+
48
+ # config
49
+ model_name: str = "google/medgemma-4b-it"
50
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
51
+ dtype: torch.dtype = torch.bfloat16
52
+
53
+ def __init__(
54
+ self,
55
+ model_name: str = "google/medgemma-4b-it",
56
+ device: Optional[str] = None,
57
+ dtype: torch.dtype = torch.bfloat16,
58
+ cache_dir: Optional[str] = None,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(**kwargs)
62
+
63
+ self.model_name = model_name
64
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
65
+ self.dtype = dtype
66
+
67
+ # Load model & processor
68
+ self.model = AutoModelForImageTextToText.from_pretrained(
69
+ model_name,
70
+ device_map="auto",
71
+ torch_dtype=dtype,
72
+ trust_remote_code=True,
73
+ cache_dir=cache_dir,
74
+ )
75
+ self.processor = AutoProcessor.from_pretrained(
76
+ model_name, trust_remote_code=True, cache_dir=cache_dir
77
+ )
78
+ self.model.eval()
79
+
80
+ def _generate(
81
+ self,
82
+ image_path: str,
83
+ prompt: str,
84
+ max_new_tokens: int,
85
+ ) -> str:
86
+ """Run MedGEMMA and return decoded answer."""
87
+ img = Image.open(image_path).convert("RGB")
88
+
89
+ messages = [
90
+ {
91
+ "role": "system",
92
+ "content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}],
93
+ },
94
+ {
95
+ "role": "user",
96
+ "content": [
97
+ {"type": "text", "text": prompt},
98
+ {"type": "image", "image": img},
99
+ ],
100
+ },
101
+ ]
102
+
103
+ # 3. Tokenise with chat template
104
+ inputs = self.processor.apply_chat_template(
105
+ messages,
106
+ add_generation_prompt=True,
107
+ tokenize=True,
108
+ return_dict=True,
109
+ return_tensors="pt",
110
+ ).to(self.model.device, dtype=self.dtype)
111
+
112
+ start_len = inputs["input_ids"].shape[-1]
113
+
114
+ # 4. Generate
115
+ with torch.inference_mode():
116
+ gens = self.model.generate(
117
+ **inputs,
118
+ max_new_tokens=max_new_tokens,
119
+ do_sample=False,
120
+ )
121
+ decoded = self.processor.decode(
122
+ gens[0][start_len:], skip_special_tokens=True
123
+ )
124
+ return decoded.strip()
125
+
126
+ def _run(
127
+ self,
128
+ image_path: str,
129
+ prompt: str,
130
+ max_new_tokens: int = 300,
131
+ run_manager: Optional[CallbackManagerForToolRun] = None,
132
+ ) -> Tuple[Dict[str, Any], Dict]:
133
+ """Validate, invoke model, return output + metadata."""
134
+ try:
135
+ if not Path(image_path).is_file():
136
+ raise FileNotFoundError(f"Image not found: {image_path}")
137
+
138
+ answer = self._generate(image_path, prompt, max_new_tokens)
139
+
140
+ return (
141
+ {"response": answer},
142
+ {
143
+ "image_path": image_path,
144
+ "prompt": prompt,
145
+ "max_new_tokens": max_new_tokens,
146
+ "status": "completed",
147
+ },
148
+ )
149
+
150
+ except Exception as e:
151
+ return (
152
+ {"error": str(e)},
153
+ {
154
+ "image_path": image_path,
155
+ "prompt": prompt,
156
+ "max_new_tokens": max_new_tokens,
157
+ "status": "failed",
158
+ "error": str(e),
159
+ },
160
+ )
161
+
162
+ async def _arun(
163
+ self,
164
+ image_path: str,
165
+ prompt: str,
166
+ max_new_tokens: int = 300,
167
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
168
+ ) -> Tuple[Dict[str, Any], Dict]:
169
+ """Asynchronous wrapper (delegates to sync)."""
170
+ return self._run(image_path, prompt, max_new_tokens)
pyproject.toml CHANGED
@@ -24,7 +24,7 @@ dependencies = [
24
  "pydantic>=1.8.0",
25
  "Pillow>=8.0.0",
26
  "torchxrayvision>=0.0.37",
27
- "transformers @ git+https://github.com/huggingface/transformers.git@88d960937c81a32bfb63356a2e8ecf7999619681",
28
  "tokenizers>=0.10.0",
29
  "sentencepiece>=0.1.95",
30
  "shortuuid>=1.0.0",
 
24
  "pydantic>=1.8.0",
25
  "Pillow>=8.0.0",
26
  "torchxrayvision>=0.0.37",
27
+ "transformers>=4.46.3",
28
  "tokenizers>=0.10.0",
29
  "sentencepiece>=0.1.95",
30
  "shortuuid>=1.0.0",