Spaces:
Paused
Paused
| """ | |
| Service layer for Weave API. | |
| This module provides high-level services for querying and processing Weave traces. | |
| It orchestrates the client, query builder, and processor components. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Set | |
| from wandb_mcp_server.utils import get_rich_logger, get_server_args | |
| from wandb_mcp_server.weave_api.client import WeaveApiClient | |
| from wandb_mcp_server.weave_api.models import QueryResult | |
| from wandb_mcp_server.weave_api.processors import TraceProcessor | |
| from wandb_mcp_server.weave_api.query_builder import QueryBuilder | |
| # Import CallSchema to validate column names | |
| try: | |
| from weave.trace_server.trace_server_interface import CallSchema | |
| VALID_COLUMNS = set(CallSchema.__annotations__.keys()) | |
| HAVE_CALL_SCHEMA = True | |
| except ImportError: | |
| # Fallback if CallSchema isn't available | |
| VALID_COLUMNS = { | |
| "id", | |
| "project_id", | |
| "op_name", | |
| "display_name", | |
| "trace_id", | |
| "parent_id", | |
| "started_at", | |
| "attributes", | |
| "inputs", | |
| "ended_at", | |
| "exception", | |
| "output", | |
| "summary", | |
| "wb_user_id", | |
| "wb_run_id", | |
| "deleted_at", | |
| "storage_size_bytes", | |
| "total_storage_size_bytes", | |
| } | |
| HAVE_CALL_SCHEMA = False | |
| logger = get_rich_logger(__name__) | |
| class TraceService: | |
| """Service for querying and processing Weave traces.""" | |
| # Define cost fields once as a class constant | |
| COST_FIELDS = {"total_cost", "completion_cost", "prompt_cost"} | |
| # Define synthetic columns that shouldn't be passed to the API but can be reconstructed | |
| SYNTHETIC_COLUMNS = {"costs"} | |
| # Define latency field mapping | |
| LATENCY_FIELD_MAPPING = {"latency_ms": "summary.weave.latency_ms"} | |
| def __init__( | |
| self, | |
| api_key: Optional[str] = None, | |
| server_url: Optional[str] = None, | |
| retries: int = 3, | |
| timeout: int = 10, | |
| ): | |
| """Initialize the TraceService. | |
| Args: | |
| api_key: W&B API key. If not provided, uses WANDB_API_KEY env var. | |
| server_url: Weave API server URL. Defaults to 'https://trace.wandb.ai'. | |
| retries: Number of retries for failed requests. | |
| timeout: Request timeout in seconds. | |
| """ | |
| # If no API key provided, try to get from environment | |
| if api_key is None: | |
| import os | |
| # Try to get from environment (set by auth middleware for HTTP or user for STDIO) | |
| api_key = os.environ.get("WANDB_API_KEY") | |
| # If still no key, try get_server_args as fallback | |
| if not api_key: | |
| server_config = get_server_args() | |
| api_key = server_config.wandb_api_key | |
| # Pass the resolved API key to WeaveApiClient. | |
| # If api_key is None or "", WeaveApiClient will raise its ValueError. | |
| self.client = WeaveApiClient( | |
| api_key=api_key, | |
| server_url=server_url, | |
| retries=retries, | |
| timeout=timeout, | |
| ) | |
| # Initialize collection for invalid columns (for warning messages) | |
| self.invalid_columns = set() | |
| def _validate_and_filter_columns( | |
| self, columns: Optional[List[str]] | |
| ) -> tuple[Optional[List[str]], List[str], Set[str]]: | |
| """Validate columns against CallSchema and filter out synthetic/invalid columns. | |
| Handles mapping of 'latency_ms' to 'summary.weave.latency_ms'. | |
| Args: | |
| columns: List of columns. | |
| Returns: | |
| Tuple of (filtered_columns_for_api, requested_synthetic_columns, invalid_columns_reported) | |
| """ | |
| if not columns: | |
| return ( | |
| None, | |
| [], | |
| set(), | |
| ) # Return None for filtered_columns_for_api if input is None | |
| filtered_columns_for_api: list[str] = [] | |
| requested_synthetic_columns: list[str] = [] | |
| invalid_columns_reported: set[str] = set() | |
| processed_columns = ( | |
| set() | |
| ) # To avoid duplicate processing if a column is listed multiple times | |
| for col_name in columns: | |
| if col_name in processed_columns: | |
| continue | |
| processed_columns.add(col_name) | |
| if col_name == "latency_ms": | |
| # 'latency_ms' is synthetic, its data comes from 'summary.weave.latency_ms' | |
| requested_synthetic_columns.append("latency_ms") | |
| # Ensure the source field is requested from the API | |
| source_field = self.LATENCY_FIELD_MAPPING["latency_ms"] | |
| if source_field not in filtered_columns_for_api: | |
| filtered_columns_for_api.append(source_field) | |
| # Also ensure 'summary' itself is added if not already, as 'summary.weave.latency_ms' implies 'summary' | |
| if ( | |
| "summary" not in filtered_columns_for_api | |
| and source_field.startswith("summary.") | |
| ): | |
| filtered_columns_for_api.append("summary") | |
| logger.info( | |
| f"Column 'latency_ms' requested: will be synthesized from '{source_field}'. Added '{source_field}' to API columns." | |
| ) | |
| elif col_name == "costs": | |
| # 'costs' is synthetic, its data comes from 'summary.weave.costs' | |
| requested_synthetic_columns.append("costs") | |
| # Ensure the source field ('summary') is requested | |
| if "summary" not in filtered_columns_for_api: | |
| filtered_columns_for_api.append("summary") | |
| logger.info( | |
| "Column 'costs' requested: will be synthesized from 'summary.weave.costs'. Added 'summary' to API columns." | |
| ) | |
| elif col_name == "status": | |
| # 'status' can be top-level or from 'summary.weave.status' | |
| requested_synthetic_columns.append("status") | |
| # Add 'status' to API columns to try fetching top-level first. | |
| # If not present, it will be synthesized from summary. | |
| if "status" not in filtered_columns_for_api: | |
| filtered_columns_for_api.append("status") | |
| if ( | |
| "summary" not in filtered_columns_for_api | |
| ): # Also ensure summary for fallback | |
| filtered_columns_for_api.append("summary") | |
| logger.info( | |
| "Column 'status' requested: will attempt direct fetch or synthesize from 'summary.weave.status'." | |
| ) | |
| elif col_name in VALID_COLUMNS: | |
| # Direct valid column | |
| if col_name not in filtered_columns_for_api: | |
| filtered_columns_for_api.append(col_name) | |
| elif "." in col_name: # Potentially a dot-separated path | |
| base_field = col_name.split(".")[0] | |
| if base_field in VALID_COLUMNS: | |
| # Valid nested field (e.g., "summary.weave.latency_ms", "attributes.foo") | |
| if col_name not in filtered_columns_for_api: | |
| filtered_columns_for_api.append(col_name) | |
| logger.info( | |
| f"Nested column field '{col_name}' requested, added to API columns." | |
| ) | |
| else: | |
| logger.warning( | |
| f"Invalid base field '{base_field}' in nested column '{col_name}'. It will be ignored." | |
| ) | |
| invalid_columns_reported.add(col_name) | |
| else: | |
| # Neither a direct valid column, nor a recognized synthetic, nor a valid-looking nested path | |
| logger.warning( | |
| f"Invalid column '{col_name}' requested. It will be ignored." | |
| ) | |
| invalid_columns_reported.add(col_name) | |
| # Ensure filtered_columns_for_api does not have duplicates and maintains order as much as possible | |
| # (though order to the API might not matter as much as presence) | |
| final_filtered_columns_for_api = [] | |
| seen_in_final = set() | |
| for fc in filtered_columns_for_api: | |
| if fc not in seen_in_final: | |
| final_filtered_columns_for_api.append(fc) | |
| seen_in_final.add(fc) | |
| return ( | |
| final_filtered_columns_for_api, | |
| requested_synthetic_columns, | |
| invalid_columns_reported, | |
| ) | |
| def _ensure_required_columns_for_synthetic( | |
| self, | |
| filtered_columns: Optional[List[str]], | |
| requested_synthetic_columns: List[str], | |
| ) -> Optional[List[str]]: | |
| """Ensure required columns for synthetic fields are included. | |
| Args: | |
| filtered_columns: List of columns after filtering out synthetic ones. | |
| requested_synthetic_columns: List of requested synthetic columns. | |
| Returns: | |
| Updated filtered columns list with required columns added. | |
| """ | |
| if not filtered_columns: | |
| filtered_columns = [] | |
| required_columns = set(filtered_columns) | |
| # Add required columns for synthesizing costs | |
| if "costs" in requested_synthetic_columns: | |
| # Costs data comes from summary.weave.costs | |
| if "summary" not in required_columns: | |
| logger.info("Adding 'summary' column as it's required for costs data") | |
| required_columns.add("summary") | |
| # Add other required columns for other synthetic fields as needed | |
| return list(required_columns) | |
| def _add_synthetic_columns( | |
| self, | |
| traces: List[Dict[str, Any]], | |
| requested_synthetic_columns: List[str], | |
| invalid_columns: Set[str], | |
| ) -> List[Dict[str, Any]]: | |
| """Add synthetic columns back to the traces and add warnings for invalid columns. | |
| Args: | |
| traces: List of trace dictionaries. | |
| requested_synthetic_columns: List of requested synthetic columns. | |
| invalid_columns: Set of invalid column names that were requested. | |
| Returns: | |
| Updated traces with synthetic columns added and invalid column warnings. | |
| """ | |
| if not requested_synthetic_columns and not invalid_columns: | |
| return traces | |
| updated_traces = [] | |
| for trace in traces: | |
| updated_trace = trace.copy() | |
| # Add costs data if requested | |
| if "costs" in requested_synthetic_columns: | |
| costs_data = trace.get("summary", {}).get("weave", {}).get("costs", {}) | |
| if costs_data: | |
| logger.debug( | |
| f"Adding synthetic 'costs' column with {len(costs_data)} providers" | |
| ) | |
| updated_trace["costs"] = costs_data | |
| else: | |
| logger.warning(f"No costs data found in trace {trace.get('id')}") | |
| updated_trace["costs"] = {} | |
| # Add status from summary if requested | |
| if "status" in requested_synthetic_columns: | |
| status = trace.get("status") # Check if it's already in the trace | |
| if not status: | |
| # Extract from summary.weave.status | |
| status = trace.get("summary", {}).get("weave", {}).get("status") | |
| if status: | |
| logger.debug( | |
| f"Adding synthetic 'status' from summary: {status}" | |
| ) | |
| updated_trace["status"] = status | |
| else: | |
| logger.warning( | |
| f"No status data found in trace {trace.get('id')}" | |
| ) | |
| updated_trace["status"] = None | |
| # Add latency_ms from summary if requested | |
| if "latency_ms" in requested_synthetic_columns: | |
| latency = trace.get("latency_ms") # Check if it's already in the trace | |
| if latency is None: | |
| # Extract from summary.weave.latency_ms | |
| latency = ( | |
| trace.get("summary", {}).get("weave", {}).get("latency_ms") | |
| ) | |
| if latency is not None: | |
| logger.debug( | |
| f"Adding synthetic 'latency_ms' from summary: {latency}" | |
| ) | |
| updated_trace["latency_ms"] = latency | |
| else: | |
| logger.warning( | |
| f"No latency_ms data found in trace {trace.get('id')}" | |
| ) | |
| updated_trace["latency_ms"] = None | |
| # Add warnings for invalid columns | |
| for col in invalid_columns: | |
| warning_message = f"{col} is not a valid column name, no data returned" | |
| updated_trace[col] = warning_message | |
| updated_traces.append(updated_trace) | |
| return updated_traces | |
| def query_traces( | |
| self, | |
| entity_name: str, | |
| project_name: str, | |
| filters: Optional[Dict[str, Any]] = None, | |
| sort_by: str = "started_at", | |
| sort_direction: str = "desc", | |
| limit: Optional[int] = None, | |
| offset: int = 0, | |
| include_costs: bool = True, | |
| include_feedback: bool = True, | |
| columns: Optional[List[str]] = None, | |
| expand_columns: Optional[List[str]] = None, | |
| truncate_length: Optional[int] = 200, | |
| return_full_data: bool = False, | |
| metadata_only: bool = False, | |
| ) -> QueryResult: | |
| """Query traces from the Weave API. | |
| Args: | |
| entity_name: Weights & Biases entity name. | |
| project_name: Weights & Biands project name. | |
| filters: Dictionary of filter conditions. | |
| sort_by: Field to sort by. | |
| sort_direction: Sort direction ('asc' or 'desc'). | |
| limit: Maximum number of results to return. | |
| offset: Number of results to skip (for pagination). | |
| include_costs: Include tracked API cost information in the results. | |
| include_feedback: Include Weave annotations in the results. | |
| columns: List of specific columns to include in the results. | |
| expand_columns: List of columns to expand in the results. | |
| truncate_length: Maximum length for string values. | |
| return_full_data: Whether to include full untruncated trace data. | |
| metadata_only: Whether to only include metadata without traces. | |
| Returns: | |
| QueryResult object with metadata and optionally traces. | |
| """ | |
| # Clear invalid columns from previous requests | |
| self.invalid_columns = set() | |
| # Special handling for cost-based sorting | |
| client_side_cost_sort = sort_by in self.COST_FIELDS | |
| # Handle latency field mapping | |
| if sort_by in self.LATENCY_FIELD_MAPPING: | |
| logger.info( | |
| f"Mapping sort field '{sort_by}' to '{self.LATENCY_FIELD_MAPPING[sort_by]}'" | |
| ) | |
| server_sort_by = self.LATENCY_FIELD_MAPPING[sort_by] | |
| server_sort_direction = sort_direction | |
| elif client_side_cost_sort: | |
| include_costs = True | |
| server_sort_by = "started_at" | |
| server_sort_direction = sort_direction | |
| elif sort_by == "latency_ms": # Added specific handling for latency_ms sort | |
| logger.info( | |
| f"Sort by 'latency_ms' requested. Will sort by server field '{self.LATENCY_FIELD_MAPPING['latency_ms']}'." | |
| ) | |
| server_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"] | |
| server_sort_direction = sort_direction | |
| elif "." in sort_by: # Handles general dot-separated paths | |
| base_field = sort_by.split(".")[0] | |
| if base_field in VALID_COLUMNS: | |
| logger.info(f"Using nested sort field for server: {sort_by}") | |
| server_sort_by = sort_by | |
| server_sort_direction = sort_direction | |
| else: | |
| logger.warning( | |
| f"Invalid base field '{base_field}' in sort_by '{sort_by}', falling back to 'started_at'." | |
| ) | |
| server_sort_by = "started_at" | |
| server_sort_direction = sort_direction | |
| elif sort_by not in VALID_COLUMNS: | |
| logger.warning( | |
| f"Invalid sort field '{sort_by}', falling back to 'started_at'." | |
| ) | |
| server_sort_by = "started_at" | |
| server_sort_direction = sort_direction | |
| else: # sort_by is in VALID_COLUMNS and not a special case | |
| server_sort_by = sort_by | |
| server_sort_direction = sort_direction | |
| # Validate and filter columns using CallSchema | |
| filtered_api_columns, rs_columns, inv_columns = ( | |
| self._validate_and_filter_columns(columns) | |
| ) | |
| # Store invalid columns for later | |
| self.invalid_columns = inv_columns # Corrected variable name | |
| # If costs was requested as a column (now checked via rs_columns), make sure to include it | |
| if "costs" in rs_columns: # Corrected check | |
| include_costs = True | |
| # Manually add latency_ms to synthetic fields if requested - This is now handled in _validate_and_filter_columns | |
| # if columns and "latency_ms" in columns and "latency_ms" not in requested_synthetic_columns: | |
| # requested_synthetic_columns.append("latency_ms") | |
| # Ensure required columns for synthetic fields are included - This is also largely handled by _validate_and_filter_columns logic | |
| # filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns) | |
| # Prepare query parameters | |
| query_params = { | |
| "entity_name": entity_name, | |
| "project_name": project_name, | |
| "filters": filters or {}, | |
| "sort_by": server_sort_by, | |
| "sort_direction": server_sort_direction, | |
| "limit": None | |
| if client_side_cost_sort | |
| else limit, # No limit if we're sorting by cost | |
| "offset": offset, | |
| "include_costs": include_costs, | |
| "include_feedback": include_feedback, | |
| "columns": filtered_api_columns, # Use the columns intended for the API | |
| "expand_columns": expand_columns, | |
| } | |
| # Build request body | |
| request_body = QueryBuilder.prepare_query_params(query_params) | |
| # Extract synthetic fields if any were specified | |
| synthetic_fields = ( | |
| request_body.pop("_synthetic_fields", []) | |
| if "_synthetic_fields" in request_body | |
| else [] | |
| ) | |
| # Make sure all requested synthetic columns are included in synthetic_fields | |
| for col in rs_columns: # Use rs_columns | |
| if col not in synthetic_fields: | |
| synthetic_fields.append(col) | |
| # Execute query | |
| all_traces = list(self.client.query_traces(request_body)) | |
| # Add synthetic columns and invalid column warnings back to the results | |
| if rs_columns or inv_columns: # Use corrected variables | |
| all_traces = self._add_synthetic_columns( | |
| all_traces, rs_columns, inv_columns | |
| ) | |
| # Client-side cost-based sorting if needed | |
| if client_side_cost_sort and all_traces: | |
| logger.info(f"Performing client-side sorting by {sort_by}") | |
| # Sort traces by cost | |
| all_traces.sort( | |
| key=lambda t: TraceProcessor.get_cost(t, sort_by), | |
| reverse=(sort_direction == "desc"), | |
| ) | |
| # Apply limit if specified | |
| if limit is not None: | |
| all_traces = all_traces[:limit] | |
| # If we need to synthesize fields, do it | |
| if synthetic_fields: | |
| logger.info(f"Synthesizing fields: {synthetic_fields}") | |
| all_traces = [ | |
| TraceProcessor.synthesize_fields(trace, synthetic_fields) | |
| for trace in all_traces | |
| ] | |
| # Process traces | |
| result = TraceProcessor.process_traces( | |
| traces=all_traces, | |
| truncate_length=truncate_length or 0, | |
| return_full_data=return_full_data, | |
| metadata_only=metadata_only, | |
| ) | |
| return result | |
| def query_paginated_traces( | |
| self, | |
| entity_name: str, | |
| project_name: str, | |
| chunk_size: int = 20, | |
| filters: Optional[Dict[str, Any]] = None, | |
| sort_by: str = "started_at", | |
| sort_direction: str = "desc", | |
| target_limit: Optional[int] = None, | |
| include_costs: bool = True, | |
| include_feedback: bool = True, | |
| columns: Optional[List[str]] = None, | |
| expand_columns: Optional[List[str]] = None, | |
| truncate_length: Optional[int] = 200, | |
| return_full_data: bool = False, | |
| metadata_only: bool = False, | |
| ) -> QueryResult: | |
| """Query traces with pagination. | |
| Args: | |
| entity_name: Weights & Biases entity name. | |
| project_name: Weights & Biands project name. | |
| chunk_size: Number of traces to retrieve in each chunk. | |
| filters: Dictionary of filter conditions. | |
| sort_by: Field to sort by. | |
| sort_direction: Sort direction ('asc' or 'desc'). | |
| target_limit: Maximum total number of results to return. | |
| include_costs: Include tracked API cost information in the results. | |
| include_feedback: Include Weave annotations in the results. | |
| columns: List of specific columns to include in the results. | |
| expand_columns: List of columns to expand in the results. | |
| truncate_length: Maximum length for string values. | |
| return_full_data: Whether to include full untruncated trace data. | |
| metadata_only: Whether to only include metadata without traces. | |
| Returns: | |
| QueryResult object with metadata and optionally traces. | |
| """ | |
| # Special handling for cost-based sorting | |
| client_side_cost_sort = sort_by in self.COST_FIELDS | |
| # Determine effective_sort_by for the server | |
| effective_sort_by = "started_at" # Default | |
| if sort_by == "latency_ms": | |
| effective_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"] | |
| logger.info( | |
| f"Paginated sort by 'latency_ms', server will use '{effective_sort_by}'." | |
| ) | |
| elif "." in sort_by: | |
| base_field = sort_by.split(".")[0] | |
| if base_field in VALID_COLUMNS: | |
| effective_sort_by = sort_by | |
| logger.info( | |
| f"Paginated sort by nested field '{sort_by}', server will use it directly." | |
| ) | |
| else: | |
| logger.warning( | |
| f"Paginated sort by invalid nested field '{sort_by}', defaulting to 'started_at'." | |
| ) | |
| elif ( | |
| sort_by in VALID_COLUMNS and sort_by not in self.COST_FIELDS | |
| ): # Exclude COST_FIELDS as they are client-sorted | |
| effective_sort_by = sort_by | |
| elif ( | |
| sort_by not in self.COST_FIELDS | |
| ): # If not valid and not cost, warn and default | |
| logger.warning( | |
| f"Paginated sort by invalid field '{sort_by}', defaulting to 'started_at'." | |
| ) | |
| # Validate and filter columns using CallSchema | |
| # Pass the original 'columns' | |
| filtered_api_columns, rs_columns, inv_columns = ( | |
| self._validate_and_filter_columns(columns) | |
| ) | |
| # Store invalid columns for later | |
| self.invalid_columns = inv_columns # Corrected | |
| # If costs was requested as a column, make sure to include it | |
| if "costs" in rs_columns: # Corrected | |
| include_costs = True | |
| # Ensure required columns for synthetic fields are included - Handled by _validate_and_filter_columns | |
| # filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns) | |
| if client_side_cost_sort: | |
| logger.info(f"Cost-based sorting detected: {sort_by}") | |
| all_traces = self._query_for_cost_sorting( | |
| entity_name=entity_name, | |
| project_name=project_name, | |
| filters=filters, | |
| sort_by=sort_by, | |
| sort_direction=sort_direction, | |
| target_limit=target_limit, | |
| columns=filtered_api_columns, # Pass filtered columns for API | |
| expand_columns=expand_columns, | |
| include_costs=True, | |
| include_feedback=include_feedback, | |
| requested_synthetic_columns=rs_columns, # Pass synthetic columns request | |
| invalid_columns=inv_columns, # Pass invalid columns | |
| ) | |
| else: | |
| # Normal paginated query logic | |
| all_traces = [] | |
| current_offset = 0 | |
| while True: | |
| logger.info( | |
| f"Querying chunk with offset {current_offset}, size {chunk_size}" | |
| ) | |
| remaining = ( | |
| target_limit - len(all_traces) if target_limit else chunk_size | |
| ) | |
| current_chunk_size = ( | |
| min(chunk_size, remaining) if target_limit else chunk_size | |
| ) | |
| chunk_result = self.query_traces( | |
| entity_name=entity_name, | |
| project_name=project_name, | |
| filters=filters, | |
| sort_by=effective_sort_by, | |
| sort_direction=sort_direction, | |
| limit=current_chunk_size, | |
| offset=current_offset, | |
| include_costs=include_costs, | |
| include_feedback=include_feedback, | |
| columns=columns, # Pass original 'columns' here, query_traces will validate and filter. | |
| # This ensures that if 'latency_ms' was requested, it's handled correctly | |
| # by the nested call to _validate_and_filter_columns inside query_traces. | |
| expand_columns=expand_columns, | |
| return_full_data=True, # We want raw data for now | |
| metadata_only=False, | |
| ) | |
| # Get the traces from the QueryResult and handle both None and empty list cases | |
| traces_from_chunk = ( | |
| chunk_result.traces if chunk_result and chunk_result.traces else [] | |
| ) | |
| if not traces_from_chunk: | |
| break | |
| all_traces.extend(traces_from_chunk) | |
| if len(traces_from_chunk) < current_chunk_size or ( | |
| target_limit and len(all_traces) >= target_limit | |
| ): | |
| break | |
| current_offset += chunk_size | |
| # Process all traces at once with appropriate parameters | |
| if target_limit and all_traces: | |
| all_traces = all_traces[:target_limit] | |
| result = TraceProcessor.process_traces( | |
| traces=all_traces, | |
| truncate_length=truncate_length or 0, | |
| return_full_data=return_full_data, | |
| metadata_only=metadata_only, | |
| ) | |
| logger.debug( | |
| f"Final result from query_paginated_traces:\n\n{len(result.model_dump_json(indent=2))}\n" | |
| ) | |
| assert isinstance(result, QueryResult), ( | |
| f"Result type must be a QueryResult, found: {type(result)}" | |
| ) | |
| return result | |
| def _query_for_cost_sorting( | |
| self, | |
| entity_name: str, | |
| project_name: str, | |
| filters: Optional[Dict[str, Any]] = None, | |
| sort_by: str = "total_cost", | |
| sort_direction: str = "desc", | |
| target_limit: Optional[int] = None, | |
| columns: Optional[List[str]] = None, | |
| expand_columns: Optional[List[str]] = None, | |
| include_costs: bool = True, | |
| include_feedback: bool = True, | |
| requested_synthetic_columns: Optional[List[str]] = None, | |
| invalid_columns: Optional[Set[str]] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """Special two-stage query logic for cost-based sorting. | |
| Args: | |
| entity_name: Weights & Biases entity name. | |
| project_name: Weights & Biands project name. | |
| filters: Dictionary of filter conditions. | |
| sort_by: Cost field to sort by. | |
| sort_direction: Sort direction ('asc' or 'desc'). | |
| target_limit: Maximum number of results to return. | |
| columns: List of specific columns to include in the results. | |
| expand_columns: List of columns to expand in the results. | |
| include_costs: Include tracked API cost information in the results. | |
| include_feedback: Include Weave annotations in the results. | |
| requested_synthetic_columns: List of synthetic columns requested by the user. | |
| invalid_columns: Set of invalid column names that were requested. | |
| Returns: | |
| List of trace dictionaries sorted by the specified cost field. | |
| """ | |
| if invalid_columns is None: | |
| invalid_columns = set() | |
| # First pass: Fetch all trace IDs and costs | |
| first_pass_query = { | |
| "entity_name": entity_name, | |
| "project_name": project_name, | |
| "filters": filters or {}, | |
| "sort_by": "started_at", # Use a standard sort for the first pass | |
| "sort_direction": "desc", | |
| "limit": 1000000, # Explicitly set a large limit to get all traces | |
| "include_costs": True, # We need costs for sorting | |
| "include_feedback": False, # Don't need feedback for the first pass | |
| "columns": ["id", "summary"], # Need summary for costs data | |
| } | |
| first_pass_request = QueryBuilder.prepare_query_params(first_pass_query) | |
| first_pass_results = list(self.client.query_traces(first_pass_request)) | |
| logger.info( | |
| f"First pass of cost sorting request retrieved {len(first_pass_results)} traces" | |
| ) | |
| # Filter and sort by cost | |
| filtered_results = [ | |
| t | |
| for t in first_pass_results | |
| if TraceProcessor.get_cost(t, sort_by) is not None | |
| ] | |
| filtered_results.sort( | |
| key=lambda t: TraceProcessor.get_cost(t, sort_by), | |
| reverse=(sort_direction == "desc"), | |
| ) | |
| # Get the IDs of the top N traces | |
| top_ids = ( | |
| [t["id"] for t in filtered_results[:target_limit] if "id" in t] | |
| if target_limit | |
| else [t["id"] for t in filtered_results if "id" in t] | |
| ) | |
| logger.info(f"After sorting by {sort_by}, selected {len(top_ids)} trace IDs") | |
| if not top_ids: | |
| return [] | |
| # Second pass: Fetch the full details for the selected traces | |
| second_pass_query = { | |
| "entity_name": entity_name, | |
| "project_name": project_name, | |
| "filters": {"call_ids": top_ids}, | |
| "include_costs": include_costs, | |
| "include_feedback": include_feedback, | |
| "columns": columns, | |
| "expand_columns": expand_columns, | |
| } | |
| # Make sure we request summary if costs were requested | |
| if requested_synthetic_columns and "costs" in requested_synthetic_columns: | |
| if not columns or "summary" not in columns: | |
| if not second_pass_query["columns"]: | |
| second_pass_query["columns"] = ["summary"] | |
| elif "summary" not in second_pass_query["columns"]: | |
| second_pass_query["columns"].append("summary") | |
| logger.info("Added 'summary' to columns for cost data retrieval") | |
| second_pass_request = QueryBuilder.prepare_query_params(second_pass_query) | |
| second_pass_results = list(self.client.query_traces(second_pass_request)) | |
| logger.info(f"Second pass retrieved {len(second_pass_results)} traces") | |
| # Add synthetic columns and invalid column warnings back to the results | |
| if requested_synthetic_columns or invalid_columns: | |
| second_pass_results = self._add_synthetic_columns( | |
| second_pass_results, | |
| requested_synthetic_columns or [], | |
| invalid_columns, | |
| ) | |
| # Ensure the results are in the same order as the IDs | |
| id_to_index = {id: i for i, id in enumerate(top_ids)} | |
| second_pass_results.sort( | |
| key=lambda t: id_to_index.get(t.get("id"), float("inf")) | |
| ) | |
| return second_pass_results | |