File size: 22,981 Bytes
f647629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
# noqa: D100
"""Integration tests that verify Anthropic selects `query_wandb_tool`.

These tests send natural-language questions about the W&B *Models* data for the
`wandb-applied-ai-team/mcp-tests` project.  The Anthropic model should respond
with a `tool_use` invoking `query_wandb_tool`, which we then execute and
validate.
"""

import json
import os
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List

import pytest

from tests.anthropic_test_utils import (
    call_anthropic,
    check_correctness_tool,
    extract_anthropic_tool_use,
)
from wandb_mcp_server.mcp_tools.query_wandb_gql import (
    QUERY_WANDB_GQL_TOOL_DESCRIPTION,
    query_paginated_wandb_gql,
)
from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema
from wandb_mcp_server.utils import get_git_commit, get_rich_logger

# Root logging configuration
logger = get_rich_logger(__name__)

# weave.init("wandb-applied-ai-team/wandb-mcp-server-test-outputs")
# os.environ["WANDB_SILENT"] = "true"


# -----------------------------------------------------------------------------
# Custom JSON encoder for datetime objects (similar to test_query_weave_traces.py)
# -----------------------------------------------------------------------------
class DateTimeEncoder(json.JSONEncoder):
    """JSON encoder that can handle datetime objects."""

    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()
        return super().default(obj)


# -----------------------------------------------------------------------------
# Environment guards
# -----------------------------------------------------------------------------

WANDB_API_KEY = os.getenv("WANDB_API_KEY")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")

if not WANDB_API_KEY:
    pytest.skip(
        "WANDB_API_KEY environment variable not set; skipping live GraphQL tests.",
        allow_module_level=True,
    )
if not ANTHROPIC_API_KEY:
    pytest.skip(
        "ANTHROPIC_API_KEY environment variable not set; skipping Anthropic tests.",
        allow_module_level=True,
    )

# -----------------------------------------------------------------------------
# Static test context
# -----------------------------------------------------------------------------

TEST_WANDB_ENTITY = "wandb-applied-ai-team"
TEST_WANDB_PROJECT = "mcp-tests"
# MODEL_NAME = "claude-3-7-sonnet-20250219"
# MODEL_NAME = "claude-4-sonnet-20250514"
MODEL_NAME = "claude-4-opus-20250514"
CORRECTNESS_MODEL_NAME = "claude-3-5-haiku-20241022"

# -----------------------------------------------------------------------------
# Build tool schema for Anthropic
# -----------------------------------------------------------------------------

available_tools: Dict[str, Dict[str, Any]] = {
    "query_paginated_wandb_gql": {
        "function": query_paginated_wandb_gql,
        "schema": generate_anthropic_tool_schema(
            func=query_paginated_wandb_gql,
            description=QUERY_WANDB_GQL_TOOL_DESCRIPTION,
        ),
    }
}

tools: List[Dict[str, Any]] = [available_tools["query_paginated_wandb_gql"]["schema"]]

# -----------------------------------------------------------------------------
# Compute baseline runCount once so that tests have a stable expected value
# -----------------------------------------------------------------------------

BASELINE_QUERY = """
query ProjectRunCount($entity: String!, $project: String!) {
  project(name: $project, entityName: $entity) {
    runCount
  }
}
"""
BASELINE_VARIABLES = {"entity": TEST_WANDB_ENTITY, "project": TEST_WANDB_PROJECT}

# Compute baseline
logger.info(
    "Fetching baseline runCount for %s/%s", TEST_WANDB_ENTITY, TEST_WANDB_PROJECT
)
_baseline_result = query_paginated_wandb_gql(BASELINE_QUERY, BASELINE_VARIABLES)
BASELINE_RUN_COUNT: int = _baseline_result["project"]["runCount"]
logger.info("Baseline runCount = %s", BASELINE_RUN_COUNT)

# -----------------------------------------------------------------------------
# Natural-language queries to test
# -----------------------------------------------------------------------------

