mcp-server / tests /test_count_traces.py
NiWaRe's picture
mcp_base
f647629
import os
import pytest
from wandb_mcp_server.utils import get_rich_logger
from tests.anthropic_test_utils import call_anthropic, extract_anthropic_tool_use
from wandb_mcp_server.mcp_tools.count_traces import (
COUNT_WEAVE_TRACES_TOOL_DESCRIPTION,
count_traces,
)
from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema
logger = get_rich_logger(__name__)
os.environ["WANDB_SILENT"] = "true"
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
if not ANTHROPIC_API_KEY:
pytest.skip(
"ANTHROPIC_API_KEY environment variable not set; skipping live Anthropic tests.",
allow_module_level=True,
)
TEST_WANDB_ENTITY = "wandb-applied-ai-team" # "c-metrics"
TEST_WANDB_PROJECT = "wandb-mcp-tests"
model_name = "claude-3-7-sonnet-20250219"
available_tools = {
"count_traces": {
"function": count_traces,
"schema": generate_anthropic_tool_schema(
func=count_traces, description=COUNT_WEAVE_TRACES_TOOL_DESCRIPTION
),
}
}
tools = [available_tools["count_traces"]["schema"]]
test_queries = [
{
"question": "Please count the total number of traces recorded in the `{project_name}` project under the `{entity_name}` entity.",
"expected_output": 21639,
},
{
"question": "How many Weave call logs exist for the `{project_name}` project in my `{entity_name}` entity?", # (Uses "call logs" instead of "traces")
"expected_output": 21639,
},
{
"question": "What's the volume of traces for `{project_name}` in the `{entity_name}` entity?", # (Assumes default entity or requires clarification, implies counting)
"expected_output": 21639,
},
{
"question": "Count the calls that resulted in an error within the `{entity_name}/{project_name}` project.", # (Requires filtering by status='error')
"expected_output": 136,
},
{
"question": "How many times has the `generate_joke` operation been invoked in the `{project_name}` project for the `{entity_name}`?", # (Requires filtering by op_name)
"expected_output": 4,
},
{
"question": "The date is March 12th, 2025. Give me the parent trace count for `{entity_name}/{project_name}` last month.", # (Requires calculating and applying a time filter)
"expected_output": 262,
},
{
"question": "Can you count the parent traces in `{entity_name}/{project_name}`?", # (Requires, root traces)
"expected_output": 475,
},
{
"question": "`{entity_name}/{project_name}` trace tally?", # (Requires inferring the need for counting and likely asking for the entity)
"expected_output": 21639,
},
{
"question": "How many traces in `{entity_name}/{project_name}` took more than 10 minutes to run?", # (Requires an attribute filter)
"expected_output": 155,
},
{
"question": "How many traces in `{entity_name}/{project_name}` took less than 2 seconds to run?", # (Requires an attribute filter)
"expected_output": 12357,
},
{
"question": "THe date is April 20th, 2025. Count failed traces for the `openai.chat.completions` op within the `{entity_name}/{project_name}` project since the 27th of February 2025 up to March 1st..", # (Requires combining status='success', trace_roots_only=True, op_name, and a time filter)
"expected_output": 15,
},
]
# -----------------------
# Pytest integration
# -----------------------
@pytest.mark.parametrize(
"sample", test_queries, ids=[f"sample_{i}" for i, _ in enumerate(test_queries)]
)
def test_count_traces(sample):
"""Run each natural-language query end-to-end through the Anthropic model and
verify that the invoked tool returns the expected value."""
query_text = sample["question"].format(
entity_name=TEST_WANDB_ENTITY,
project_name=TEST_WANDB_PROJECT,
)
expected_output = sample["expected_output"]
logger.info("==============================")
logger.info(f"QUERY: {query_text}")
messages = [{"role": "user", "content": query_text}]
response = call_anthropic(model_name, messages, tools)
_, tool_name, tool_input, _ = extract_anthropic_tool_use(response)
logger.info(f"Tool emitted by model: {tool_name}")
logger.debug(f"Tool input: {tool_input}")
assert tool_name is not None, "Model did not emit a tool call"
# Execute the real tool — no mocking.
tool_result = available_tools[tool_name]["function"](**tool_input)
logger.info(f"Tool result: {tool_result} (expected {expected_output})")
assert tool_result == expected_output, (
f"Unexpected result for query `{query_text}`: {tool_result} (expected {expected_output})"
)