Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Medical X-ray Question Generation Benchmark aka ChestAgentBench | |
| This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o. | |
| It structures questions across different analytical categories and saves them as JSON. | |
| """ | |
| import os | |
| import re | |
| import json | |
| from typing import * | |
| from pprint import pprint | |
| import openai | |
| import numpy as np | |
| from scipy import stats | |
| import plotly.graph_objects as go | |
| from tqdm import tqdm | |
| from benchmark.utils import load_eurorad_dataset | |
| from benchmark.llm import get_llm_response | |
| # Constants | |
| DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data" | |
| DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json") | |
| SYSTEM_PROMPT = """ | |
| You are an expert medical benchmark creation assistant. | |
| Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays. | |
| """.strip() | |
| CATEGORIES_META = { | |
| "detection": "Identify and locate specific findings in the chest X-ray.", | |
| "classification": "Determine whether specific findings are present or absent in the chest X-ray.", | |
| "enumeration": "Count the number of target findings in the chest X-ray.", | |
| "localization": "Locate a given finding in the chest X-ray.", | |
| "comparison": "Compare the size or position of a specific finding in the chest X-ray.", | |
| "relationship": "Determine the relationship between two or more findings in the chest X-ray.", | |
| "diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.", | |
| "characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.", | |
| "reasoning": "Explain the medical rationale and thought process behind findings and conclusions.", | |
| } | |
| CATEGORIES = list(CATEGORIES_META.keys()) | |
| CATEGORY_COMBINATIONS = [ | |
| ["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis | |
| ["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations | |
| ["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding | |
| ["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making | |
| ["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization | |
| ] | |
| DEFAULT_SECTIONS = [ | |
| "history", | |
| "image_finding", | |
| "discussion", | |
| "differential_diagnosis", | |
| "diagnosis", | |
| "figures", | |
| ] | |
| class Question: | |
| """A class to generate clinical questions from case data. | |
| This class handles creating structured clinical questions by combining case data with | |
| specified categories and difficulty levels. | |
| Attributes: | |
| type (str): The type of question (e.g. multiple choice) | |
| difficulty (str): Difficulty level of the question | |
| case_data (Dict[str, Any]): Dictionary containing the clinical case data | |
| case_content (str): Formatted case data from selected sections | |
| case_id (str): Unique identifier for the case | |
| categories (List[str]): List of analytical categories this question tests | |
| sections (List[str]): Case sections to include in question | |
| raw_content (Optional[str]): Raw LLM response to the question prompt | |
| content (Optional[Dict[str, str]]): Extracted content from the raw LLM response | |
| """ | |
| def __init__( | |
| self, | |
| type: str, | |
| difficulty: str, | |
| case_data: Dict[str, Any], | |
| categories: List[str], | |
| sections: List[str] = [ | |
| "history", | |
| "image_finding", | |
| "discussion", | |
| "differential_diagnosis", | |
| "diagnosis", | |
| "figures", | |
| ], | |
| system_prompt: str = "You are an expert medical benchmark creation assistant.", | |
| ) -> None: | |
| self.type = type | |
| self.difficulty = difficulty | |
| self.case_data = case_data | |
| self.case_id = case_data["case_id"] | |
| self.categories = categories | |
| self.sections = sections | |
| self.system_prompt = system_prompt | |
| self.case_content = self.select_case_sections() | |
| self.raw_content: Optional[str] = None | |
| self.content: Optional[Dict[str, str]] = None | |
| def create_question_prompt(self) -> str: | |
| """Creates a formatted prompt for generating a clinical question. | |
| Returns: | |
| str: A structured prompt containing the question parameters and clinical data | |
| """ | |
| category_descriptions = "\n".join( | |
| f"{category}: {desc}" | |
| for category, desc in CATEGORIES_META.items() | |
| if category in self.categories | |
| ) | |
| return f""" | |
| You must follow these guidelines: | |
| 1. Questions must be answerable using only context and chest X-rays. | |
| - Questions must explicitly mention the referenced figures | |
| - Questions can only reference the chest X-ray figures | |
| 2. Questions must have unambiguous, verifiable answers, and should: | |
| - Challenge the agent's analytical capabilities | |
| - Require multi-step reasoning | |
| - Test ability to make precise observations | |
| - Evaluate capability to derive insights and findings from the chest X-ray | |
| 3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools. | |
| Create a {self.difficulty} {self.type} clinical question that integrates the following: | |
| {category_descriptions} | |
| based on the following clinical case: | |
| {self.case_content} | |
| Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays. | |
| Your question should require the agent to derive insights and findings from the chest X-ray by itself. | |
| Your answer should be verifiable directly in the context of the case. | |
| You can only use the image findings that come from the chest X-ray figures. | |
| Your response must follow this exact format: | |
| THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question] | |
| QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.] | |
| FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]] | |
| EXPLANATION: [short explanation of why your answer is verifiable in the case] | |
| ANSWER: [correct answer e.g. "A"] | |
| """.strip().replace( | |
| " ", "" | |
| ) # remove tabs | |
| def select_case_sections(self) -> str: | |
| """Extract and format selected sections from case data into paragraphs. | |
| Returns: | |
| str: Formatted string with case sections and content | |
| """ | |
| section_mapping = { | |
| "history": ("history", "No history provided."), | |
| "image_finding": ("image_finding", "No findings provided."), | |
| "discussion": ("discussion", "No discussion provided."), | |
| "differential_diagnosis": ( | |
| "differential_diagnosis", | |
| "No differential diagnosis provided.", | |
| ), | |
| "diagnosis": ("diagnosis", "No diagnosis provided."), | |
| "figures": ("figures", "No figures provided."), | |
| } | |
| formatted = [] | |
| for section in self.sections: | |
| if section in section_mapping: | |
| key, default = section_mapping[section] | |
| content = self.case_data.get(key, default) | |
| if key == "figures": | |
| figures_text = [] | |
| for figure in content: | |
| for subfig in figure["subfigures"]: | |
| figures_text.append(f"{subfig['number']}: {subfig['caption']}") | |
| content = "\n".join(figures_text) | |
| formatted.append(f"{section}:\n{content}") | |
| return "\n\n".join(formatted) | |
| def create_question( | |
| self, | |
| client: openai.OpenAI, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| max_tokens: int = 500, | |
| model: str = "gpt-4o", | |
| ) -> str: | |
| """Create a clinical question using LLM. | |
| Args: | |
| client (openai.OpenAI): OpenAI client instance | |
| temperature (float): Controls randomness in responses. Defaults to 0.7. | |
| top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95. | |
| max_tokens (int): Max tokens in model response. Defaults to 500. | |
| model (str): OpenAI model to use. Defaults to "gpt-4o". | |
| Returns: | |
| str: LLM response containing formatted question components | |
| """ | |
| self.raw_content = get_llm_response( | |
| client=client, | |
| prompt=self.create_question_prompt(), | |
| system_prompt=self.system_prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| model=model, | |
| ) | |
| self.content = self.extract_content() | |
| return self.raw_content | |
| def extract_content(self) -> Dict[str, str]: | |
| """Extract sections from raw LLM response using regex patterns. | |
| Returns: | |
| Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer | |
| """ | |
| keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"] | |
| content = {} | |
| for kw in keywords: | |
| pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)" | |
| match = re.search(pattern, self.raw_content, re.DOTALL) | |
| content[kw.lower()] = match.group(1).strip() if match else None | |
| return content | |
| def save(self, output_path: str) -> Dict[str, Any]: | |
| """Save question content and metadata as a JSON file. | |
| Args: | |
| output_path (str): Directory path where the JSON file will be saved | |
| Returns: | |
| Dict[str, Any]: Question data including content (thoughts, question, figures, options, | |
| explanation, answer) and metadata (type, difficulty, categories, etc.) | |
| """ | |
| question_metadata = self.content.copy() | |
| # Add metadata | |
| question_metadata["metadata"] = { | |
| "case_id": self.case_id, | |
| "type": self.type, | |
| "difficulty": self.difficulty, | |
| "categories": self.categories, | |
| "sections": self.sections, | |
| } | |
| # Create a directory for the case | |
| case_dir = os.path.join(output_path, str(self.case_id)) | |
| os.makedirs(case_dir, exist_ok=True) | |
| # Save the question metadata to a JSON file | |
| output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json") | |
| with open(output_file, "w") as f: | |
| json.dump(question_metadata, f, indent=2) | |
| return question_metadata | |
| def generate_questions( | |
| dataset: Dict[str, Any], | |
| client: openai.OpenAI, | |
| output_dir: str, | |
| skip_first: int = 100, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| max_tokens: int = 1200, | |
| model: str = "gpt-4o", | |
| ) -> None: | |
| """Generate questions for each case and category combination. | |
| Args: | |
| dataset: Dictionary of case data | |
| client: OpenAI client instance | |
| output_dir: Directory to save generated questions | |
| skip_first: Number of initial cases to skip | |
| temperature: LLM temperature parameter | |
| top_p: LLM top_p parameter | |
| max_tokens: Maximum tokens for LLM response | |
| model: LLM model name | |
| """ | |
| target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first] | |
| for case_id in tqdm(target_cases, desc="Processing cases"): | |
| case_data = dataset[case_id] | |
| for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"): | |
| question = Question( | |
| type="multiple choice (A/B/C/D/E/F)", | |
| difficulty="complex", | |
| case_data=case_data, | |
| categories=category, | |
| sections=DEFAULT_SECTIONS, | |
| system_prompt=SYSTEM_PROMPT, | |
| ) | |
| response = question.create_question( | |
| client=client, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| model=model, | |
| ) | |
| question.save(output_dir) | |
| def main(): | |
| """Main execution function.""" | |
| client = openai.OpenAI() | |
| # Load and verify dataset | |
| dataset = load_eurorad_dataset( | |
| DATASET_PATH, | |
| section="Chest Imaging", | |
| as_dict=True, | |
| filter_by_caption=[ | |
| "xray", | |
| "x-ray", | |
| "x ray", | |
| "ray", | |
| "xr", | |
| "radiograph", | |
| ], | |
| ) | |
| print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n") | |
| # Optional: Print sample case for verification | |
| case_data = dataset["16798"] | |
| pprint(case_data, sort_dicts=False) | |
| # Generate questions | |
| generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions") | |
| if __name__ == "__main__": | |
| main() | |