test_queries = [
    {
        "index": 0,
        "question": "How many runs are currently logged in the `{project_name}` project under the `{entity_name}` entity?",
        "expected_output": 37,
    },
    {
        "index": 1,
        "question": "What's the total experiment count for `{entity_name}/{project_name}`?",
        "expected_output": 37,
    },
    {
        "index": 2,
        "question": "In `{project_name}` in entity `{entity_name}` how many runs were run on April 29th 2025?",
        "expected_output": 37,
    },
    {
        "index": 3,
        "question": "Could you report the number of tracked runs in `{entity_name}/{project_name}` with lr 0.002?",
        "expected_output": 7,
    },
    {
        "index": 4,
        "question": "what was the run with the best eval loss in the `{project_name}` project belonging to `{entity_name}`.",
        "expected_output": "run_id: h0fm5qp5 OR run_name: transformer_7_bs-128_lr-0.008_5593616",
    },
    {
        "index": 5,
        "question": "How many steps in run gtng2y4l `{entity_name}/{project_name}` right now.",
        "expected_output": 750000,
    },
    {
        "index": 6,
        "question": "How many steps in run transformer_25_bs-33554432_lr-0.026000000000000002_2377215 `{entity_name}/{project_name}` right now.",
        "expected_output": 750000,
    },
    {
        "index": 7,
        "question": "What's the batch size of the run with best evaluation accuracy for `{project_name}` inside `{entity_name}`?",
        "expected_output": 16,
    },
    # {
    #     "index": 8, # Example if uncommented
    #     "question": "Count the runs in my `{entity_name}` entity for the `{project_name}` project.",
    #     "expected_output": BASELINE_RUN_COUNT,
    # },
    # {
    #     "index": 9, # Example if uncommented
    #     "question": "How big is the experiment set for `{entity_name}/{project_name}`?",
    #     "expected_output": BASELINE_RUN_COUNT,
    # },
    # {
    #     "index": 10, # Example if uncommented
    #     "question": "Tell me the number of runs tracked in `{project_name}` (entity `{entity_name}`).",
    #     "expected_output": BASELINE_RUN_COUNT,
    # },
]


# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------


