Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any, Dict, List | |
| import pytest | |
| from wandb_mcp_server.utils import get_rich_logger | |
| from tests.anthropic_test_utils import ( | |
| call_anthropic, | |
| check_correctness_tool, | |
| extract_anthropic_tool_use, | |
| get_anthropic_tool_result_message, | |
| ) | |
| from wandb_mcp_server.mcp_tools.query_wandbot import ( | |
| WANDBOT_TOOL_DESCRIPTION, | |
| query_wandbot_api, | |
| ) | |
| from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema | |
| # ----------------------------------------------------------------------------- | |
| # Logging & env guards | |
| # ----------------------------------------------------------------------------- | |
| logger = get_rich_logger(__name__) | |
| # ----------------------------------------------------------------------------- | |
| # Environment guards | |
| # ----------------------------------------------------------------------------- | |
| ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
| WANDBOT_BASE_URL = os.getenv("WANDBOT_TEST_URL", "https://morg--wandbot-api-wandbotapi-serve.modal.run") | |
| if not ANTHROPIC_API_KEY: | |
| pytest.skip( | |
| "ANTHROPIC_API_KEY environment variable not set; skipping Anthropic tests.", | |
| allow_module_level=True, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Static test context | |
| # ----------------------------------------------------------------------------- | |
| MODEL_NAME = "claude-3-7-sonnet-20250219" | |
| CORRECTNESS_MODEL_NAME = "claude-3-5-haiku-20241022" | |
| # ----------------------------------------------------------------------------- | |
| # Build tool schema for Anthropic | |
| # ----------------------------------------------------------------------------- | |
| available_tools: Dict[str, Dict[str, Any]] = { | |
| "query_wandbot_api": { | |
| "function": query_wandbot_api, | |
| "schema": generate_anthropic_tool_schema( | |
| func=query_wandbot_api, # Pass the function itself | |
| description=WANDBOT_TOOL_DESCRIPTION, # Use the imported description | |
| ), | |
| } | |
| } | |
| tools: List[Dict[str, Any]] = [available_tools["query_wandbot_api"]["schema"]] | |
| # ----------------------------------------------------------------------------- | |
| # Natural-language queries to test | |
| # ----------------------------------------------------------------------------- | |
| test_queries = [ | |
| { | |
| "question": "What kinds of scorers does weave support?", | |
| "expected_output": "There are 2 types of scorers in weave, Function-based and Class-based.", | |
| }, | |
| # Add more test cases here later | |
| ] | |
| # ----------------------------------------------------------------------------- | |
| # Tests | |
| # ----------------------------------------------------------------------------- | |
| def test_query_wandbot(sample): | |
| """End-to-end test: NL question → Anthropic → tool_use → result validation with correctness check.""" | |
| query_text = sample["question"] | |
| expected_output = sample[ | |
| "expected_output" | |
| ] # Get expected output for correctness check | |
| logger.info("\n==============================") | |
| logger.info("QUERY: %s", query_text) | |
| # --- Retry Logic Setup --- | |
| max_retries = 1 | |
| last_reasoning = "No correctness check performed yet." | |
| last_is_correct = False | |
| first_call_assistant_response = None # Store the response dict from the first model | |
| tool_result = None # Store the result of executing the tool | |
| tool_use_id = None # Initialize tool_use_id *before* the loop | |
| # Initial messages for the first attempt | |
| messages_first_call = [{"role": "user", "content": query_text}] | |
| for attempt in range(max_retries + 1): | |
| logger.info(f"\n--- Attempt {attempt + 1} / {max_retries + 1} ---") | |
| current_messages = messages_first_call # Start with the base messages | |
| if attempt > 0: | |
| # Retry logic: Add previous assistant response, tool result, and user feedback | |
| retry_messages = [] | |
| if first_call_assistant_response: | |
| # 1. Add previous assistant message (contains tool use) | |
| retry_messages.append( | |
| { | |
| "role": first_call_assistant_response.role, | |
| "content": first_call_assistant_response.content, | |
| } | |
| ) | |
| # 2. Add the result from executing the tool in the previous attempt | |
| if tool_result is not None and tool_use_id is not None: | |
| tool_result_message = get_anthropic_tool_result_message( | |
| tool_result, tool_use_id | |
| ) | |
| retry_messages.append(tool_result_message) | |
| else: | |
| logger.warning( | |
| f"Attempt {attempt + 1}: Cannot add tool result message, tool_result or tool_use_id missing." | |
| ) | |
| # 3. Add the user message asking for a retry | |
| retry_user_message_content = f""" | |
| Executing the previous tool call resulted in: | |
| ```json | |
| {json.dumps(tool_result, indent=2)} | |
| ``` | |
| A separate check determined this result was incorrect for the original query. | |
| The reasoning provided was: "{last_reasoning}". | |
| Please re-analyze the original query ("{query_text}") and the result from your previous attempt, then try generating the '{available_tools["query_wandbot_api"]["schema"]["name"]}' tool call again. | |
| """ | |
| retry_messages.append( | |
| {"role": "user", "content": retry_user_message_content} | |
| ) | |
| current_messages = ( | |
| messages_first_call[:1] + retry_messages | |
| ) # Rebuild message list for retry | |
| else: | |
| logger.warning( | |
| "Attempting retry, but no previous assistant response or tool_use_id found." | |
| ) | |
| # If retry is needed but we lack context, we probably should just fail or stick with original messages | |
| # For now, let's proceed with original messages, though this might not be ideal. | |
| current_messages = messages_first_call | |
| # --- First Call: Get the query_wandbot_api tool use --- | |
| try: | |
| response = call_anthropic( | |
| model_name=MODEL_NAME, | |
| messages=current_messages, # Use the potentially updated message list | |
| tools=tools, | |
| ) | |
| first_call_assistant_response = response # Store for potential *next* retry | |
| except Exception as e: | |
| pytest.fail(f"Attempt {attempt + 1}: Anthropic API call failed: {e}") | |
| try: | |
| # Extract tool_use_id here | |
| _, tool_name, tool_input, tool_use_id = extract_anthropic_tool_use(response) | |
| if tool_use_id is None: | |
| logger.warning( | |
| f"Attempt {attempt + 1}: Model did not return a tool use block." | |
| ) | |
| # Decide how to handle this - maybe fail, maybe retry without tool use? | |
| # For now, continue to execution, it might fail gracefully or correctness check will catch it. | |
| except ValueError as e: | |
| logger.error( | |
| f"Attempt {attempt + 1}: Failed to extract tool use from response: {response}" | |
| ) | |
| pytest.fail(f"Attempt {attempt + 1}: Could not extract tool use: {e}") | |
| logger.info(f"Attempt {attempt + 1}: Tool emitted by model: {tool_name}") | |
| logger.info( | |
| f"Attempt {attempt + 1}: Tool input: {json.dumps(tool_input, indent=2)}" | |
| ) | |
| assert tool_name == "query_wandbot_api", ( | |
| f"Attempt {attempt + 1}: Expected 'query_wandbot_api', got '{tool_name}'" | |
| ) | |
| assert "question" in tool_input, ( | |
| f"Attempt {attempt + 1}: Tool input missing 'question'" | |
| ) | |
| # --- Execute the WandBot tool --- | |
| try: | |
| # --- Ensure only expected args based on the *current* function signature are passed --- | |
| # Assuming the function now only takes 'question' | |
| if "question" not in tool_input: | |
| pytest.fail( | |
| f"Attempt {attempt + 1}: Tool input missing required 'question' argument." | |
| ) | |
| actual_args = {"question": tool_input["question"]} | |
| tool_result = available_tools[tool_name]["function"](**actual_args) | |
| logger.info( | |
| f"Attempt {attempt + 1}: Tool result: {json.dumps(tool_result, indent=2)}" | |
| ) # Log full result | |
| # Basic structure check before correctness check | |
| assert isinstance(tool_result, dict), "Tool result should be a dictionary" | |
| assert isinstance(tool_result.get("answer"), str), ( | |
| "'answer' should be a string" | |
| ) | |
| assert isinstance(tool_result.get("sources"), list), ( | |
| "'sources' should be a list" | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Attempt {attempt + 1}: Error executing or validating tool '{tool_name}' with input {actual_args}: {e}", | |
| exc_info=True, | |
| ) | |
| pytest.fail( | |
| f"Attempt {attempt + 1}: Tool execution or basic validation failed: {e}" | |
| ) | |
| # --- Second Call: Perform Correctness Check --- | |
| logger.info(f"\n--- Starting Correctness Check for Attempt {attempt + 1} ---") | |
| try: | |
| correctness_prompt = f""" | |
| Please evaluate if the provided 'Actual Tool Result' provides a helpful and relevant answer to the 'Original User Query'. | |
| The 'Expected Output Hint' gives guidance on what a good answer should contain. | |
| Use the 'check_correctness_tool' to provide your reasoning and conclusion. | |
| Original User Query: | |
| "{query_text}" | |
| Expected Output: | |
| "{expected_output}" | |
| Actual Tool Result from '{tool_name}': | |
| ```json | |
| {json.dumps(tool_result, indent=2)} | |
| ``` | |
| """ | |
| messages_check_call = [{"role": "user", "content": correctness_prompt}] | |
| correctness_response = call_anthropic( | |
| model_name=CORRECTNESS_MODEL_NAME, | |
| messages=messages_check_call, | |
| check_correctness_tool=check_correctness_tool, # Pass the imported tool schema | |
| ) | |
| logger.info( | |
| f"Attempt {attempt + 1}: Correctness check response:\n{correctness_response}\n\n" | |
| ) | |
| _, check_tool_name, check_tool_input, _ = extract_anthropic_tool_use( | |
| correctness_response | |
| ) | |
| assert check_tool_name == "check_correctness_tool", ( | |
| f"Attempt {attempt + 1}: Expected correctness tool, got {check_tool_name}" | |
| ) | |
| assert "reasoning" in check_tool_input, ( | |
| f"Attempt {attempt + 1}: Correctness tool missing 'reasoning'" | |
| ) | |
| assert "is_correct" in check_tool_input, ( | |
| f"Attempt {attempt + 1}: Correctness tool missing 'is_correct'" | |
| ) | |
| last_reasoning = check_tool_input["reasoning"] | |
| last_is_correct = check_tool_input["is_correct"] | |
| logger.info( | |
| f"Attempt {attempt + 1}: Correctness Reasoning: {last_reasoning}" | |
| ) | |
| logger.info( | |
| f"Attempt {attempt + 1}: Is Correct according to LLM: {last_is_correct}" | |
| ) | |
| if last_is_correct: | |
| logger.info( | |
| f"--- Correctness check passed on attempt {attempt + 1}. ---" | |
| ) | |
| break # Exit the loop successfully | |
| except KeyError as e: | |
| logger.error( | |
| f"Attempt {attempt + 1}: Missing expected key in correctness tool input: {e}" | |
| ) | |
| logger.error( | |
| f"Attempt {attempt + 1}: Full input received: {check_tool_input}" | |
| ) | |
| last_is_correct = False | |
| last_reasoning = f"Correctness tool response missing key: {e}" | |
| # Continue loop if retries left, fail otherwise handled after loop | |
| except Exception as e: | |
| logger.error( | |
| f"Attempt {attempt + 1}: Error during correctness check for query '{query_text}': {e}", | |
| exc_info=True, | |
| ) | |
| last_is_correct = False | |
| last_reasoning = f"Correctness check failed with exception: {e}" | |
| # Continue loop if retries left, fail otherwise handled after loop | |
| # --- After the loop, fail the test if the last attempt wasn't correct --- | |
| if not last_is_correct: | |
| pytest.fail( | |
| f"LLM evaluation failed after {max_retries + 1} attempts. " | |
| f"Final is_correct_flag is `{last_is_correct}`. " | |
| f"Final Reasoning: '{last_reasoning}'" | |
| ) | |
| # If we reach here, it means the correctness check passed within the allowed attempts. | |
| logger.info("--- Test passed within allowed attempts. ---") | |