Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import requests | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from llava.conversation import conv_templates | |
| import time | |
| import os | |
| import glob | |
| import logging | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| import re | |
| from typing import Dict, List, Optional, Union, Any, Tuple | |
| def process_image(image_path: str, target_size: int = 640) -> Image.Image: | |
| """Process and resize an image to match model requirements. | |
| Args: | |
| image_path: Path to the input image file | |
| target_size: Target size for both width and height in pixels | |
| Returns: | |
| PIL.Image: Processed and padded image with dimensions (target_size, target_size) | |
| """ | |
| image = Image.open(image_path) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Calculate scaling to maintain aspect ratio | |
| ratio = min(target_size / image.width, target_size / image.height) | |
| new_size = (int(image.width * ratio), int(image.height * ratio)) | |
| # Resize image | |
| image = image.resize(new_size, Image.LANCZOS) | |
| # Create new image with padding | |
| new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0)) | |
| # Paste resized image in center | |
| offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2) | |
| new_image.paste(image, offset) | |
| return new_image | |
| def validate_answer(response_text: str) -> Optional[str]: | |
| """Extract and validate a single-letter response from the model's output. | |
| Handles multiple response formats and edge cases. | |
| Args: | |
| response_text: The full text output from the model | |
| Returns: | |
| A single letter answer (A-F) or None if no valid answer found | |
| """ | |
| if not response_text: | |
| return None | |
| # Clean the response text | |
| cleaned = response_text.strip() | |
| # Comprehensive set of patterns to extract the answer | |
| extraction_patterns = [ | |
| # Strict format with explicit letter answer | |
| r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b", | |
| # Patterns for extracting from longer descriptions | |
| r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", | |
| r"\b(?:answer|option)\s*([A-F])[):]\s*", | |
| # Patterns for extracting from descriptive sentences | |
| r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", | |
| r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b", | |
| # Patterns with contextual words | |
| r"characteriz[e]?d?\s+by\s+([A-F])\b", | |
| r"indicat[e]?s?\s+([A-F])\b", | |
| # Fallback to Option X or Letterr X formats | |
| r"Option\s*([A-F])\b", | |
| r"\b([A-F])\)\s*", | |
| # Fallback to standalone letter | |
| r"^\s*([A-F])\s*$", | |
| ] | |
| # Try each pattern | |
| for pattern in extraction_patterns: | |
| matches = re.findall(pattern, cleaned, re.IGNORECASE) | |
| for match in matches: | |
| # Ensure match is a single valid letter | |
| if isinstance(match, tuple): | |
| match = match[0] if match[0] in "ABCDEF" else None | |
| if match and match.upper() in "ABCDEF": | |
| return match.upper() | |
| # Final fallback: look for standalone letters in context | |
| context_matches = re.findall(r"\b([A-F])\b", cleaned.upper()) | |
| context_letters = [m for m in context_matches if m in "ABCDEF"] | |
| if context_letters: | |
| return context_letters[0] | |
| # No valid answer found | |
| return None | |
| def load_benchmark_questions(case_id: str) -> List[str]: | |
| """Find all question files for a given case ID. | |
| Args: | |
| case_id: The ID of the medical case | |
| Returns: | |
| List of paths to question JSON files | |
| """ | |
| benchmark_dir = "MedMAX/benchmark/questions" | |
| return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") | |
| def count_total_questions() -> Tuple[int, int]: | |
| """Count total number of cases and questions in benchmark. | |
| Returns: | |
| Tuple containing (total_cases, total_questions) | |
| """ | |
| total_cases = len(glob.glob("MedMAX/benchmark/questions/*")) | |
| total_questions = sum( | |
| len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json")) | |
| for case_id in os.listdir("MedMAX/benchmark/questions") | |
| ) | |
| return total_cases, total_questions | |
| def create_inference_request( | |
| question_data: Dict[str, Any], | |
| case_details: Dict[str, Any], | |
| case_id: str, | |
| question_id: str, | |
| worker_addr: str, | |
| model_name: str, | |
| raw_output: bool = False, | |
| ) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]: | |
| """Create and send inference request to worker. | |
| Args: | |
| question_data: Dictionary containing question details and figures | |
| case_details: Dictionary containing case information and figures | |
| case_id: Identifier for the medical case | |
| question_id: Identifier for the specific question | |
| worker_addr: Address of the worker endpoint | |
| model_name: Name of the model to use | |
| raw_output: Whether to return raw model output | |
| Returns: | |
| If raw_output is False: Tuple of (validated_answer, duration) | |
| If raw_output is True: Dictionary with full inference details | |
| """ | |
| system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'. | |
| """ | |
| prompt = f"""Given the following medical case: | |
| Please answer this multiple choice question: | |
| {question_data['question']} | |
| Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """ | |
| try: | |
| # Parse required figures | |
| if isinstance(question_data["figures"], str): | |
| try: | |
| required_figures = json.loads(question_data["figures"]) | |
| except json.JSONDecodeError: | |
| required_figures = [question_data["figures"]] | |
| elif isinstance(question_data["figures"], list): | |
| required_figures = question_data["figures"] | |
| else: | |
| required_figures = [str(question_data["figures"])] | |
| except Exception as e: | |
| print(f"Error parsing figures: {e}") | |
| required_figures = [] | |
| required_figures = [ | |
| fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures | |
| ] | |
| # Get image paths | |
| image_paths = [] | |
| for figure in required_figures: | |
| base_figure_num = "".join(filter(str.isdigit, figure)) | |
| figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None | |
| matching_figures = [ | |
| case_figure | |
| for case_figure in case_details.get("figures", []) | |
| if case_figure["number"] == f"Figure {base_figure_num}" | |
| ] | |
| for case_figure in matching_figures: | |
| subfigures = [] | |
| if figure_letter: | |
| subfigures = [ | |
| subfig | |
| for subfig in case_figure.get("subfigures", []) | |
| if subfig.get("number", "").lower().endswith(figure_letter.lower()) | |
| or subfig.get("label", "").lower() == figure_letter.lower() | |
| ] | |
| else: | |
| subfigures = case_figure.get("subfigures", []) | |
| for subfig in subfigures: | |
| if "local_path" in subfig: | |
| image_paths.append("MedMAX/data/" + subfig["local_path"]) | |
| if not image_paths: | |
| print(f"No local images found for case {case_id}, question {question_id}") | |
| return "skipped", 0.0 # Return a special 'skipped' marker | |
| try: | |
| start_time = time.time() | |
| # Process each image | |
| processed_images = [process_image(path) for path in image_paths] | |
| # Create conversation | |
| conv = conv_templates["mistral_instruct"].copy() | |
| # Add image and message | |
| if "<image>" not in prompt: | |
| text = prompt + "\n<image>" | |
| else: | |
| text = prompt | |
| message = (text, processed_images[0], "Default") # Currently handling first image | |
| conv.append_message(conv.roles[0], message) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| headers = {"User-Agent": "LLaVA-Med Client"} | |
| pload = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "max_new_tokens": 150, # Reduce this since we only need one letter | |
| "temperature": 0.5, # Lower temperature for more focused responses | |
| "stop": conv.sep2, | |
| "images": conv.get_images(), | |
| "top_p": 1, # Lower top_p for more focused sampling | |
| "frequency_penalty": 0.0, | |
| "presence_penalty": 0.0, | |
| } | |
| max_retries = 3 | |
| retry_delay = 5 | |
| response_text = None | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post( | |
| worker_addr + "/worker_generate_stream", | |
| headers=headers, | |
| json=pload, | |
| stream=True, | |
| timeout=30, | |
| ) | |
| complete_output = "" | |
| for chunk in response.iter_lines( | |
| chunk_size=8192, decode_unicode=False, delimiter=b"\0" | |
| ): | |
| if chunk: | |
| data = json.loads(chunk.decode("utf-8")) | |
| if data["error_code"] == 0: | |
| output = data["text"].split("[/INST]")[-1] | |
| complete_output = output | |
| else: | |
| print(f"\nError: {data['text']} (error_code: {data['error_code']})") | |
| if attempt < max_retries - 1: | |
| time.sleep(retry_delay) | |
| break | |
| return None, None | |
| if complete_output: | |
| response_text = complete_output | |
| break | |
| except (requests.exceptions.RequestException, json.JSONDecodeError) as e: | |
| if attempt < max_retries - 1: | |
| print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| print(f"\nFailed after {max_retries} attempts: {str(e)}") | |
| return None, None | |
| duration = time.time() - start_time | |
| if raw_output: | |
| inference_details = { | |
| "raw_output": response_text, | |
| "validated_answer": validate_answer(response_text), | |
| "duration": duration, | |
| "prompt": prompt, | |
| "system_prompt": system_prompt, | |
| "image_paths": image_paths, | |
| "payload": pload, | |
| } | |
| return inference_details | |
| return validate_answer(response_text), duration | |
| except Exception as e: | |
| print(f"Error in inference request: {str(e)}") | |
| return None, None | |
| def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
| """Remove image-related and large data from the payload to keep the log lean. | |
| Args: | |
| payload: Original request payload dictionary | |
| Returns: | |
| Cleaned payload dictionary with large data removed | |
| """ | |
| if not payload: | |
| return None | |
| # Create a copy of the payload to avoid modifying the original | |
| cleaned_payload = payload.copy() | |
| # Remove large or sensitive data | |
| if "images" in cleaned_payload: | |
| del cleaned_payload["images"] | |
| return cleaned_payload | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--controller-address", type=str, default="http://localhost:21001") | |
| parser.add_argument("--worker-address", type=str) | |
| parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b") | |
| parser.add_argument("--output-dir", type=str, default="benchmark_results") | |
| parser.add_argument( | |
| "--raw-output", action="store_true", help="Return raw model output without validation" | |
| ) | |
| parser.add_argument( | |
| "--num-cases", | |
| type=int, | |
| help="Number of cases to process if looking at raw outputs", | |
| default=2, | |
| ) | |
| args = parser.parse_args() | |
| # Setup output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Setup live logging files | |
| live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json") | |
| final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json") | |
| # Initialize live log file | |
| with open(live_log_filename, "w") as live_log_file: | |
| live_log_file.write("[\n") # Start of JSON array | |
| # Setup logging | |
| logging.basicConfig( | |
| filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"), | |
| level=logging.INFO, | |
| format="%(message)s", | |
| ) | |
| # Get worker address | |
| if args.worker_address: | |
| worker_addr = args.worker_address | |
| else: | |
| try: | |
| requests.post(args.controller_address + "/refresh_all_workers") | |
| ret = requests.post(args.controller_address + "/list_models") | |
| models = ret.json()["models"] | |
| ret = requests.post( | |
| args.controller_address + "/get_worker_address", json={"model": args.model_name} | |
| ) | |
| worker_addr = ret.json()["address"] | |
| print(f"Worker address: {worker_addr}") | |
| except requests.exceptions.RequestException as e: | |
| print(f"Failed to connect to controller: {e}") | |
| return | |
| if worker_addr == "": | |
| print("No available worker") | |
| return | |
| # Load cases with local paths | |
| with open("MedMAX/data/updated_cases.json", "r") as file: | |
| data = json.load(file) | |
| total_cases, total_questions = count_total_questions() | |
| print(f"\nStarting benchmark with {args.model_name}") | |
| print(f"Found {total_cases} cases with {total_questions} total questions") | |
| results = { | |
| "model": args.model_name, | |
| "timestamp": datetime.now().isoformat(), | |
| "total_cases": total_cases, | |
| "total_questions": total_questions, | |
| "results": [], | |
| } | |
| cases_processed = 0 | |
| questions_processed = 0 | |
| correct_answers = 0 | |
| skipped_questions = 0 | |
| total_processed_entries = 0 | |
| # Process each case | |
| for case_id, case_details in tqdm(data.items(), desc="Processing cases"): | |
| question_files = load_benchmark_questions(case_id) | |
| if not question_files: | |
| continue | |
| cases_processed += 1 | |
| for question_file in tqdm( | |
| question_files, desc=f"Processing questions for case {case_id}", leave=False | |
| ): | |
| with open(question_file, "r") as file: | |
| question_data = json.load(file) | |
| question_id = os.path.basename(question_file).split(".")[0] | |
| questions_processed += 1 | |
| # Get model's answer | |
| inference_result = create_inference_request( | |
| question_data, | |
| case_details, | |
| case_id, | |
| question_id, | |
| worker_addr, | |
| args.model_name, | |
| raw_output=True, # Always use raw output for detailed logging | |
| ) | |
| # Handle skipped questions | |
| if inference_result == ("skipped", 0.0): | |
| skipped_questions += 1 | |
| print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)") | |
| # Log skipped question | |
| skipped_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "status": "skipped", | |
| "reason": "No images found", | |
| } | |
| with open(live_log_filename, "a") as live_log_file: | |
| json.dump(skipped_entry, live_log_file, indent=2) | |
| live_log_file.write(",\n") # Add comma for next entry | |
| continue | |
| # Extract information | |
| answer = inference_result["validated_answer"] | |
| duration = inference_result["duration"] | |
| # Prepare detailed logging entry | |
| log_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "question": question_data["question"], | |
| "correct_answer": question_data["answer"], | |
| "raw_output": inference_result["raw_output"], | |
| "validated_answer": answer, | |
| "model_answer": answer, | |
| "is_correct": answer == question_data["answer"] if answer else False, | |
| "duration": duration, | |
| "system_prompt": inference_result["system_prompt"], | |
| "input_prompt": inference_result["prompt"], | |
| "image_paths": inference_result["image_paths"], | |
| "payload": clean_payload(inference_result["payload"]), | |
| } | |
| # Write to live log file | |
| with open(live_log_filename, "a") as live_log_file: | |
| json.dump(log_entry, live_log_file, indent=2) | |
| live_log_file.write(",\n") # Add comma for next entry | |
| # Print to console | |
| print(f"\nCase {case_id}, Question {question_id}") | |
| print(f"Model Answer: {answer}") | |
| print(f"Correct Answer: {question_data['answer']}") | |
| print(f"Time taken: {duration:.2f}s") | |
| # Track correct answers | |
| if answer == question_data["answer"]: | |
| correct_answers += 1 | |
| # Append to results | |
| results["results"].append(log_entry) | |
| total_processed_entries += 1 | |
| # Optional: break if reached specified number of cases | |
| if args.raw_output and cases_processed == args.num_cases: | |
| break | |
| # Optional: break if reached specified number of cases | |
| if args.raw_output and cases_processed == args.num_cases: | |
| break | |
| # Close live log file | |
| with open(live_log_filename, "a") as live_log_file: | |
| # Remove trailing comma and close JSON array | |
| live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n' | |
| live_log_file.write("\n]") | |
| # Calculate final statistics | |
| results["summary"] = { | |
| "cases_processed": cases_processed, | |
| "questions_processed": questions_processed, | |
| "total_processed_entries": total_processed_entries, | |
| "correct_answers": correct_answers, | |
| "skipped_questions": skipped_questions, | |
| "accuracy": ( | |
| correct_answers / (questions_processed - skipped_questions) | |
| if (questions_processed - skipped_questions) > 0 | |
| else 0 | |
| ), | |
| } | |
| # Save final results | |
| with open(final_results_filename, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nBenchmark Summary:") | |
| print(f"Total Cases Processed: {cases_processed}") | |
| print(f"Total Questions Processed: {questions_processed}") | |
| print(f"Total Processed Entries: {total_processed_entries}") | |
| print(f"Correct Answers: {correct_answers}") | |
| print(f"Skipped Questions: {skipped_questions}") | |
| print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%") | |
| print(f"\nResults saved to {args.output_dir}") | |
| print(f"Live log: {live_log_filename}") | |
| print(f"Final results: {final_results_filename}") | |
| if __name__ == "__main__": | |
| main() | |