Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, Optional, Tuple, Type | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| from PIL import Image | |
| from medrax.llava.conversation import conv_templates | |
| from medrax.llava.model.builder import load_pretrained_model | |
| from medrax.llava.mm_utils import tokenizer_image_token, process_images | |
| from medrax.llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| class LlavaMedInput(BaseModel): | |
| """Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images.""" | |
| question: str = Field(..., description="The question to ask about the medical image") | |
| image_path: Optional[str] = Field( | |
| None, | |
| description="Path to the medical image file (optional), only supports JPG or PNG images", | |
| ) | |
| class LlavaMedTool(BaseTool): | |
| """Tool that performs medical visual question answering using LLaVA-Med. | |
| This tool uses a large language model fine-tuned on medical images to answer | |
| questions about medical images. It can handle both image-based questions and | |
| general medical questions without images. | |
| """ | |
| name: str = "llava_med_qa" | |
| description: str = ( | |
| "A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. " | |
| "While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. " | |
| "Input should be a question and optionally a path to a medical image file." | |
| ) | |
| args_schema: Type[BaseModel] = LlavaMedInput | |
| tokenizer: Any = None | |
| model: Any = None | |
| image_processor: Any = None | |
| context_len: int = 200000 | |
| def __init__( | |
| self, | |
| model_path: str = "microsoft/llava-med-v1.5-mistral-7b", | |
| # model_path: str = "microsoft/llava-rad", | |
| cache_dir: str = "/model-weights", | |
| low_cpu_mem_usage: bool = True, | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| device: str = "cuda", | |
| load_in_4bit: bool = False, | |
| load_in_8bit: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( | |
| model_path=model_path, | |
| model_base=None, | |
| # model_base="lmsys/vicuna-7b-v1.5", | |
| model_name=model_path, | |
| load_in_4bit=load_in_4bit, | |
| load_in_8bit=load_in_8bit, | |
| cache_dir=cache_dir, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| **kwargs, | |
| ) | |
| self.model.eval() | |
| def _process_input( | |
| self, question: str, image_path: Optional[str] = None | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| if self.model.config.mm_use_im_start_end: | |
| question = ( | |
| DEFAULT_IM_START_TOKEN | |
| + DEFAULT_IMAGE_TOKEN | |
| + DEFAULT_IM_END_TOKEN | |
| + "\n" | |
| + question | |
| ) | |
| else: | |
| question = DEFAULT_IMAGE_TOKEN + "\n" + question | |
| conv = conv_templates["vicuna_v1"].copy() | |
| conv.append_message(conv.roles[0], question) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = ( | |
| tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| image_tensor = None | |
| if image_path: | |
| image = Image.open(image_path) | |
| image_tensor = process_images([image], self.image_processor, self.model.config)[0] | |
| image_tensor = image_tensor.unsqueeze(0).half().cuda() | |
| return input_ids, image_tensor | |
| def _run( | |
| self, | |
| question: str, | |
| image_path: Optional[str] = None, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[str, Dict]: | |
| """Answer a medical question, optionally based on an input image. | |
| Args: | |
| question (str): The medical question to answer. | |
| image_path (Optional[str]): The path to the medical image file (if applicable). | |
| run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. | |
| Returns: | |
| Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the input or generating the answer. | |
| """ | |
| try: | |
| input_ids, image_tensor = self._process_input(question, image_path) | |
| input_ids = input_ids.to(device=self.model.device) | |
| image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype) | |
| with torch.inference_mode(): | |
| output_ids = self.model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=False, | |
| temperature=0.2, | |
| max_new_tokens=500, | |
| use_cache=True, | |
| ) | |
| output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | |
| metadata = { | |
| "question": question, | |
| "image_path": image_path, | |
| "analysis_status": "completed", | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| return f"Error generating answer: {str(e)}", { | |
| "question": question, | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| } | |
| async def _arun( | |
| self, | |
| question: str, | |
| image_path: Optional[str] = None, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[str, Dict]: | |
| """Asynchronously answer a medical question, optionally based on an input image. | |
| This method currently calls the synchronous version, as the model inference | |
| is not inherently asynchronous. For true asynchronous behavior, consider | |
| using a separate thread or process. | |
| Args: | |
| question (str): The medical question to answer. | |
| image_path (Optional[str]): The path to the medical image file (if applicable). | |
| run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. | |
| Returns: | |
| Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the input or generating the answer. | |
| """ | |
| return self._run(question, image_path) | |