Spaces:
Runtime error
Runtime error
| import json | |
| import openai | |
| import os | |
| import glob | |
| import time | |
| import logging | |
| from datetime import datetime | |
| from tenacity import retry, wait_exponential, stop_after_attempt | |
| model_name = "chatgpt-4o-latest" | |
| temperature = 0.2 | |
| log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
| logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s") | |
| def calculate_cost( | |
| prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest" | |
| ) -> float: | |
| """Calculate the cost of API usage based on token counts. | |
| Args: | |
| prompt_tokens: Number of tokens in the prompt | |
| completion_tokens: Number of tokens in the completion | |
| model: Model name to use for pricing, defaults to chatgpt-4o-latest | |
| Returns: | |
| float: Cost in USD | |
| """ | |
| pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}} | |
| rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0}) | |
| return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000 | |
| def create_multimodal_request( | |
| question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI | |
| ) -> openai.types.chat.ChatCompletion: | |
| """Create and send a multimodal request to the OpenAI API. | |
| Args: | |
| question_data: Dictionary containing question details and figures | |
| case_details: Dictionary containing case information and figures | |
| case_id: Identifier for the medical case | |
| question_id: Identifier for the specific question | |
| client: OpenAI client instance | |
| Returns: | |
| openai.types.chat.ChatCompletion: API response object, or None if request fails | |
| """ | |
| prompt = f"""Given the following medical case: | |
| Please answer this multiple choice question: | |
| {question_data['question']} | |
| Base your answer only on the provided images and case information.""" | |
| content = [{"type": "text", "text": prompt}] | |
| # Parse required figures | |
| try: | |
| # Try multiple ways of parsing figures | |
| if isinstance(question_data["figures"], str): | |
| try: | |
| required_figures = json.loads(question_data["figures"]) | |
| except json.JSONDecodeError: | |
| required_figures = [question_data["figures"]] | |
| elif isinstance(question_data["figures"], list): | |
| required_figures = question_data["figures"] | |
| else: | |
| required_figures = [str(question_data["figures"])] | |
| except Exception as e: | |
| print(f"Error parsing figures: {e}") | |
| required_figures = [] | |
| # Ensure each figure starts with "Figure " | |
| required_figures = [ | |
| fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures | |
| ] | |
| subfigures = [] | |
| for figure in required_figures: | |
| # Handle both regular figures and those with letter suffixes | |
| base_figure_num = "".join(filter(str.isdigit, figure)) | |
| figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None | |
| # Find matching figures in case details | |
| matching_figures = [ | |
| case_figure | |
| for case_figure in case_details.get("figures", []) | |
| if case_figure["number"] == f"Figure {base_figure_num}" | |
| ] | |
| if not matching_figures: | |
| print(f"No matching figure found for {figure} in case {case_id}") | |
| continue | |
| for case_figure in matching_figures: | |
| # If a specific letter is specified, filter subfigures | |
| if figure_letter: | |
| matching_subfigures = [ | |
| subfig | |
| for subfig in case_figure.get("subfigures", []) | |
| if subfig.get("number", "").lower().endswith(figure_letter.lower()) | |
| or subfig.get("label", "").lower() == figure_letter.lower() | |
| ] | |
| subfigures.extend(matching_subfigures) | |
| else: | |
| # If no letter specified, add all subfigures | |
| subfigures.extend(case_figure.get("subfigures", [])) | |
| # Add images to content | |
| for subfig in subfigures: | |
| if "url" in subfig: | |
| content.append({"type": "image_url", "image_url": {"url": subfig["url"]}}) | |
| else: | |
| print(f"Subfigure missing URL: {subfig}") | |
| # If no images found, log and return None | |
| if len(content) == 1: # Only the text prompt exists | |
| print(f"No images found for case {case_id}, question {question_id}") | |
| return None | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).", | |
| }, | |
| {"role": "user", "content": content}, | |
| ] | |
| if len(content) == 1: # Only the text prompt exists | |
| print(f"No images found for case {case_id}, question {question_id}") | |
| log_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "model": model_name, | |
| "temperature": temperature, | |
| "status": "skipped", | |
| "reason": "no_images", | |
| "cost": 0, | |
| "input": { | |
| "messages": messages, | |
| "question_data": { | |
| "question": question_data["question"], | |
| "explanation": question_data["explanation"], | |
| "metadata": question_data.get("metadata", {}), | |
| "figures": question_data["figures"], | |
| }, | |
| "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig], | |
| "image_captions": [subfig.get("caption", "") for subfig in subfigures], | |
| }, | |
| } | |
| logging.info(json.dumps(log_entry)) | |
| return None | |
| try: | |
| start_time = time.time() | |
| response = client.chat.completions.create( | |
| model=model_name, messages=messages, max_tokens=50, temperature=temperature | |
| ) | |
| duration = time.time() - start_time | |
| log_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "model": model_name, | |
| "temperature": temperature, | |
| "duration": round(duration, 2), | |
| "usage": { | |
| "prompt_tokens": response.usage.prompt_tokens, | |
| "completion_tokens": response.usage.completion_tokens, | |
| "total_tokens": response.usage.total_tokens, | |
| }, | |
| "cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens), | |
| "model_answer": response.choices[0].message.content, | |
| "correct_answer": question_data["answer"], | |
| "input": { | |
| "messages": messages, | |
| "question_data": { | |
| "question": question_data["question"], | |
| "explanation": question_data["explanation"], | |
| "metadata": question_data.get("metadata", {}), | |
| "figures": question_data["figures"], | |
| }, | |
| "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig], | |
| "image_captions": [subfig.get("caption", "") for subfig in subfigures], | |
| }, | |
| } | |
| logging.info(json.dumps(log_entry)) | |
| return response | |
| except openai.RateLimitError: | |
| log_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "model": model_name, | |
| "temperature": temperature, | |
| "status": "error", | |
| "reason": "rate_limit", | |
| "cost": 0, | |
| "input": { | |
| "messages": messages, | |
| "question_data": { | |
| "question": question_data["question"], | |
| "explanation": question_data["explanation"], | |
| "metadata": question_data.get("metadata", {}), | |
| "figures": question_data["figures"], | |
| }, | |
| "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig], | |
| "image_captions": [subfig.get("caption", "") for subfig in subfigures], | |
| }, | |
| } | |
| logging.info(json.dumps(log_entry)) | |
| print( | |
| f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...", | |
| flush=True, | |
| ) | |
| time.sleep(20) | |
| raise | |
| except Exception as e: | |
| log_entry = { | |
| "case_id": case_id, | |
| "question_id": question_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "model": model_name, | |
| "temperature": temperature, | |
| "status": "error", | |
| "error": str(e), | |
| "cost": 0, | |
| "input": { | |
| "messages": messages, | |
| "question_data": { | |
| "question": question_data["question"], | |
| "explanation": question_data["explanation"], | |
| "metadata": question_data.get("metadata", {}), | |
| "figures": question_data["figures"], | |
| }, | |
| "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig], | |
| "image_captions": [subfig.get("caption", "") for subfig in subfigures], | |
| }, | |
| } | |
| logging.info(json.dumps(log_entry)) | |
| print(f"Error processing case {case_id}, question {question_id}: {str(e)}") | |
| raise | |
| def load_benchmark_questions(case_id: str) -> list: | |
| """Load benchmark questions for a given case. | |
| Args: | |
| case_id: Identifier for the medical case | |
| Returns: | |
| list: List of paths to question files | |
| """ | |
| benchmark_dir = "../benchmark/questions" | |
| return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") | |
| def count_total_questions() -> tuple[int, int]: | |
| """Count total number of cases and questions in benchmark. | |
| Returns: | |
| tuple: (total_cases, total_questions) | |
| """ | |
| total_cases = len(glob.glob("../benchmark/questions/*")) | |
| total_questions = sum( | |
| len(glob.glob(f"../benchmark/questions/{case_id}/*.json")) | |
| for case_id in os.listdir("../benchmark/questions") | |
| ) | |
| return total_cases, total_questions | |
| def main() -> None: | |
| """Main function to run the benchmark evaluation.""" | |
| with open("../data/eurorad_metadata.json", "r") as file: | |
| data = json.load(file) | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable is not set.") | |
| global client | |
| client = openai.OpenAI(api_key=api_key) | |
| total_cases, total_questions = count_total_questions() | |
| cases_processed = 0 | |
| questions_processed = 0 | |
| skipped_questions = 0 | |
| print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}") | |
| for case_id, case_details in data.items(): | |
| question_files = load_benchmark_questions(case_id) | |
| if not question_files: | |
| continue | |
| cases_processed += 1 | |
| for question_file in question_files: | |
| with open(question_file, "r") as file: | |
| question_data = json.load(file) | |
| question_id = os.path.basename(question_file).split(".")[0] | |
| questions_processed += 1 | |
| response = create_multimodal_request( | |
| question_data, case_details, case_id, question_id, client | |
| ) | |
| # Handle cases where response is None | |
| if response is None: | |
| skipped_questions += 1 | |
| print(f"Skipped question: Case ID {case_id}, Question ID {question_id}") | |
| continue | |
| print( | |
| f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}" | |
| ) | |
| print(f"Case ID: {case_id}") | |
| print(f"Question ID: {question_id}") | |
| print(f"Model Answer: {response.choices[0].message.content}") | |
| print(f"Correct Answer: {question_data['answer']}\n") | |
| print(f"\nBenchmark Summary:") | |
| print(f"Total Cases Processed: {cases_processed}") | |
| print(f"Total Questions Processed: {questions_processed}") | |
| print(f"Total Questions Skipped: {skipped_questions}") | |
| if __name__ == "__main__": | |
| main() | |