# noqa: D100 """Integration tests that verify Anthropic selects `query_wandb_tool`. These tests send natural-language questions about the W&B *Models* data for the `wandb-applied-ai-team/mcp-tests` project. The Anthropic model should respond with a `tool_use` invoking `query_wandb_tool`, which we then execute and validate. """ import json import os import time import uuid from datetime import datetime from typing import Any, Dict, List import pytest from tests.anthropic_test_utils import ( call_anthropic, check_correctness_tool, extract_anthropic_tool_use, ) from wandb_mcp_server.mcp_tools.query_wandb_gql import ( QUERY_WANDB_GQL_TOOL_DESCRIPTION, query_paginated_wandb_gql, ) from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema from wandb_mcp_server.utils import get_git_commit, get_rich_logger # Root logging configuration logger = get_rich_logger(__name__) # weave.init("wandb-applied-ai-team/wandb-mcp-server-test-outputs") # os.environ["WANDB_SILENT"] = "true" # ----------------------------------------------------------------------------- # Custom JSON encoder for datetime objects (similar to test_query_weave_traces.py) # ----------------------------------------------------------------------------- class DateTimeEncoder(json.JSONEncoder): """JSON encoder that can handle datetime objects.""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) # ----------------------------------------------------------------------------- # Environment guards # ----------------------------------------------------------------------------- WANDB_API_KEY = os.getenv("WANDB_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") if not WANDB_API_KEY: pytest.skip( "WANDB_API_KEY environment variable not set; skipping live GraphQL tests.", allow_module_level=True, ) if not ANTHROPIC_API_KEY: pytest.skip( "ANTHROPIC_API_KEY environment variable not set; skipping Anthropic tests.", allow_module_level=True, ) # ----------------------------------------------------------------------------- # Static test context # ----------------------------------------------------------------------------- TEST_WANDB_ENTITY = "wandb-applied-ai-team" TEST_WANDB_PROJECT = "mcp-tests" # MODEL_NAME = "claude-3-7-sonnet-20250219" # MODEL_NAME = "claude-4-sonnet-20250514" MODEL_NAME = "claude-4-opus-20250514" CORRECTNESS_MODEL_NAME = "claude-3-5-haiku-20241022" # ----------------------------------------------------------------------------- # Build tool schema for Anthropic # ----------------------------------------------------------------------------- available_tools: Dict[str, Dict[str, Any]] = { "query_paginated_wandb_gql": { "function": query_paginated_wandb_gql, "schema": generate_anthropic_tool_schema( func=query_paginated_wandb_gql, description=QUERY_WANDB_GQL_TOOL_DESCRIPTION, ), } } tools: List[Dict[str, Any]] = [available_tools["query_paginated_wandb_gql"]["schema"]] # ----------------------------------------------------------------------------- # Compute baseline runCount once so that tests have a stable expected value # ----------------------------------------------------------------------------- BASELINE_QUERY = """ query ProjectRunCount($entity: String!, $project: String!) { project(name: $project, entityName: $entity) { runCount } } """ BASELINE_VARIABLES = {"entity": TEST_WANDB_ENTITY, "project": TEST_WANDB_PROJECT} # Compute baseline logger.info( "Fetching baseline runCount for %s/%s", TEST_WANDB_ENTITY, TEST_WANDB_PROJECT ) _baseline_result = query_paginated_wandb_gql(BASELINE_QUERY, BASELINE_VARIABLES) BASELINE_RUN_COUNT: int = _baseline_result["project"]["runCount"] logger.info("Baseline runCount = %s", BASELINE_RUN_COUNT) # ----------------------------------------------------------------------------- # Natural-language queries to test # ----------------------------------------------------------------------------- test_queries = [ { "index": 0, "question": "How many runs are currently logged in the `{project_name}` project under the `{entity_name}` entity?", "expected_output": 37, }, { "index": 1, "question": "What's the total experiment count for `{entity_name}/{project_name}`?", "expected_output": 37, }, { "index": 2, "question": "In `{project_name}` in entity `{entity_name}` how many runs were run on April 29th 2025?", "expected_output": 37, }, { "index": 3, "question": "Could you report the number of tracked runs in `{entity_name}/{project_name}` with lr 0.002?", "expected_output": 7, }, { "index": 4, "question": "what was the run with the best eval loss in the `{project_name}` project belonging to `{entity_name}`.", "expected_output": "run_id: h0fm5qp5 OR run_name: transformer_7_bs-128_lr-0.008_5593616", }, { "index": 5, "question": "How many steps in run gtng2y4l `{entity_name}/{project_name}` right now.", "expected_output": 750000, }, { "index": 6, "question": "How many steps in run transformer_25_bs-33554432_lr-0.026000000000000002_2377215 `{entity_name}/{project_name}` right now.", "expected_output": 750000, }, { "index": 7, "question": "What's the batch size of the run with best evaluation accuracy for `{project_name}` inside `{entity_name}`?", "expected_output": 16, }, # { # "index": 8, # Example if uncommented # "question": "Count the runs in my `{entity_name}` entity for the `{project_name}` project.", # "expected_output": BASELINE_RUN_COUNT, # }, # { # "index": 9, # Example if uncommented # "question": "How big is the experiment set for `{entity_name}/{project_name}`?", # "expected_output": BASELINE_RUN_COUNT, # }, # { # "index": 10, # Example if uncommented # "question": "Tell me the number of runs tracked in `{project_name}` (entity `{entity_name}`).", # "expected_output": BASELINE_RUN_COUNT, # }, ] # ----------------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------------- @pytest.mark.parametrize( "sample", test_queries, ids=[f"sample_{i}" for i, _ in enumerate(test_queries)], ) def test_query_wandb_gql(sample, weave_results_dir): """End-to-end test: NL question → Anthropic → tool_use → result validation.""" start_time = time.monotonic() current_git_commit = get_git_commit() git_commit_id = f"commit_{current_git_commit}" current_test_file_name = os.path.basename(__file__) # Find the index of the current sample for unique naming and metadata sample_index = -1 for i, s in enumerate(test_queries): if s == sample: sample_index = i break test_case_name = f"gql_query_{sample_index}_{sample.get('question', 'unknown_question')[:20].replace(' ', '_')}" query_text = sample["question"].format( entity_name=TEST_WANDB_ENTITY, project_name=TEST_WANDB_PROJECT, ) expected_output = sample["expected_output"] logger.info("\n==============================") logger.info("QUERY: %s", query_text) # --- Retry Logic Setup --- max_retries = 1 last_reasoning = "No correctness check performed yet." last_is_correct = False first_call_assistant_response = None # Store the response dict from the first model tool_result = None # Store the result of executing the tool tool_name_used_in_test = None tool_input_used_in_test = None # Initialize log_data_for_file for the current test sample final_log_data_for_file = { "metadata": { "sample_name": test_case_name, "test_case_index": sample_index, "git_commit_id": git_commit_id, "source_test_file_name": current_test_file_name, "test_query_text": query_text, "expected_test_output": str(expected_output), "retry_attempt": 0, # Will be updated in the loop "max_retries_configured": max_retries, }, "inputs": { # Inputs to the overall test/evaluation "test_query": query_text, "expected_value": str(expected_output), }, "output": {}, # Will store tool output and correctness check details "score": False, # Default to False, updated on success "scorer_name": "gql_correctness_assertion", # Specific scorer for these tests "metrics": {}, # Will store execution_latency_seconds } try: # Initial messages for the first attempt messages_first_call = [{"role": "user", "content": query_text}] for attempt in range(max_retries + 1): logger.info(f"\n--- Attempt {attempt + 1} / {max_retries + 1} ---") final_log_data_for_file["metadata"]["retry_attempt"] = attempt + 1 if attempt > 0: # We are retrying. Add the previous assistant response and a user message with feedback. if first_call_assistant_response: messages_first_call.append( first_call_assistant_response ) # Add previous assistant message (contains tool use) else: # Should not happen in retry logic, but defensively handle logger.warning( "Attempting retry, but no previous assistant response found." ) # Construct the user message asking for a retry retry_user_message_content = f""" Executing the previous tool call resulted in: ```json {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)} ``` A separate check determined this result was incorrect for the original query. The reasoning provided was: "{last_reasoning}". Please re-analyze the original query ("{query_text}") and the result from your previous attempt, then try generating the 'query_paginated_wandb_gql' tool call again. """ messages_first_call.append( {"role": "user", "content": retry_user_message_content} ) # --- First Call: Get the query_paginated_wandb_gql tool use --- response = call_anthropic( model_name=MODEL_NAME, messages=messages_first_call, tools=tools, # Provide the GQL tool schema ) first_call_assistant_response = ( response # Store this response for potential next retry ) _, tool_name, tool_input, _ = extract_anthropic_tool_use(response) tool_name_used_in_test = tool_name tool_input_used_in_test = tool_input logger.info(f"Attempt {attempt + 1}: Tool emitted by model: {tool_name}") logger.info( f"Attempt {attempt + 1}: Tool input: {json.dumps(tool_input, indent=2)}" ) assert tool_name == "query_paginated_wandb_gql", ( f"Attempt {attempt + 1}: Expected 'query_paginated_wandb_gql', got '{tool_name}'" ) # --- Execute the GQL tool --- try: tool_result = available_tools[tool_name]["function"](**tool_input) logger.info( f"Attempt {attempt + 1}: Tool result: {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}" ) # Log full result except Exception as e: logger.error( f"Attempt {attempt + 1}: Error executing tool '{tool_name}' with input {tool_input}: {e}", exc_info=True, ) final_log_data_for_file["output"]["tool_execution_error_details"] = str( e ) # If tool execution fails, we might want to stop retrying for this sample or handle differently. # For now, it will proceed to correctness check which will likely fail or be skipped. # Depending on the error, we might want to `pytest.fail` or `raise` to stop the current attempt. # For this iteration, we'll let it go to the correctness check, which will likely fail it. last_is_correct = False last_reasoning = f"Tool execution failed: {e}" if attempt >= max_retries: # If this was the last attempt raise # Re-raise the exception to fail the test continue # Skip to next retry attempt # --- Second Call: Perform Correctness Check (Separate Task) --- logger.info( f"\n--- Starting Correctness Check for Attempt {attempt + 1} ---" ) try: # Prepare the prompt for the check - provide all context clearly correctness_prompt = f""" Please evaluate if the provided 'Actual Tool Result' correctly addresses the 'Original User Query' and seems consistent with the 'Expected Output'. Use the 'check_correctness_tool' to provide your reasoning and conclusion. Original User Query: "{query_text}" Expected Output (for context, may not be directly comparable in structure): {json.dumps(expected_output, indent=2, cls=DateTimeEncoder)} Actual Tool Result from 'query_paginated_wandb_gql': {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)} """ messages_check_call = [{"role": "user", "content": correctness_prompt}] correctness_response = call_anthropic( model_name=CORRECTNESS_MODEL_NAME, messages=messages_check_call, check_correctness_tool=check_correctness_tool, ) logger.info( f"Attempt {attempt + 1}: Correctness check response:\n{correctness_response}\n\n" ) # --- Extract and Validate Correctness Tool Use --- _, check_tool_name, check_tool_input, _ = extract_anthropic_tool_use( correctness_response ) assert check_tool_name == "check_correctness_tool", ( f"Attempt {attempt + 1}: Expected correctness tool, got {check_tool_name}" ) assert "reasoning" in check_tool_input, ( f"Attempt {attempt + 1}: Correctness tool missing 'reasoning'" ) assert "is_correct" in check_tool_input, ( f"Attempt {attempt + 1}: Correctness tool missing 'is_correct'" ) # 2. Extract the data from the input dictionary try: reasoning_text = check_tool_input["reasoning"] is_correct_flag = check_tool_input["is_correct"] # Store the latest results last_reasoning = reasoning_text last_is_correct = is_correct_flag logger.info( f"Attempt {attempt + 1}: Correctness Reasoning: {reasoning_text}" ) logger.info( f"Attempt {attempt + 1}: Is Correct according to LLM: {is_correct_flag}" ) if is_correct_flag: logger.info( f"--- Correctness check passed on attempt {attempt + 1}. ---" ) final_log_data_for_file["score"] = True break # Exit the loop successfully # If not correct, and this is the last attempt, the loop will end naturally. except KeyError as e: logger.error( f"Attempt {attempt + 1}: Missing expected key in correctness tool input: {e}" ) logger.error( f"Attempt {attempt + 1}: Full input received: {check_tool_input}" ) last_is_correct = False last_reasoning = f"Correctness tool response missing key: {e}" final_log_data_for_file["output"]["assertion_error_details"] = ( f"Correctness tool response missing key: {e}" ) if attempt >= max_retries: pytest.fail( f"Attempt {attempt + 1}: Correctness tool response was missing key: {e}" ) continue # To next retry except Exception as e: logger.error( f"Attempt {attempt + 1}: Error processing correctness tool input: {e}", exc_info=True, ) last_is_correct = False last_reasoning = f"Failed to process correctness tool input: {e}" final_log_data_for_file["output"]["assertion_error_details"] = ( f"Failed to process correctness tool input: {e}" ) if attempt >= max_retries: pytest.fail( f"Attempt {attempt + 1}: Failed to process correctness tool input: {e}" ) continue # To next retry except Exception as e: logger.error( f"Attempt {attempt + 1}: Error during correctness check for query '{query_text}': {e}", exc_info=True, ) last_is_correct = False last_reasoning = f"Correctness check failed with exception: {e}" final_log_data_for_file["output"]["assertion_error_details"] = ( f"Correctness check failed with exception: {e}" ) if attempt >= max_retries: pytest.fail( f"Attempt {attempt + 1}: Correctness check failed with exception: {e}" ) continue # To next retry # After the loop, if not last_is_correct, it means all retries failed or it failed on the last attempt. if not last_is_correct and attempt >= max_retries: pytest.fail( f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. " f"Final is_correct_flag is `{last_is_correct}`. " f"Final Reasoning: '{last_reasoning}'" ) except Exception as test_exec_exception: # Catch any exception that might cause the test to fail before all retries are done # or even before the loop fully completes. logger.error( f"Test execution for sample {sample_index} failed globally: {test_exec_exception}", exc_info=True, ) final_log_data_for_file["score"] = False final_log_data_for_file["output"]["test_exception"] = str(test_exec_exception) # We will write the JSON in `finally`, then re-raise or let pytest handle the failure. raise # Re-raise the caught exception to ensure the test is marked as failed by pytest finally: end_time = time.monotonic() execution_latency_seconds = end_time - start_time final_log_data_for_file["metrics"]["execution_latency_seconds"] = ( execution_latency_seconds ) final_log_data_for_file["metadata"]["final_attempt_number_for_json"] = ( final_log_data_for_file["metadata"]["retry_attempt"] ) # Should be updated inside loop # Populate output details from the last successful (or last attempted) tool call final_log_data_for_file["output"]["tool_name"] = tool_name_used_in_test final_log_data_for_file["output"]["tool_input"] = ( json.dumps(tool_input_used_in_test, indent=2) if tool_input_used_in_test else None ) final_log_data_for_file["output"]["tool_result"] = ( json.dumps(tool_result, indent=2, cls=DateTimeEncoder) if tool_result else None ) final_log_data_for_file["output"]["correctness_reasoning"] = last_reasoning final_log_data_for_file["score"] = last_is_correct # Ensure final score is set # Generate a unique filename for the JSON output unique_file_id = str(uuid.uuid4()) worker_id = os.environ.get( "PYTEST_XDIST_WORKER", "main_thread" ) # Default if not in xdist # Sanitize test_case_name for filename (take first 30 chars, replace spaces) safe_test_name_part = ( test_case_name.replace(" ", "_").replace("/", "_").replace("\\", "_")[:30] ) file_name = f"gql_test_idx_{sample_index}_{safe_test_name_part}_w_{worker_id}_attempt_{final_log_data_for_file['metadata']['final_attempt_number_for_json']}_{('pass' if final_log_data_for_file['score'] else 'fail')}_{unique_file_id}.json" file_path = weave_results_dir / file_name logger.critical( f"WRITING JSON for GQL Test: {test_case_name} (Index: {sample_index}, Last Attempt: {final_log_data_for_file['metadata']['final_attempt_number_for_json']}, Score: {final_log_data_for_file['score']}) to {file_path}" ) try: with open(file_path, "w") as f: json.dump(final_log_data_for_file, f, indent=2, cls=DateTimeEncoder) logger.info( f"Result for GQL test {test_case_name} (Latency: {execution_latency_seconds:.2f}s) written to {file_path}" ) except Exception as e: logger.error( f"Failed to write result JSON for GQL test {test_case_name} to {file_path}: {e}" ) # If we reach here and no exception was raised by pytest.fail or re-raised from the try block, # it means the correctness check passed within the allowed attempts. if not last_is_correct: # Final check if loop exited due to retries without success pytest.fail( f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. " f"Final is_correct_flag is `{last_is_correct}`. " f"Final Reasoning: '{last_reasoning}'" ) logger.info( f"--- Test for sample {sample_index} ({test_case_name}) completed. Score: {last_is_correct} ---" )