Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple, Union, Any | |
| import json | |
| import os | |
| import sys | |
| import argparse | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| QUESTION_TYPES = { | |
| "Detailed Finding Analysis": ["detection", "localization", "characterization"], | |
| "Pattern Recognition & Relations": ["detection", "classification", "relationship"], | |
| "Spatial Understanding": ["localization", "comparison", "relationship"], | |
| "Clinical Decision Making": ["classification", "comparison", "diagnosis"], | |
| "Diagnostic Classification": ["classification", "characterization", "diagnosis"], | |
| } | |
| def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]: | |
| """ | |
| Extract just the letter from various answer formats. | |
| Args: | |
| answer: The answer text to extract letter from | |
| Returns: | |
| Optional[str]: The extracted letter in uppercase, or None if no letter found | |
| """ | |
| if not answer: | |
| return None | |
| # Convert to string and clean | |
| answer = str(answer).strip() | |
| # If it's just a single letter, return it | |
| if len(answer) == 1 and answer.isalpha(): | |
| return answer.upper() | |
| # Try to extract letter from format like "A)" or "A." | |
| if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ": | |
| return answer[0].upper() | |
| # Try to extract letter from format like "A) Some text" | |
| if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")): | |
| return answer[0].upper() | |
| return None | |
| def analyze_gpt4_results( | |
| results_file: str, max_questions: Optional[int] = None | |
| ) -> Tuple[float, Dict, Dict, List[str], List[str]]: | |
| """ | |
| Analyze results in GPT-4 format. | |
| Args: | |
| results_file: Path to results file | |
| max_questions: Maximum number of questions to analyze | |
| Returns: | |
| Tuple containing: | |
| - overall_accuracy (float) | |
| - category_accuracies (Dict) | |
| - question_type_stats (Dict) | |
| - correct_ids (List[str]) | |
| - incorrect_ids (List[str]) | |
| """ | |
| category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) | |
| all_questions = 0 | |
| all_correct = 0 | |
| correct_ids = [] | |
| incorrect_ids = [] | |
| with open(results_file, "r") as f: | |
| lines = f.readlines() | |
| processed_questions = 0 | |
| for line in tqdm(lines, desc="Analyzing Benchmark Results"): | |
| # Check if we've hit the maximum questions | |
| if max_questions is not None and processed_questions >= max_questions: | |
| break | |
| if line.startswith("HTTP Request:"): | |
| continue | |
| try: | |
| entry = json.loads(line) | |
| metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) | |
| question_id = entry.get("question_id") | |
| model_letter = extract_answer_letter(entry.get("model_answer")) | |
| correct_letter = extract_answer_letter(entry.get("correct_answer")) | |
| if model_letter and correct_letter: | |
| all_questions += 1 | |
| processed_questions += 1 | |
| is_correct = model_letter == correct_letter | |
| if is_correct: | |
| all_correct += 1 | |
| correct_ids.append(question_id) | |
| else: | |
| incorrect_ids.append(question_id) | |
| for category in metadata.get("categories", []): | |
| category_performance[category]["total"] += 1 | |
| if is_correct: | |
| category_performance[category]["correct"] += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| return process_results( | |
| category_performance, all_questions, all_correct, correct_ids, incorrect_ids | |
| ) | |
| def analyze_llama_results( | |
| results_file: str, max_questions: Optional[int] = None | |
| ) -> Tuple[float, Dict, Dict, List[str], List[str]]: | |
| """ | |
| Analyze results in Llama format. | |
| Args: | |
| results_file: Path to results file | |
| max_questions: Maximum number of questions to analyze | |
| Returns: | |
| Tuple containing: | |
| - overall_accuracy (float) | |
| - category_accuracies (Dict) | |
| - question_type_stats (Dict) | |
| - correct_ids (List[str]) | |
| - incorrect_ids (List[str]) | |
| """ | |
| category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) | |
| all_questions = 0 | |
| all_correct = 0 | |
| correct_ids = [] | |
| incorrect_ids = [] | |
| with open(results_file, "r") as f: | |
| lines = f.readlines() | |
| # If max_questions is set, limit the number of lines processed | |
| if max_questions is not None: | |
| lines = lines[:max_questions] | |
| for line in tqdm(lines, desc="Analyzing Benchmark Results"): | |
| if line.startswith("HTTP Request:"): | |
| continue | |
| try: | |
| entry = json.loads(line) | |
| metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) | |
| question_id = entry.get("question_id") | |
| model_letter = extract_answer_letter(entry.get("model_answer")) | |
| correct_letter = extract_answer_letter(entry.get("correct_answer")) | |
| if model_letter and correct_letter: | |
| all_questions += 1 | |
| is_correct = model_letter == correct_letter | |
| if is_correct: | |
| all_correct += 1 | |
| correct_ids.append(question_id) | |
| else: | |
| incorrect_ids.append(question_id) | |
| for category in metadata.get("categories", []): | |
| category_performance[category]["total"] += 1 | |
| if is_correct: | |
| category_performance[category]["correct"] += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| return process_results( | |
| category_performance, all_questions, all_correct, correct_ids, incorrect_ids | |
| ) | |
| def analyze_chexagent_results( | |
| results_file: str, max_questions: Optional[int] = None | |
| ) -> Tuple[float, Dict, Dict, List[str], List[str]]: | |
| """ | |
| Analyze results in CheXagent format. | |
| Args: | |
| results_file: Path to results file | |
| max_questions: Maximum number of questions to analyze | |
| Returns: | |
| Tuple containing: | |
| - overall_accuracy (float) | |
| - category_accuracies (Dict) | |
| - question_type_stats (Dict) | |
| - correct_ids (List[str]) | |
| - incorrect_ids (List[str]) | |
| """ | |
| category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) | |
| all_questions = 0 | |
| all_correct = 0 | |
| correct_ids = [] | |
| incorrect_ids = [] | |
| with open(results_file, "r") as f: | |
| lines = f.readlines() | |
| # If max_questions is set, limit the number of lines processed | |
| if max_questions is not None: | |
| lines = lines[:max_questions] | |
| for line in tqdm(lines, desc="Analyzing Benchmark Results"): | |
| try: | |
| entry = json.loads(line) | |
| metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) | |
| question_id = entry.get("question_id") | |
| model_letter = extract_answer_letter(entry.get("model_answer")) | |
| correct_letter = extract_answer_letter(entry.get("correct_answer")) | |
| if model_letter and correct_letter: | |
| all_questions += 1 | |
| is_correct = model_letter == correct_letter | |
| if is_correct: | |
| all_correct += 1 | |
| correct_ids.append(question_id) | |
| else: | |
| incorrect_ids.append(question_id) | |
| for category in metadata.get("categories", []): | |
| category_performance[category]["total"] += 1 | |
| if is_correct: | |
| category_performance[category]["correct"] += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| return process_results( | |
| category_performance, all_questions, all_correct, correct_ids, incorrect_ids | |
| ) | |
| def process_results( | |
| category_performance: Dict, | |
| all_questions: int, | |
| all_correct: int, | |
| correct_ids: Optional[List[str]] = None, | |
| incorrect_ids: Optional[List[str]] = None, | |
| ) -> Tuple[float, Dict, Dict, List[str], List[str]]: | |
| """ | |
| Process raw results into final statistics. | |
| Args: | |
| category_performance: Dict containing performance by category | |
| all_questions: Total number of questions | |
| all_correct: Total number of correct answers | |
| correct_ids: List of IDs for correctly answered questions | |
| incorrect_ids: List of IDs for incorrectly answered questions | |
| Returns: | |
| Tuple containing: | |
| - overall_accuracy (float) | |
| - category_accuracies (Dict) | |
| - question_type_stats (Dict) | |
| - correct_ids (List[str]) | |
| - incorrect_ids (List[str]) | |
| """ | |
| category_accuracies = { | |
| category: { | |
| "accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0, | |
| "total": stats["total"], | |
| "correct": stats["correct"], | |
| } | |
| for category, stats in category_performance.items() | |
| } | |
| question_type_stats = {} | |
| for qtype, categories in QUESTION_TYPES.items(): | |
| total = sum( | |
| category_performance[cat]["total"] for cat in categories if cat in category_performance | |
| ) | |
| correct = sum( | |
| category_performance[cat]["correct"] | |
| for cat in categories | |
| if cat in category_performance | |
| ) | |
| question_type_stats[qtype] = { | |
| "accuracy": (correct / total * 100) if total > 0 else 0, | |
| "total": total, | |
| "correct": correct, | |
| } | |
| overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0 | |
| return ( | |
| overall_accuracy, | |
| category_accuracies, | |
| question_type_stats, | |
| correct_ids or [], | |
| incorrect_ids or [], | |
| ) | |
| def print_analysis( | |
| overall_accuracy: float, | |
| category_accuracies: Dict, | |
| question_type_stats: Dict, | |
| correct_ids: List[str], | |
| incorrect_ids: List[str], | |
| model_name: str, | |
| ) -> None: | |
| """ | |
| Print analysis results. | |
| Args: | |
| overall_accuracy: Overall accuracy percentage | |
| category_accuracies: Dict containing accuracy metrics by category | |
| question_type_stats: Dict containing stats by question type | |
| correct_ids: List of IDs for correctly answered questions | |
| incorrect_ids: List of IDs for incorrectly answered questions | |
| model_name: Name of the model being analyzed | |
| """ | |
| total_questions = len(correct_ids) + len(incorrect_ids) | |
| print( | |
| f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)" | |
| ) | |
| print("\nCategory Performance:") | |
| sorted_categories = sorted( | |
| category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True | |
| ) | |
| for category, metrics in sorted_categories: | |
| print(f"{category}:") | |
| print(f" Accuracy: {metrics['accuracy']:.2f}%") | |
| print(f" Total Questions: {metrics['total']}") | |
| print(f" Correct Questions: {metrics['correct']}") | |
| print("\nQuestion Type Performance:") | |
| sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True) | |
| for qtype, metrics in sorted_types: | |
| print(f"\n{qtype}:") | |
| print(f" Accuracy: {metrics['accuracy']:.2f}%") | |
| print(f" Total Questions: {metrics['total']}") | |
| print(f" Correct Questions: {metrics['correct']}") | |
| print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}") | |
| # Save question IDs to JSON | |
| question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids} | |
| output_filename = f"{model_name}_question_ids.json" | |
| with open(output_filename, "w") as f: | |
| json.dump(question_ids, f, indent=2) | |
| print(f"\nQuestion IDs have been saved to {output_filename}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Analyze benchmark results") | |
| parser.add_argument("results_file", help="Path to results file") | |
| parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory") | |
| parser.add_argument( | |
| "--model", | |
| choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"], | |
| default="gpt4", | |
| help="Specify model format (default: gpt4)", | |
| ) | |
| parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze") | |
| args = parser.parse_args() | |
| if args.model == "gpt4": | |
| results = analyze_gpt4_results(args.results_file, args.max_questions) | |
| elif args.model == "llama": | |
| results = analyze_llama_results(args.results_file, args.max_questions) | |
| elif args.model == "chexagent": | |
| results = analyze_chexagent_results(args.results_file, args.max_questions) | |
| elif args.model == "medrax": | |
| results = analyze_gpt4_results(args.results_file, args.max_questions) | |
| else: | |
| parser.error(f"Unsupported model: {args.model}") | |
| print_analysis(*results, args.model) | |