mcp-server / tests /test_query_wandb_gql.py
NiWaRe's picture
mcp_base
f647629
raw
history blame
23 kB
# noqa: D100
"""Integration tests that verify Anthropic selects `query_wandb_tool`.
These tests send natural-language questions about the W&B *Models* data for the
`wandb-applied-ai-team/mcp-tests` project. The Anthropic model should respond
with a `tool_use` invoking `query_wandb_tool`, which we then execute and
validate.
"""
import json
import os
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List
import pytest
from tests.anthropic_test_utils import (
call_anthropic,
check_correctness_tool,
extract_anthropic_tool_use,
)
from wandb_mcp_server.mcp_tools.query_wandb_gql import (
QUERY_WANDB_GQL_TOOL_DESCRIPTION,
query_paginated_wandb_gql,
)
from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema
from wandb_mcp_server.utils import get_git_commit, get_rich_logger
# Root logging configuration
logger = get_rich_logger(__name__)
# weave.init("wandb-applied-ai-team/wandb-mcp-server-test-outputs")
# os.environ["WANDB_SILENT"] = "true"
# -----------------------------------------------------------------------------
# Custom JSON encoder for datetime objects (similar to test_query_weave_traces.py)
# -----------------------------------------------------------------------------
class DateTimeEncoder(json.JSONEncoder):
"""JSON encoder that can handle datetime objects."""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
# -----------------------------------------------------------------------------
# Environment guards
# -----------------------------------------------------------------------------
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
if not WANDB_API_KEY:
pytest.skip(
"WANDB_API_KEY environment variable not set; skipping live GraphQL tests.",
allow_module_level=True,
)
if not ANTHROPIC_API_KEY:
pytest.skip(
"ANTHROPIC_API_KEY environment variable not set; skipping Anthropic tests.",
allow_module_level=True,
)
# -----------------------------------------------------------------------------
# Static test context
# -----------------------------------------------------------------------------
TEST_WANDB_ENTITY = "wandb-applied-ai-team"
TEST_WANDB_PROJECT = "mcp-tests"
# MODEL_NAME = "claude-3-7-sonnet-20250219"
# MODEL_NAME = "claude-4-sonnet-20250514"
MODEL_NAME = "claude-4-opus-20250514"
CORRECTNESS_MODEL_NAME = "claude-3-5-haiku-20241022"
# -----------------------------------------------------------------------------
# Build tool schema for Anthropic
# -----------------------------------------------------------------------------
available_tools: Dict[str, Dict[str, Any]] = {
"query_paginated_wandb_gql": {
"function": query_paginated_wandb_gql,
"schema": generate_anthropic_tool_schema(
func=query_paginated_wandb_gql,
description=QUERY_WANDB_GQL_TOOL_DESCRIPTION,
),
}
}
tools: List[Dict[str, Any]] = [available_tools["query_paginated_wandb_gql"]["schema"]]
# -----------------------------------------------------------------------------
# Compute baseline runCount once so that tests have a stable expected value
# -----------------------------------------------------------------------------
BASELINE_QUERY = """
query ProjectRunCount($entity: String!, $project: String!) {
project(name: $project, entityName: $entity) {
runCount
}
}
"""
BASELINE_VARIABLES = {"entity": TEST_WANDB_ENTITY, "project": TEST_WANDB_PROJECT}
# Compute baseline
logger.info(
"Fetching baseline runCount for %s/%s", TEST_WANDB_ENTITY, TEST_WANDB_PROJECT
)
_baseline_result = query_paginated_wandb_gql(BASELINE_QUERY, BASELINE_VARIABLES)
BASELINE_RUN_COUNT: int = _baseline_result["project"]["runCount"]
logger.info("Baseline runCount = %s", BASELINE_RUN_COUNT)
# -----------------------------------------------------------------------------
# Natural-language queries to test
# -----------------------------------------------------------------------------
test_queries = [
{
"index": 0,
"question": "How many runs are currently logged in the `{project_name}` project under the `{entity_name}` entity?",
"expected_output": 37,
},
{
"index": 1,
"question": "What's the total experiment count for `{entity_name}/{project_name}`?",
"expected_output": 37,
},
{
"index": 2,
"question": "In `{project_name}` in entity `{entity_name}` how many runs were run on April 29th 2025?",
"expected_output": 37,
},
{
"index": 3,
"question": "Could you report the number of tracked runs in `{entity_name}/{project_name}` with lr 0.002?",
"expected_output": 7,
},
{
"index": 4,
"question": "what was the run with the best eval loss in the `{project_name}` project belonging to `{entity_name}`.",
"expected_output": "run_id: h0fm5qp5 OR run_name: transformer_7_bs-128_lr-0.008_5593616",
},
{
"index": 5,
"question": "How many steps in run gtng2y4l `{entity_name}/{project_name}` right now.",
"expected_output": 750000,
},
{
"index": 6,
"question": "How many steps in run transformer_25_bs-33554432_lr-0.026000000000000002_2377215 `{entity_name}/{project_name}` right now.",
"expected_output": 750000,
},
{
"index": 7,
"question": "What's the batch size of the run with best evaluation accuracy for `{project_name}` inside `{entity_name}`?",
"expected_output": 16,
},
# {
# "index": 8, # Example if uncommented
# "question": "Count the runs in my `{entity_name}` entity for the `{project_name}` project.",
# "expected_output": BASELINE_RUN_COUNT,
# },
# {
# "index": 9, # Example if uncommented
# "question": "How big is the experiment set for `{entity_name}/{project_name}`?",
# "expected_output": BASELINE_RUN_COUNT,
# },
# {
# "index": 10, # Example if uncommented
# "question": "Tell me the number of runs tracked in `{project_name}` (entity `{entity_name}`).",
# "expected_output": BASELINE_RUN_COUNT,
# },
]
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"sample",
test_queries,
ids=[f"sample_{i}" for i, _ in enumerate(test_queries)],
)
def test_query_wandb_gql(sample, weave_results_dir):
"""End-to-end test: NL question → Anthropic → tool_use → result validation."""
start_time = time.monotonic()
current_git_commit = get_git_commit()
git_commit_id = f"commit_{current_git_commit}"
current_test_file_name = os.path.basename(__file__)
# Find the index of the current sample for unique naming and metadata
sample_index = -1
for i, s in enumerate(test_queries):
if s == sample:
sample_index = i
break
test_case_name = f"gql_query_{sample_index}_{sample.get('question', 'unknown_question')[:20].replace(' ', '_')}"
query_text = sample["question"].format(
entity_name=TEST_WANDB_ENTITY,
project_name=TEST_WANDB_PROJECT,
)
expected_output = sample["expected_output"]
logger.info("\n==============================")
logger.info("QUERY: %s", query_text)
# --- Retry Logic Setup ---
max_retries = 1
last_reasoning = "No correctness check performed yet."
last_is_correct = False
first_call_assistant_response = None # Store the response dict from the first model
tool_result = None # Store the result of executing the tool
tool_name_used_in_test = None
tool_input_used_in_test = None
# Initialize log_data_for_file for the current test sample
final_log_data_for_file = {
"metadata": {
"sample_name": test_case_name,
"test_case_index": sample_index,
"git_commit_id": git_commit_id,
"source_test_file_name": current_test_file_name,
"test_query_text": query_text,
"expected_test_output": str(expected_output),
"retry_attempt": 0, # Will be updated in the loop
"max_retries_configured": max_retries,
},
"inputs": { # Inputs to the overall test/evaluation
"test_query": query_text,
"expected_value": str(expected_output),
},
"output": {}, # Will store tool output and correctness check details
"score": False, # Default to False, updated on success
"scorer_name": "gql_correctness_assertion", # Specific scorer for these tests
"metrics": {}, # Will store execution_latency_seconds
}
try:
# Initial messages for the first attempt
messages_first_call = [{"role": "user", "content": query_text}]
for attempt in range(max_retries + 1):
logger.info(f"\n--- Attempt {attempt + 1} / {max_retries + 1} ---")
final_log_data_for_file["metadata"]["retry_attempt"] = attempt + 1
if attempt > 0:
# We are retrying. Add the previous assistant response and a user message with feedback.
if first_call_assistant_response:
messages_first_call.append(
first_call_assistant_response
) # Add previous assistant message (contains tool use)
else:
# Should not happen in retry logic, but defensively handle
logger.warning(
"Attempting retry, but no previous assistant response found."
)
# Construct the user message asking for a retry
retry_user_message_content = f"""
Executing the previous tool call resulted in:
```json
{json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}
```
A separate check determined this result was incorrect for the original query.
The reasoning provided was: "{last_reasoning}".
Please re-analyze the original query ("{query_text}") and the result from your previous attempt, then try generating the 'query_paginated_wandb_gql' tool call again.
"""
messages_first_call.append(
{"role": "user", "content": retry_user_message_content}
)
# --- First Call: Get the query_paginated_wandb_gql tool use ---
response = call_anthropic(
model_name=MODEL_NAME,
messages=messages_first_call,
tools=tools, # Provide the GQL tool schema
)
first_call_assistant_response = (
response # Store this response for potential next retry
)
_, tool_name, tool_input, _ = extract_anthropic_tool_use(response)
tool_name_used_in_test = tool_name
tool_input_used_in_test = tool_input
logger.info(f"Attempt {attempt + 1}: Tool emitted by model: {tool_name}")
logger.info(
f"Attempt {attempt + 1}: Tool input: {json.dumps(tool_input, indent=2)}"
)
assert tool_name == "query_paginated_wandb_gql", (
f"Attempt {attempt + 1}: Expected 'query_paginated_wandb_gql', got '{tool_name}'"
)
# --- Execute the GQL tool ---
try:
tool_result = available_tools[tool_name]["function"](**tool_input)
logger.info(
f"Attempt {attempt + 1}: Tool result: {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}"
) # Log full result
except Exception as e:
logger.error(
f"Attempt {attempt + 1}: Error executing tool '{tool_name}' with input {tool_input}: {e}",
exc_info=True,
)
final_log_data_for_file["output"]["tool_execution_error_details"] = str(
e
)
# If tool execution fails, we might want to stop retrying for this sample or handle differently.
# For now, it will proceed to correctness check which will likely fail or be skipped.
# Depending on the error, we might want to `pytest.fail` or `raise` to stop the current attempt.
# For this iteration, we'll let it go to the correctness check, which will likely fail it.
last_is_correct = False
last_reasoning = f"Tool execution failed: {e}"
if attempt >= max_retries: # If this was the last attempt
raise # Re-raise the exception to fail the test
continue # Skip to next retry attempt
# --- Second Call: Perform Correctness Check (Separate Task) ---
logger.info(
f"\n--- Starting Correctness Check for Attempt {attempt + 1} ---"
)
try:
# Prepare the prompt for the check - provide all context clearly
correctness_prompt = f"""
Please evaluate if the provided 'Actual Tool Result' correctly addresses the 'Original User Query' and seems consistent with the 'Expected Output'. Use the 'check_correctness_tool' to provide your reasoning and conclusion.
Original User Query:
"{query_text}"
Expected Output (for context, may not be directly comparable in structure):
{json.dumps(expected_output, indent=2, cls=DateTimeEncoder)}
Actual Tool Result from 'query_paginated_wandb_gql':
{json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}
"""
messages_check_call = [{"role": "user", "content": correctness_prompt}]
correctness_response = call_anthropic(
model_name=CORRECTNESS_MODEL_NAME,
messages=messages_check_call,
check_correctness_tool=check_correctness_tool,
)
logger.info(
f"Attempt {attempt + 1}: Correctness check response:\n{correctness_response}\n\n"
)
# --- Extract and Validate Correctness Tool Use ---
_, check_tool_name, check_tool_input, _ = extract_anthropic_tool_use(
correctness_response
)
assert check_tool_name == "check_correctness_tool", (
f"Attempt {attempt + 1}: Expected correctness tool, got {check_tool_name}"
)
assert "reasoning" in check_tool_input, (
f"Attempt {attempt + 1}: Correctness tool missing 'reasoning'"
)
assert "is_correct" in check_tool_input, (
f"Attempt {attempt + 1}: Correctness tool missing 'is_correct'"
)
# 2. Extract the data from the input dictionary
try:
reasoning_text = check_tool_input["reasoning"]
is_correct_flag = check_tool_input["is_correct"]
# Store the latest results
last_reasoning = reasoning_text
last_is_correct = is_correct_flag
logger.info(
f"Attempt {attempt + 1}: Correctness Reasoning: {reasoning_text}"
)
logger.info(
f"Attempt {attempt + 1}: Is Correct according to LLM: {is_correct_flag}"
)
if is_correct_flag:
logger.info(
f"--- Correctness check passed on attempt {attempt + 1}. ---"
)
final_log_data_for_file["score"] = True
break # Exit the loop successfully
# If not correct, and this is the last attempt, the loop will end naturally.
except KeyError as e:
logger.error(
f"Attempt {attempt + 1}: Missing expected key in correctness tool input: {e}"
)
logger.error(
f"Attempt {attempt + 1}: Full input received: {check_tool_input}"
)
last_is_correct = False
last_reasoning = f"Correctness tool response missing key: {e}"
final_log_data_for_file["output"]["assertion_error_details"] = (
f"Correctness tool response missing key: {e}"
)
if attempt >= max_retries:
pytest.fail(
f"Attempt {attempt + 1}: Correctness tool response was missing key: {e}"
)
continue # To next retry
except Exception as e:
logger.error(
f"Attempt {attempt + 1}: Error processing correctness tool input: {e}",
exc_info=True,
)
last_is_correct = False
last_reasoning = f"Failed to process correctness tool input: {e}"
final_log_data_for_file["output"]["assertion_error_details"] = (
f"Failed to process correctness tool input: {e}"
)
if attempt >= max_retries:
pytest.fail(
f"Attempt {attempt + 1}: Failed to process correctness tool input: {e}"
)
continue # To next retry
except Exception as e:
logger.error(
f"Attempt {attempt + 1}: Error during correctness check for query '{query_text}': {e}",
exc_info=True,
)
last_is_correct = False
last_reasoning = f"Correctness check failed with exception: {e}"
final_log_data_for_file["output"]["assertion_error_details"] = (
f"Correctness check failed with exception: {e}"
)
if attempt >= max_retries:
pytest.fail(
f"Attempt {attempt + 1}: Correctness check failed with exception: {e}"
)
continue # To next retry
# After the loop, if not last_is_correct, it means all retries failed or it failed on the last attempt.
if not last_is_correct and attempt >= max_retries:
pytest.fail(
f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. "
f"Final is_correct_flag is `{last_is_correct}`. "
f"Final Reasoning: '{last_reasoning}'"
)
except Exception as test_exec_exception:
# Catch any exception that might cause the test to fail before all retries are done
# or even before the loop fully completes.
logger.error(
f"Test execution for sample {sample_index} failed globally: {test_exec_exception}",
exc_info=True,
)
final_log_data_for_file["score"] = False
final_log_data_for_file["output"]["test_exception"] = str(test_exec_exception)
# We will write the JSON in `finally`, then re-raise or let pytest handle the failure.
raise # Re-raise the caught exception to ensure the test is marked as failed by pytest
finally:
end_time = time.monotonic()
execution_latency_seconds = end_time - start_time
final_log_data_for_file["metrics"]["execution_latency_seconds"] = (
execution_latency_seconds
)
final_log_data_for_file["metadata"]["final_attempt_number_for_json"] = (
final_log_data_for_file["metadata"]["retry_attempt"]
) # Should be updated inside loop
# Populate output details from the last successful (or last attempted) tool call
final_log_data_for_file["output"]["tool_name"] = tool_name_used_in_test
final_log_data_for_file["output"]["tool_input"] = (
json.dumps(tool_input_used_in_test, indent=2)
if tool_input_used_in_test
else None
)
final_log_data_for_file["output"]["tool_result"] = (
json.dumps(tool_result, indent=2, cls=DateTimeEncoder)
if tool_result
else None
)
final_log_data_for_file["output"]["correctness_reasoning"] = last_reasoning
final_log_data_for_file["score"] = last_is_correct # Ensure final score is set
# Generate a unique filename for the JSON output
unique_file_id = str(uuid.uuid4())
worker_id = os.environ.get(
"PYTEST_XDIST_WORKER", "main_thread"
) # Default if not in xdist
# Sanitize test_case_name for filename (take first 30 chars, replace spaces)
safe_test_name_part = (
test_case_name.replace(" ", "_").replace("/", "_").replace("\\", "_")[:30]
)
file_name = f"gql_test_idx_{sample_index}_{safe_test_name_part}_w_{worker_id}_attempt_{final_log_data_for_file['metadata']['final_attempt_number_for_json']}_{('pass' if final_log_data_for_file['score'] else 'fail')}_{unique_file_id}.json"
file_path = weave_results_dir / file_name
logger.critical(
f"WRITING JSON for GQL Test: {test_case_name} (Index: {sample_index}, Last Attempt: {final_log_data_for_file['metadata']['final_attempt_number_for_json']}, Score: {final_log_data_for_file['score']}) to {file_path}"
)
try:
with open(file_path, "w") as f:
json.dump(final_log_data_for_file, f, indent=2, cls=DateTimeEncoder)
logger.info(
f"Result for GQL test {test_case_name} (Latency: {execution_latency_seconds:.2f}s) written to {file_path}"
)
except Exception as e:
logger.error(
f"Failed to write result JSON for GQL test {test_case_name} to {file_path}: {e}"
)
# If we reach here and no exception was raised by pytest.fail or re-raised from the try block,
# it means the correctness check passed within the allowed attempts.
if not last_is_correct: # Final check if loop exited due to retries without success
pytest.fail(
f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. "
f"Final is_correct_flag is `{last_is_correct}`. "
f"Final Reasoning: '{last_reasoning}'"
)
logger.info(
f"--- Test for sample {sample_index} ({test_case_name}) completed. Score: {last_is_correct} ---"
)