@pytest.mark.parametrize(
    "sample",
    test_queries,
    ids=[f"sample_{i}" for i, _ in enumerate(test_queries)],
)
def test_query_wandb_gql(sample, weave_results_dir):
    """End-to-end test: NL question → Anthropic → tool_use → result validation."""

    start_time = time.monotonic()
    current_git_commit = get_git_commit()
    git_commit_id = f"commit_{current_git_commit}"
    current_test_file_name = os.path.basename(__file__)

    # Find the index of the current sample for unique naming and metadata
    sample_index = -1
    for i, s in enumerate(test_queries):
        if s == sample:
            sample_index = i
            break

    test_case_name = f"gql_query_{sample_index}_{sample.get('question', 'unknown_question')[:20].replace(' ', '_')}"

    query_text = sample["question"].format(
        entity_name=TEST_WANDB_ENTITY,
        project_name=TEST_WANDB_PROJECT,
    )
    expected_output = sample["expected_output"]

    logger.info("\n==============================")
    logger.info("QUERY: %s", query_text)

    # --- Retry Logic Setup ---
    max_retries = 1
    last_reasoning = "No correctness check performed yet."
    last_is_correct = False
    first_call_assistant_response = None  # Store the response dict from the first model
    tool_result = None  # Store the result of executing the tool
    tool_name_used_in_test = None
    tool_input_used_in_test = None

    # Initialize log_data_for_file for the current test sample
    final_log_data_for_file = {
        "metadata": {
            "sample_name": test_case_name,
            "test_case_index": sample_index,
            "git_commit_id": git_commit_id,
            "source_test_file_name": current_test_file_name,
            "test_query_text": query_text,
            "expected_test_output": str(expected_output),
            "retry_attempt": 0,  # Will be updated in the loop
            "max_retries_configured": max_retries,
        },
        "inputs": {  # Inputs to the overall test/evaluation
            "test_query": query_text,
            "expected_value": str(expected_output),
        },
        "output": {},  # Will store tool output and correctness check details
        "score": False,  # Default to False, updated on success
        "scorer_name": "gql_correctness_assertion",  # Specific scorer for these tests
        "metrics": {},  # Will store execution_latency_seconds
    }

    try:
        # Initial messages for the first attempt
        messages_first_call = [{"role": "user", "content": query_text}]

        for attempt in range(max_retries + 1):
            logger.info(f"\n--- Attempt {attempt + 1} / {max_retries + 1} ---")
            final_log_data_for_file["metadata"]["retry_attempt"] = attempt + 1

            if attempt > 0:
                # We are retrying. Add the previous assistant response and a user message with feedback.
                if first_call_assistant_response:
                    messages_first_call.append(
                        first_call_assistant_response
                    )  # Add previous assistant message (contains tool use)
                else:
                    # Should not happen in retry logic, but defensively handle
                    logger.warning(
                        "Attempting retry, but no previous assistant response found."
                    )

                # Construct the user message asking for a retry
                retry_user_message_content = f"""
Executing the previous tool call resulted in:
```json
{json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}
```
A separate check determined this result was incorrect for the original query.
The reasoning provided was: "{last_reasoning}".

Please re-analyze the original query ("{query_text}") and the result from your previous attempt, then try generating the 'query_paginated_wandb_gql' tool call again.
"""
                messages_first_call.append(
                    {"role": "user", "content": retry_user_message_content}
                )

            # --- First Call: Get the query_paginated_wandb_gql tool use ---
            response = call_anthropic(
                model_name=MODEL_NAME,
                messages=messages_first_call,
                tools=tools,  # Provide the GQL tool schema
            )
            first_call_assistant_response = (
                response  # Store this response for potential next retry
            )
            _, tool_name, tool_input, _ = extract_anthropic_tool_use(response)

            tool_name_used_in_test = tool_name
            tool_input_used_in_test = tool_input

            logger.info(f"Attempt {attempt + 1}: Tool emitted by model: {tool_name}")
            logger.info(
                f"Attempt {attempt + 1}: Tool input: {json.dumps(tool_input, indent=2)}"
            )

            assert tool_name == "query_paginated_wandb_gql", (
                f"Attempt {attempt + 1}: Expected 'query_paginated_wandb_gql', got '{tool_name}'"
            )

            # --- Execute the GQL tool ---
            try:
                tool_result = available_tools[tool_name]["function"](**tool_input)
                logger.info(
                    f"Attempt {attempt + 1}: Tool result: {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}"
                )  # Log full result
            except Exception as e:
                logger.error(
                    f"Attempt {attempt + 1}: Error executing tool '{tool_name}' with input {tool_input}: {e}",
                    exc_info=True,
                )
                final_log_data_for_file["output"]["tool_execution_error_details"] = str(
                    e
                )
                # If tool execution fails, we might want to stop retrying for this sample or handle differently.
                # For now, it will proceed to correctness check which will likely fail or be skipped.
                # Depending on the error, we might want to `pytest.fail` or `raise` to stop the current attempt.
                # For this iteration, we'll let it go to the correctness check, which will likely fail it.
                last_is_correct = False
                last_reasoning = f"Tool execution failed: {e}"
                if attempt >= max_retries:  # If this was the last attempt
                    raise  # Re-raise the exception to fail the test
                continue  # Skip to next retry attempt

            # --- Second Call: Perform Correctness Check (Separate Task) ---
            logger.info(
                f"\n--- Starting Correctness Check for Attempt {attempt + 1} ---"
            )

            try:
                # Prepare the prompt for the check - provide all context clearly
                correctness_prompt = f"""
                Please evaluate if the provided 'Actual Tool Result' correctly addresses the 'Original User Query' and seems consistent with the 'Expected Output'. Use the 'check_correctness_tool' to provide your reasoning and conclusion.

                Original User Query:
                "{query_text}"

                Expected Output (for context, may not be directly comparable in structure):
                {json.dumps(expected_output, indent=2, cls=DateTimeEncoder)}

                Actual Tool Result from 'query_paginated_wandb_gql':
                {json.dumps(tool_result, indent=2, cls=DateTimeEncoder)}
                """

                messages_check_call = [{"role": "user", "content": correctness_prompt}]
                correctness_response = call_anthropic(
                    model_name=CORRECTNESS_MODEL_NAME,
                    messages=messages_check_call,
                    check_correctness_tool=check_correctness_tool,
                )

                logger.info(
                    f"Attempt {attempt + 1}: Correctness check response:\n{correctness_response}\n\n"
                )

                # --- Extract and Validate Correctness Tool Use ---
                _, check_tool_name, check_tool_input, _ = extract_anthropic_tool_use(
                    correctness_response
                )

                assert check_tool_name == "check_correctness_tool", (
                    f"Attempt {attempt + 1}: Expected correctness tool, got {check_tool_name}"
                )
                assert "reasoning" in check_tool_input, (
                    f"Attempt {attempt + 1}: Correctness tool missing 'reasoning'"
                )
                assert "is_correct" in check_tool_input, (
                    f"Attempt {attempt + 1}: Correctness tool missing 'is_correct'"
                )

                # 2. Extract the data from the input dictionary
                try:
                    reasoning_text = check_tool_input["reasoning"]
                    is_correct_flag = check_tool_input["is_correct"]

                    # Store the latest results
                    last_reasoning = reasoning_text
                    last_is_correct = is_correct_flag

                    logger.info(
                        f"Attempt {attempt + 1}: Correctness Reasoning: {reasoning_text}"
                    )
                    logger.info(
                        f"Attempt {attempt + 1}: Is Correct according to LLM: {is_correct_flag}"
                    )

                    if is_correct_flag:
                        logger.info(
                            f"--- Correctness check passed on attempt {attempt + 1}. ---"
                        )
                        final_log_data_for_file["score"] = True
                        break  # Exit the loop successfully

                    # If not correct, and this is the last attempt, the loop will end naturally.

                except KeyError as e:
                    logger.error(
                        f"Attempt {attempt + 1}: Missing expected key in correctness tool input: {e}"
                    )
                    logger.error(
                        f"Attempt {attempt + 1}: Full input received: {check_tool_input}"
                    )
                    last_is_correct = False
                    last_reasoning = f"Correctness tool response missing key: {e}"
                    final_log_data_for_file["output"]["assertion_error_details"] = (
                        f"Correctness tool response missing key: {e}"
                    )
                    if attempt >= max_retries:
                        pytest.fail(
                            f"Attempt {attempt + 1}: Correctness tool response was missing key: {e}"
                        )
                    continue  # To next retry
                except Exception as e:
                    logger.error(
                        f"Attempt {attempt + 1}: Error processing correctness tool input: {e}",
                        exc_info=True,
                    )
                    last_is_correct = False
                    last_reasoning = f"Failed to process correctness tool input: {e}"
                    final_log_data_for_file["output"]["assertion_error_details"] = (
                        f"Failed to process correctness tool input: {e}"
                    )
                    if attempt >= max_retries:
                        pytest.fail(
                            f"Attempt {attempt + 1}: Failed to process correctness tool input: {e}"
                        )
                    continue  # To next retry

            except Exception as e:
                logger.error(
                    f"Attempt {attempt + 1}: Error during correctness check for query '{query_text}': {e}",
                    exc_info=True,
                )
                last_is_correct = False
                last_reasoning = f"Correctness check failed with exception: {e}"
                final_log_data_for_file["output"]["assertion_error_details"] = (
                    f"Correctness check failed with exception: {e}"
                )
                if attempt >= max_retries:
                    pytest.fail(
                        f"Attempt {attempt + 1}: Correctness check failed with exception: {e}"
                    )
                continue  # To next retry

        # After the loop, if not last_is_correct, it means all retries failed or it failed on the last attempt.
        if not last_is_correct and attempt >= max_retries:
            pytest.fail(
                f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. "
                f"Final is_correct_flag is `{last_is_correct}`. "
                f"Final Reasoning: '{last_reasoning}'"
            )

    except Exception as test_exec_exception:
        # Catch any exception that might cause the test to fail before all retries are done
        # or even before the loop fully completes.
        logger.error(
            f"Test execution for sample {sample_index} failed globally: {test_exec_exception}",
            exc_info=True,
        )
        final_log_data_for_file["score"] = False
        final_log_data_for_file["output"]["test_exception"] = str(test_exec_exception)
        # We will write the JSON in `finally`, then re-raise or let pytest handle the failure.
        raise  # Re-raise the caught exception to ensure the test is marked as failed by pytest

    finally:
        end_time = time.monotonic()
        execution_latency_seconds = end_time - start_time
        final_log_data_for_file["metrics"]["execution_latency_seconds"] = (
            execution_latency_seconds
        )
        final_log_data_for_file["metadata"]["final_attempt_number_for_json"] = (
            final_log_data_for_file["metadata"]["retry_attempt"]
        )  # Should be updated inside loop

        # Populate output details from the last successful (or last attempted) tool call
        final_log_data_for_file["output"]["tool_name"] = tool_name_used_in_test
        final_log_data_for_file["output"]["tool_input"] = (
            json.dumps(tool_input_used_in_test, indent=2)
            if tool_input_used_in_test
            else None
        )
        final_log_data_for_file["output"]["tool_result"] = (
            json.dumps(tool_result, indent=2, cls=DateTimeEncoder)
            if tool_result
            else None
        )
        final_log_data_for_file["output"]["correctness_reasoning"] = last_reasoning
        final_log_data_for_file["score"] = last_is_correct  # Ensure final score is set

        # Generate a unique filename for the JSON output
        unique_file_id = str(uuid.uuid4())
        worker_id = os.environ.get(
            "PYTEST_XDIST_WORKER", "main_thread"
        )  # Default if not in xdist

        # Sanitize test_case_name for filename (take first 30 chars, replace spaces)
        safe_test_name_part = (
            test_case_name.replace(" ", "_").replace("/", "_").replace("\\", "_")[:30]
        )

        file_name = f"gql_test_idx_{sample_index}_{safe_test_name_part}_w_{worker_id}_attempt_{final_log_data_for_file['metadata']['final_attempt_number_for_json']}_{('pass' if final_log_data_for_file['score'] else 'fail')}_{unique_file_id}.json"
        file_path = weave_results_dir / file_name

        logger.critical(
            f"WRITING JSON for GQL Test: {test_case_name} (Index: {sample_index}, Last Attempt: {final_log_data_for_file['metadata']['final_attempt_number_for_json']}, Score: {final_log_data_for_file['score']}) to {file_path}"
        )
        try:
            with open(file_path, "w") as f:
                json.dump(final_log_data_for_file, f, indent=2, cls=DateTimeEncoder)
            logger.info(
                f"Result for GQL test {test_case_name} (Latency: {execution_latency_seconds:.2f}s) written to {file_path}"
            )
        except Exception as e:
            logger.error(
                f"Failed to write result JSON for GQL test {test_case_name} to {file_path}: {e}"
            )

    # If we reach here and no exception was raised by pytest.fail or re-raised from the try block,
    # it means the correctness check passed within the allowed attempts.
    if not last_is_correct:  # Final check if loop exited due to retries without success
        pytest.fail(
            f"LLM evaluation failed after {max_retries + 1} attempts for sample {sample_index}. "
            f"Final is_correct_flag is `{last_is_correct}`. "
            f"Final Reasoning: '{last_reasoning}'"
        )

    logger.info(
        f"--- Test for sample {sample_index} ({test_case_name}) completed. Score: {last_is_correct} ---"
    )