import base64 import json import os from typing import Any, Dict import requests from wandb_mcp_server.weave_api.query_builder import QueryBuilder from wandb_mcp_server.mcp_tools.tools_utils import get_retry_session from wandb_mcp_server.utils import get_rich_logger from wandb_mcp_server.api_client import WandBApiManager logger = get_rich_logger(__name__) COUNT_WEAVE_TRACES_TOOL_DESCRIPTION = """count Weave traces and return the total storage \ size in bytes for the given filters. Use this tool to query data from Weights & Biases Weave, an observability product for tracing and evaluating LLMs and GenAI apps. This tool only provides COUNT information and STORAGE SIZE (bytes) about traces, \ not actual logged traces data, metrics or run data. **IMPORTANT PRODUCT DISTINCTION:** W&B offers two distinct products with different purposes: 1. W&B Models: A system for ML experiment tracking, hyperparameter optimization, and model lifecycle management. Use `query_wandb_tool` for questions about: - Experiment runs, metrics, and performance comparisons - Artifact management and model registry - Hyperparameter optimization and sweeps - Project dashboards and reports 2. W&B Weave: A toolkit for LLM and GenAI application observability and evaluation. Use `query_weave_traces_tool` (this tool) for questions about: - Execution traces and paths of LLM operations - LLM inputs, outputs, and intermediate results - Chain of thought visualization and debugging - LLM evaluation results and feedback **USE CASE SELECTOR - READ FIRST:** - For runs, metrics, experiments, artifacts, sweeps etc → use query_wandb_tool - For traces, LLM calls, chain-of-thought, LLM evaluations, AI agent traces, AI apps etc → use query_weave_traces_tool ===================================================================== ⚠️ TOOL SELECTION WARNING ⚠️ This tool is ONLY for WEAVE TRACES (LLM operations), NOT for run metrics or experiments! ===================================================================== **KEYWORD GUIDE:** If user question contains: - "runs", "experiments", "metrics" → Use query_wandb_tool - "traces", "LLM calls" etc → Use this tool **COMMON MISUSE CASES:** ❌ "Looking at metrics of my latest runs" - Do NOT use this tool, use query_wandb_tool instead ❌ "Compare performance across experiments" - Do NOT use this tool, use query_wandb_tool instead Returns the total number of traces in a project and the number of root (i.e. "parent" or top-level) traces. This is more efficient than query_trace_tool when you only need the count. This can be useful to understand how many traces are in a project before querying for them as query_trace_tool can return a lot of data. Parameters ---------- entity_name : str The Weights & Biases entity name (team or username). project_name : str The Weights & Biases project name. filters : Dict[str, Any], optional Dict of filter conditions, supporting: - display_name: Filter by display name (string or regex pattern) - op_name_contains: Filter for op_name containing a substring. Not a good idea to use in conjunction with trace_roots_only. - trace_id: Filter by specific trace ID - status: Filter by trace status ('success', 'error', etc.) - time_range: Dict with "start" and "end" datetime strings - latency: Filter by latency in milliseconds (summary.weave.latency_ms). Use a nested dict with operators: $gt, $lt, $eq, $gte, $lte. ($lt and $lte are implemented via logical negation on the backend). e.g., {"latency": {"$gt": 5000}} - attributes: Dict of attribute path and value/operator to match. Supports nested paths (e.g., "metadata.model_name") via dot notation. Value can be literal for equality or a dict with operator ($gt, $lt, $eq, $gte, $lte) for comparison (e.g., {"token_count": {"$gt": 100}}). - has_exception: Boolean to filter traces with/without exceptions - trace_roots_only: Boolean to filter for only top-level (aka parent) traces Returns ------- int The number of traces matching the query parameters. Examples -------- >>> # Count failed traces >>> count = count_traces( ... entity_name="my-team", ... project_name="my-project", ... filters={"status": "error"} ... ) >>> # Count traces faster than 500ms >>> count = count_traces( ... entity_name="my-team", ... project_name="my-project", ... filters={"latency": {"$lt": 500}} ... ) """ def count_traces( entity_name: str, project_name: str, filters: dict = None, request_timeout: int = 30, ) -> int: """Count the number of traces matching the given filters. Counts without retrieving the full trace data, making it more efficient than `query_traces` when only the count is needed. Parameters ---------- entity_name : str The Weights & Biases entity name (team or username). project_name : str The Weights & Biases project name. filters : Dict[str, Any], optional Dict of filter conditions, supporting: - display_name: Filter by display name (string or regex pattern) - op_name_contains: Filter for op_name containing a substring - trace_id: Filter by specific trace ID - status: Filter by trace status ('success', 'error', etc.) - latency: Filter by latency in milliseconds (summary.weave.latency_ms). Use a nested dict with operators: $gt, $lt, $eq, $gte, $lte. Note: $lt and $lte are implemented via logical negation. e.g., {"latency": {"$gt": 5000}} - time_range: Dict with "start" and "end" datetime strings - attributes: Dict of attribute path and value/operator to match. Supports nested paths (e.g., "metadata.model_name") via dot notation. Value can be literal for equality or a dict with operator ($gt, $lt, $eq, $gte, $lte) for comparison (e.g., {"token_count": {"$gt": 100}}). - has_exception: Boolean to filter traces with/without exceptions - trace_roots_only: Boolean to filter for only top-level (aka parent) traces request_timeout : int, optional Timeout for the HTTP request in seconds. Defaults to 30. Returns ------- int The number of traces matching the query parameters. Examples -------- >>> # Count failed traces >>> count = count_traces( ... entity_name="my-team", ... project_name="my-project", ... filters={"status": "error"} ... ) >>> # Count traces matching an attribute and latency > 1s >>> count = count_traces( ... entity_name="my-team", ... project_name="my-project", ... filters={ ... "attributes": {"metadata.environment": "production"}, ... "latency": {"$gt": 1000} ... } ... ) """ project_id = f"{entity_name}/{project_name}" # Get API key from context (set by auth middleware) or environment api_key = WandBApiManager.get_api_key() if not api_key: logger.error("W&B API key not found in context or environment variables.") raise ValueError("W&B API key is required to query Weave traces count.") # Debug logging to diagnose API key issues logger.debug(f"Using W&B API key: length={len(api_key)}, " f"first_6={api_key[:6] if len(api_key) >= 6 else 'N/A'}..., " f"last_4={api_key[-4:] if len(api_key) >= 4 else 'N/A'}") request_body: Dict[str, Any] = {"project_id": project_id} filter_payload: Dict[ str, Any ] = {} # For fields that go into the top-level 'filter' object complex_filters_for_query_expr: Dict[ str, Any ] = {} # For fields that go into query.$expr if filters: # Keys that belong inside the 'filter' object in the request body # as per https://weave-docs.wandb.ai/reference/service-api/calls-query-stats-calls-query_stats-post direct_filter_keys = { "op_names", "op_name", # op_name will be converted to op_names list "input_refs", "output_refs", "parent_ids", "trace_ids", "trace_id", # trace_id will be converted to trace_ids list "call_ids", "trace_roots_only", "wb_user_ids", "wb_run_ids", } temp_op_names = [] if "op_name" in filters: temp_op_names.append(filters["op_name"]) if "op_names" in filters: val = filters["op_names"] if isinstance(val, list): temp_op_names.extend(val) else: temp_op_names.append(val) if temp_op_names: filter_payload["op_names"] = list(set(temp_op_names)) temp_trace_ids = [] if "trace_id" in filters: temp_trace_ids.append(filters["trace_id"]) if "trace_ids" in filters: val = filters["trace_ids"] if isinstance(val, list): temp_trace_ids.extend(val) else: temp_trace_ids.append(val) if temp_trace_ids: filter_payload["trace_ids"] = list(set(temp_trace_ids)) # Handle other direct filter keys for key in [ "input_refs", "output_refs", "parent_ids", "call_ids", "wb_user_ids", "wb_run_ids", ]: if key in filters: value = filters[key] filter_payload[key] = [value] if not isinstance(value, list) else value if "trace_roots_only" in filters: filter_payload["trace_roots_only"] = filters["trace_roots_only"] # Per docs, trace_roots_only is a boolean, not a list. # If not in filters, it's omitted, API default (false) should apply. # Populate complex_filters_for_query_expr for remaining keys for key, value in filters.items(): # Skip keys already handled in direct_filter_keys or their singular versions if key not in direct_filter_keys and key not in ["op_name", "trace_id"]: complex_filters_for_query_expr[key] = value # Add the constructed filter_payload to the main request_body if it's not empty if filter_payload: request_body["filter"] = filter_payload # Build the query expression from remaining complex filters if complex_filters_for_query_expr: query_expr_obj = QueryBuilder.build_query_expression( complex_filters_for_query_expr ) if query_expr_obj: dumped_query = query_expr_obj.model_dump(by_alias=True, exclude_none=True) if dumped_query and dumped_query.get("$expr"): request_body["query"] = dumped_query # Execute the HTTP query weave_server_url = os.environ.get( "WEAVE_TRACE_SERVER_URL", "https://trace.wandb.ai" ) url = f"{weave_server_url}/calls/query_stats" auth_token = base64.b64encode(f":{api_key}".encode()).decode() headers = { "Content-Type": "application/json", "Accept": "application/json", # /calls/query_stats returns application/json "Authorization": f"Basic {auth_token}", } session = get_retry_session() logger.debug(f"Posting to {url} with body: {json.dumps(request_body)}") try: response = session.post( url, headers=headers, data=json.dumps(request_body), # Ensure complex objects are serialized timeout=request_timeout, ) if response.status_code != 200: error_msg = f"Error querying Weave trace count: {response.status_code} - {response.text}" logger.error(error_msg) # Log API key info for debugging logger.error(f"API key info: length={len(api_key)}, is_40_chars={len(api_key) == 40}") if "40 characters" in response.text: logger.error(f"W&B requires exactly 40 character API keys. Current key has {len(api_key)} characters.") logger.error(f"Key preview: {api_key[:8]}...{api_key[-4:] if len(api_key) >= 12 else ''}") # Log request body for easier debugging on error logger.debug(f"Failed request body: {json.dumps(request_body)}") raise Exception(error_msg) response_json = response.json() return response_json.get("count", 0) # Default to 0 if count is not in response except requests.exceptions.RequestException as e: logger.error(f"HTTP Request failed for project {project_id}: {e}") if isinstance(e, requests.exceptions.RetryError): if e.__cause__ and hasattr(e.__cause__, "reason") and e.__cause__.reason: logger.error( f"Specific reason for retry exhaustion: {e.__cause__.reason}" ) logger.debug( f"Failed request body during exception for {project_id}: {json.dumps(request_body)}" ) # traceback.print_exc() # Uncomment for detailed traceback during development raise Exception( f"Failed to query Weave trace count for {project_id} due to network error: {e}" ) except json.JSONDecodeError as e: logger.error( f"Failed to decode JSON response for {project_id}: {e}. Response text: {response.text if 'response' in locals() else 'N/A'}" ) raise Exception(f"Failed to parse Weave API response for {project_id}: {e}")