Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| import warnings | |
| from typing import * | |
| from dotenv import load_dotenv | |
| from transformers import logging | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_openai import ChatOpenAI | |
| from interface import create_demo | |
| from medrax.agent import * | |
| from medrax.tools import * | |
| from medrax.utils import * | |
| warnings.filterwarnings("ignore") | |
| logging.set_verbosity_error() | |
| _ = load_dotenv() | |
| def initialize_agent( | |
| prompt_file, tools_to_use=None, model_dir="/model-weights", temp_dir="temp", device=device | |
| ): | |
| """Initialize the MedRAX agent with specified tools and configuration. | |
| Args: | |
| prompt_file (str): Path to file containing system prompts | |
| tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized. | |
| model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights". | |
| temp_dir (str, optional): Directory for temporary files. Defaults to "temp". | |
| device (torch.device, optional): Device to run models on. Defaults to CUDA if available. | |
| Returns: | |
| Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances | |
| """ | |
| prompts = load_prompts_from_file(prompt_file) | |
| prompt = prompts["MEDICAL_ASSISTANT"] | |
| all_tools = { | |
| "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device), | |
| "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device), | |
| "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True), | |
| "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device), | |
| "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool( | |
| cache_dir=model_dir, device=device | |
| ), | |
| "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool( | |
| cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device | |
| ), | |
| "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool( | |
| model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device | |
| ), | |
| "ImageVisualizerTool": lambda: ImageVisualizerTool(), | |
| "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir), | |
| } | |
| # Initialize only selected tools or all if none specified | |
| tools_dict = {} | |
| tools_to_use = tools_to_use or all_tools.keys() | |
| for tool_name in tools_to_use: | |
| if tool_name in all_tools: | |
| tools_dict[tool_name] = all_tools[tool_name]() | |
| checkpointer = MemorySaver() | |
| model = ChatOpenAI(model="gpt-4o", temperature=0.7, top_p=0.95) | |
| agent = Agent( | |
| model, | |
| tools=list(tools_dict.values()), | |
| log_tools=True, | |
| log_dir="logs", | |
| system_prompt=prompt, | |
| checkpointer=checkpointer, | |
| ) | |
| print("Agent initialized") | |
| return agent, tools_dict | |
| if __name__ == "__main__": | |
| """ | |
| This is the main entry point for the MedRAX application. | |
| It initializes the agent with the selected tools and creates the demo. | |
| """ | |
| print("Starting server...") | |
| # Example: initialize with only specific tools | |
| selected_tools = [ | |
| "ImageVisualizerTool", | |
| "DicomProcessorTool", | |
| "ChestXRayClassifierTool", | |
| "ChestXRaySegmentationTool", | |
| "ChestXRayReportGeneratorTool", | |
| "XRayVQATool", | |
| # "LlavaMedTool", | |
| # "XRayPhraseGroundingTool", | |
| # "ChestXRayGeneratorTool", | |
| ] | |
| agent, tools_dict = initialize_agent( | |
| "medrax/docs/system_prompts.txt", tools_to_use=selected_tools, model_dir="/model-weights" | |
| ) | |
| demo = create_demo(agent, tools_dict) | |
| demo.launch(server_name="0.0.0.0", server_port=8585, share=True) | |