Spaces:
Runtime error
Runtime error
| import json | |
| import argparse | |
| import random | |
| from typing import List, Dict, Any, Tuple | |
| import re | |
| from collections import defaultdict | |
| # Define category order | |
| CATEGORY_ORDER = [ | |
| "detection", | |
| "classification", | |
| "localization", | |
| "comparison", | |
| "relationship", | |
| "diagnosis", | |
| "characterization", | |
| ] | |
| def extract_letter_answer(answer: str) -> str: | |
| """Extract just the letter answer from various answer formats. | |
| Args: | |
| answer: The answer string to extract a letter from | |
| Returns: | |
| str: The extracted letter in uppercase, or empty string if no letter found | |
| """ | |
| if not answer: | |
| return "" | |
| # Convert to string and clean | |
| answer = str(answer).strip() | |
| # If it's just a single letter A-F, return it | |
| if len(answer) == 1 and answer.upper() in "ABCDEF": | |
| return answer.upper() | |
| # Try to match patterns like "A)", "A.", "A ", etc. | |
| match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE) | |
| if match: | |
| return match.group(1).upper() | |
| # Try to find any standalone A-F letters preceded by space or start of string | |
| # and followed by space, period, parenthesis or end of string | |
| matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE) | |
| if matches: | |
| return matches[0].upper() | |
| # Last resort: just find any A-F letter | |
| letters = re.findall(r"[A-F]", answer, re.IGNORECASE) | |
| if letters: | |
| return letters[0].upper() | |
| # If no letter found, return original (cleaned) | |
| return answer.strip().upper() | |
| def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]: | |
| """Parse JSON Lines file and extract valid predictions. | |
| Args: | |
| file_path: Path to the JSON Lines file to parse | |
| Returns: | |
| Tuple containing: | |
| - str: Model name or file path if model name not found | |
| - List[Dict[str, Any]]: List of valid prediction entries | |
| """ | |
| valid_predictions = [] | |
| model_name = None | |
| # First try to parse as LLaVA format | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if data.get("model") == "llava-med-v1.5-mistral-7b": | |
| model_name = data["model"] | |
| for result in data.get("results", []): | |
| if all(k in result for k in ["case_id", "question_id", "correct_answer"]): | |
| # Extract answer with priority: model_answer > validated_answer > raw_output | |
| model_answer = ( | |
| result.get("model_answer") | |
| or result.get("validated_answer") | |
| or result.get("raw_output", "") | |
| ) | |
| # Add default categories for LLaVA results | |
| prediction = { | |
| "case_id": result["case_id"], | |
| "question_id": result["question_id"], | |
| "model_answer": model_answer, | |
| "correct_answer": result["correct_answer"], | |
| "input": { | |
| "question_data": { | |
| "metadata": { | |
| "categories": [ | |
| "detection", | |
| "classification", | |
| "localization", | |
| "comparison", | |
| "relationship", | |
| "diagnosis", | |
| "characterization", | |
| ] | |
| } | |
| } | |
| }, | |
| } | |
| valid_predictions.append(prediction) | |
| return model_name, valid_predictions | |
| except (json.JSONDecodeError, KeyError): | |
| pass | |
| # If not LLaVA format, process as original format | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.startswith("HTTP Request:"): | |
| continue | |
| try: | |
| data = json.loads(line.strip()) | |
| if "model" in data: | |
| model_name = data["model"] | |
| if all( | |
| k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"] | |
| ): | |
| valid_predictions.append(data) | |
| except json.JSONDecodeError: | |
| continue | |
| return model_name if model_name else file_path, valid_predictions | |
| def filter_common_questions( | |
| predictions_list: List[List[Dict[str, Any]]] | |
| ) -> List[List[Dict[str, Any]]]: | |
| """Ensure only questions that exist across all models are evaluated. | |
| Args: | |
| predictions_list: List of prediction lists from different models | |
| Returns: | |
| List[List[Dict[str, Any]]]: Filtered predictions containing only common questions | |
| """ | |
| question_sets = [ | |
| set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list | |
| ] | |
| common_questions = set.intersection(*question_sets) | |
| return [ | |
| [p for p in preds if (p["case_id"], p["question_id"]) in common_questions] | |
| for preds in predictions_list | |
| ] | |
| def calculate_accuracy( | |
| predictions: List[Dict[str, Any]] | |
| ) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]: | |
| """Compute overall and category-level accuracy. | |
| Args: | |
| predictions: List of prediction entries to analyze | |
| Returns: | |
| Tuple containing: | |
| - float: Overall accuracy percentage | |
| - int: Number of correct predictions | |
| - int: Total number of predictions | |
| - Dict[str, Dict[str, float]]: Category-level accuracy statistics | |
| """ | |
| if not predictions: | |
| return 0.0, 0, 0, {} | |
| category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) | |
| correct = 0 | |
| total = 0 | |
| sample_size = min(5, len(predictions)) | |
| sampled_indices = random.sample(range(len(predictions)), sample_size) | |
| print("\nSample extracted answers:") | |
| for i in sampled_indices: | |
| pred = predictions[i] | |
| model_ans = extract_letter_answer(pred["model_answer"]) | |
| correct_ans = extract_letter_answer(pred["correct_answer"]) | |
| print(f"QID: {pred['question_id']}") | |
| print(f" Raw Model Answer: {pred['model_answer']}") | |
| print(f" Extracted Model Answer: {model_ans}") | |
| print(f" Raw Correct Answer: {pred['correct_answer']}") | |
| print(f" Extracted Correct Answer: {correct_ans}") | |
| print("-" * 80) | |
| for pred in predictions: | |
| try: | |
| model_ans = extract_letter_answer(pred["model_answer"]) | |
| correct_ans = extract_letter_answer(pred["correct_answer"]) | |
| categories = ( | |
| pred.get("input", {}) | |
| .get("question_data", {}) | |
| .get("metadata", {}) | |
| .get("categories", []) | |
| ) | |
| if model_ans and correct_ans: | |
| total += 1 | |
| is_correct = model_ans == correct_ans | |
| if is_correct: | |
| correct += 1 | |
| for category in categories: | |
| category_performance[category]["total"] += 1 | |
| if is_correct: | |
| category_performance[category]["correct"] += 1 | |
| except KeyError: | |
| continue | |
| category_accuracies = { | |
| category: { | |
| "accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0, | |
| "total": stats["total"], | |
| "correct": stats["correct"], | |
| } | |
| for category, stats in category_performance.items() | |
| } | |
| return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies) | |
| def compare_models(file_paths: List[str]) -> None: | |
| """Compare accuracy between multiple model prediction files. | |
| Args: | |
| file_paths: List of paths to model prediction files to compare | |
| """ | |
| # Parse all files | |
| parsed_results = [parse_json_lines(file_path) for file_path in file_paths] | |
| model_names, predictions_list = zip(*parsed_results) | |
| # Get initial stats | |
| print(f"\n📊 **Initial Accuracy**:") | |
| results = [] | |
| category_results = [] | |
| for preds, name in zip(predictions_list, model_names): | |
| acc, correct, total, category_acc = calculate_accuracy(preds) | |
| results.append((acc, correct, total, name)) | |
| category_results.append(category_acc) | |
| print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") | |
| # Get common questions across all models | |
| filtered_predictions = filter_common_questions(predictions_list) | |
| print( | |
| f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}" | |
| ) | |
| # Compute accuracy on common questions | |
| print(f"\n📊 **Accuracy on Common Questions**:") | |
| filtered_results = [] | |
| filtered_category_results = [] | |
| for preds, name in zip(filtered_predictions, model_names): | |
| acc, correct, total, category_acc = calculate_accuracy(preds) | |
| filtered_results.append((acc, correct, total, name)) | |
| filtered_category_results.append(category_acc) | |
| print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") | |
| # Print category-wise accuracy | |
| print("\nCategory Performance (Common Questions):") | |
| for category in CATEGORY_ORDER: | |
| print(f"\n{category.capitalize()}:") | |
| for model_name, category_acc in zip(model_names, filtered_category_results): | |
| stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0}) | |
| print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Compare accuracy across multiple model prediction files" | |
| ) | |
| parser.add_argument("files", nargs="+", help="Paths to model prediction files") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| compare_models(args.files) | |
| if __name__ == "__main__": | |
| main() | |