NiWaRe's picture
separate client and server api key for http and weave tracing
7d190a0
raw
history blame
32.9 kB
"""
Service layer for Weave API.
This module provides high-level services for querying and processing Weave traces.
It orchestrates the client, query builder, and processor components.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Set
from wandb_mcp_server.utils import get_rich_logger, get_server_args
from wandb_mcp_server.weave_api.client import WeaveApiClient
from wandb_mcp_server.weave_api.models import QueryResult
from wandb_mcp_server.weave_api.processors import TraceProcessor
from wandb_mcp_server.weave_api.query_builder import QueryBuilder
# Import CallSchema to validate column names
try:
from weave.trace_server.trace_server_interface import CallSchema
VALID_COLUMNS = set(CallSchema.__annotations__.keys())
HAVE_CALL_SCHEMA = True
except ImportError:
# Fallback if CallSchema isn't available
VALID_COLUMNS = {
"id",
"project_id",
"op_name",
"display_name",
"trace_id",
"parent_id",
"started_at",
"attributes",
"inputs",
"ended_at",
"exception",
"output",
"summary",
"wb_user_id",
"wb_run_id",
"deleted_at",
"storage_size_bytes",
"total_storage_size_bytes",
}
HAVE_CALL_SCHEMA = False
logger = get_rich_logger(__name__)
class TraceService:
"""Service for querying and processing Weave traces."""
# Define cost fields once as a class constant
COST_FIELDS = {"total_cost", "completion_cost", "prompt_cost"}
# Define synthetic columns that shouldn't be passed to the API but can be reconstructed
SYNTHETIC_COLUMNS = {"costs"}
# Define latency field mapping
LATENCY_FIELD_MAPPING = {"latency_ms": "summary.weave.latency_ms"}
def __init__(
self,
api_key: Optional[str] = None,
server_url: Optional[str] = None,
retries: int = 3,
timeout: int = 10,
):
"""Initialize the TraceService.
Args:
api_key: W&B API key. If not provided, uses WANDB_API_KEY env var.
server_url: Weave API server URL. Defaults to 'https://trace.wandb.ai'.
retries: Number of retries for failed requests.
timeout: Request timeout in seconds.
"""
# If no API key provided, try to get from environment
if api_key is None:
import os
# Try to get from environment (set by auth middleware for HTTP or user for STDIO)
api_key = os.environ.get("WANDB_API_KEY")
# If still no key, try get_server_args as fallback
if not api_key:
server_config = get_server_args()
api_key = server_config.wandb_api_key
# Pass the resolved API key to WeaveApiClient.
# If api_key is None or "", WeaveApiClient will raise its ValueError.
self.client = WeaveApiClient(
api_key=api_key,
server_url=server_url,
retries=retries,
timeout=timeout,
)
# Initialize collection for invalid columns (for warning messages)
self.invalid_columns = set()
def _validate_and_filter_columns(
self, columns: Optional[List[str]]
) -> tuple[Optional[List[str]], List[str], Set[str]]:
"""Validate columns against CallSchema and filter out synthetic/invalid columns.
Handles mapping of 'latency_ms' to 'summary.weave.latency_ms'.
Args:
columns: List of columns.
Returns:
Tuple of (filtered_columns_for_api, requested_synthetic_columns, invalid_columns_reported)
"""
if not columns:
return (
None,
[],
set(),
) # Return None for filtered_columns_for_api if input is None
filtered_columns_for_api: list[str] = []
requested_synthetic_columns: list[str] = []
invalid_columns_reported: set[str] = set()
processed_columns = (
set()
) # To avoid duplicate processing if a column is listed multiple times
for col_name in columns:
if col_name in processed_columns:
continue
processed_columns.add(col_name)
if col_name == "latency_ms":
# 'latency_ms' is synthetic, its data comes from 'summary.weave.latency_ms'
requested_synthetic_columns.append("latency_ms")
# Ensure the source field is requested from the API
source_field = self.LATENCY_FIELD_MAPPING["latency_ms"]
if source_field not in filtered_columns_for_api:
filtered_columns_for_api.append(source_field)
# Also ensure 'summary' itself is added if not already, as 'summary.weave.latency_ms' implies 'summary'
if (
"summary" not in filtered_columns_for_api
and source_field.startswith("summary.")
):
filtered_columns_for_api.append("summary")
logger.info(
f"Column 'latency_ms' requested: will be synthesized from '{source_field}'. Added '{source_field}' to API columns."
)
elif col_name == "costs":
# 'costs' is synthetic, its data comes from 'summary.weave.costs'
requested_synthetic_columns.append("costs")
# Ensure the source field ('summary') is requested
if "summary" not in filtered_columns_for_api:
filtered_columns_for_api.append("summary")
logger.info(
"Column 'costs' requested: will be synthesized from 'summary.weave.costs'. Added 'summary' to API columns."
)
elif col_name == "status":
# 'status' can be top-level or from 'summary.weave.status'
requested_synthetic_columns.append("status")
# Add 'status' to API columns to try fetching top-level first.
# If not present, it will be synthesized from summary.
if "status" not in filtered_columns_for_api:
filtered_columns_for_api.append("status")
if (
"summary" not in filtered_columns_for_api
): # Also ensure summary for fallback
filtered_columns_for_api.append("summary")
logger.info(
"Column 'status' requested: will attempt direct fetch or synthesize from 'summary.weave.status'."
)
elif col_name in VALID_COLUMNS:
# Direct valid column
if col_name not in filtered_columns_for_api:
filtered_columns_for_api.append(col_name)
elif "." in col_name: # Potentially a dot-separated path
base_field = col_name.split(".")[0]
if base_field in VALID_COLUMNS:
# Valid nested field (e.g., "summary.weave.latency_ms", "attributes.foo")
if col_name not in filtered_columns_for_api:
filtered_columns_for_api.append(col_name)
logger.info(
f"Nested column field '{col_name}' requested, added to API columns."
)
else:
logger.warning(
f"Invalid base field '{base_field}' in nested column '{col_name}'. It will be ignored."
)
invalid_columns_reported.add(col_name)
else:
# Neither a direct valid column, nor a recognized synthetic, nor a valid-looking nested path
logger.warning(
f"Invalid column '{col_name}' requested. It will be ignored."
)
invalid_columns_reported.add(col_name)
# Ensure filtered_columns_for_api does not have duplicates and maintains order as much as possible
# (though order to the API might not matter as much as presence)
final_filtered_columns_for_api = []
seen_in_final = set()
for fc in filtered_columns_for_api:
if fc not in seen_in_final:
final_filtered_columns_for_api.append(fc)
seen_in_final.add(fc)
return (
final_filtered_columns_for_api,
requested_synthetic_columns,
invalid_columns_reported,
)
def _ensure_required_columns_for_synthetic(
self,
filtered_columns: Optional[List[str]],
requested_synthetic_columns: List[str],
) -> Optional[List[str]]:
"""Ensure required columns for synthetic fields are included.
Args:
filtered_columns: List of columns after filtering out synthetic ones.
requested_synthetic_columns: List of requested synthetic columns.
Returns:
Updated filtered columns list with required columns added.
"""
if not filtered_columns:
filtered_columns = []
required_columns = set(filtered_columns)
# Add required columns for synthesizing costs
if "costs" in requested_synthetic_columns:
# Costs data comes from summary.weave.costs
if "summary" not in required_columns:
logger.info("Adding 'summary' column as it's required for costs data")
required_columns.add("summary")
# Add other required columns for other synthetic fields as needed
return list(required_columns)
def _add_synthetic_columns(
self,
traces: List[Dict[str, Any]],
requested_synthetic_columns: List[str],
invalid_columns: Set[str],
) -> List[Dict[str, Any]]:
"""Add synthetic columns back to the traces and add warnings for invalid columns.
Args:
traces: List of trace dictionaries.
requested_synthetic_columns: List of requested synthetic columns.
invalid_columns: Set of invalid column names that were requested.
Returns:
Updated traces with synthetic columns added and invalid column warnings.
"""
if not requested_synthetic_columns and not invalid_columns:
return traces
updated_traces = []
for trace in traces:
updated_trace = trace.copy()
# Add costs data if requested
if "costs" in requested_synthetic_columns:
costs_data = trace.get("summary", {}).get("weave", {}).get("costs", {})
if costs_data:
logger.debug(
f"Adding synthetic 'costs' column with {len(costs_data)} providers"
)
updated_trace["costs"] = costs_data
else:
logger.warning(f"No costs data found in trace {trace.get('id')}")
updated_trace["costs"] = {}
# Add status from summary if requested
if "status" in requested_synthetic_columns:
status = trace.get("status") # Check if it's already in the trace
if not status:
# Extract from summary.weave.status
status = trace.get("summary", {}).get("weave", {}).get("status")
if status:
logger.debug(
f"Adding synthetic 'status' from summary: {status}"
)
updated_trace["status"] = status
else:
logger.warning(
f"No status data found in trace {trace.get('id')}"
)
updated_trace["status"] = None
# Add latency_ms from summary if requested
if "latency_ms" in requested_synthetic_columns:
latency = trace.get("latency_ms") # Check if it's already in the trace
if latency is None:
# Extract from summary.weave.latency_ms
latency = (
trace.get("summary", {}).get("weave", {}).get("latency_ms")
)
if latency is not None:
logger.debug(
f"Adding synthetic 'latency_ms' from summary: {latency}"
)
updated_trace["latency_ms"] = latency
else:
logger.warning(
f"No latency_ms data found in trace {trace.get('id')}"
)
updated_trace["latency_ms"] = None
# Add warnings for invalid columns
for col in invalid_columns:
warning_message = f"{col} is not a valid column name, no data returned"
updated_trace[col] = warning_message
updated_traces.append(updated_trace)
return updated_traces
def query_traces(
self,
entity_name: str,
project_name: str,
filters: Optional[Dict[str, Any]] = None,
sort_by: str = "started_at",
sort_direction: str = "desc",
limit: Optional[int] = None,
offset: int = 0,
include_costs: bool = True,
include_feedback: bool = True,
columns: Optional[List[str]] = None,
expand_columns: Optional[List[str]] = None,
truncate_length: Optional[int] = 200,
return_full_data: bool = False,
metadata_only: bool = False,
) -> QueryResult:
"""Query traces from the Weave API.
Args:
entity_name: Weights & Biases entity name.
project_name: Weights & Biands project name.
filters: Dictionary of filter conditions.
sort_by: Field to sort by.
sort_direction: Sort direction ('asc' or 'desc').
limit: Maximum number of results to return.
offset: Number of results to skip (for pagination).
include_costs: Include tracked API cost information in the results.
include_feedback: Include Weave annotations in the results.
columns: List of specific columns to include in the results.
expand_columns: List of columns to expand in the results.
truncate_length: Maximum length for string values.
return_full_data: Whether to include full untruncated trace data.
metadata_only: Whether to only include metadata without traces.
Returns:
QueryResult object with metadata and optionally traces.
"""
# Clear invalid columns from previous requests
self.invalid_columns = set()
# Special handling for cost-based sorting
client_side_cost_sort = sort_by in self.COST_FIELDS
# Handle latency field mapping
if sort_by in self.LATENCY_FIELD_MAPPING:
logger.info(
f"Mapping sort field '{sort_by}' to '{self.LATENCY_FIELD_MAPPING[sort_by]}'"
)
server_sort_by = self.LATENCY_FIELD_MAPPING[sort_by]
server_sort_direction = sort_direction
elif client_side_cost_sort:
include_costs = True
server_sort_by = "started_at"
server_sort_direction = sort_direction
elif sort_by == "latency_ms": # Added specific handling for latency_ms sort
logger.info(
f"Sort by 'latency_ms' requested. Will sort by server field '{self.LATENCY_FIELD_MAPPING['latency_ms']}'."
)
server_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"]
server_sort_direction = sort_direction
elif "." in sort_by: # Handles general dot-separated paths
base_field = sort_by.split(".")[0]
if base_field in VALID_COLUMNS:
logger.info(f"Using nested sort field for server: {sort_by}")
server_sort_by = sort_by
server_sort_direction = sort_direction
else:
logger.warning(
f"Invalid base field '{base_field}' in sort_by '{sort_by}', falling back to 'started_at'."
)
server_sort_by = "started_at"
server_sort_direction = sort_direction
elif sort_by not in VALID_COLUMNS:
logger.warning(
f"Invalid sort field '{sort_by}', falling back to 'started_at'."
)
server_sort_by = "started_at"
server_sort_direction = sort_direction
else: # sort_by is in VALID_COLUMNS and not a special case
server_sort_by = sort_by
server_sort_direction = sort_direction
# Validate and filter columns using CallSchema
filtered_api_columns, rs_columns, inv_columns = (
self._validate_and_filter_columns(columns)
)
# Store invalid columns for later
self.invalid_columns = inv_columns # Corrected variable name
# If costs was requested as a column (now checked via rs_columns), make sure to include it
if "costs" in rs_columns: # Corrected check
include_costs = True
# Manually add latency_ms to synthetic fields if requested - This is now handled in _validate_and_filter_columns
# if columns and "latency_ms" in columns and "latency_ms" not in requested_synthetic_columns:
# requested_synthetic_columns.append("latency_ms")
# Ensure required columns for synthetic fields are included - This is also largely handled by _validate_and_filter_columns logic
# filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns)
# Prepare query parameters
query_params = {
"entity_name": entity_name,
"project_name": project_name,
"filters": filters or {},
"sort_by": server_sort_by,
"sort_direction": server_sort_direction,
"limit": None
if client_side_cost_sort
else limit, # No limit if we're sorting by cost
"offset": offset,
"include_costs": include_costs,
"include_feedback": include_feedback,
"columns": filtered_api_columns, # Use the columns intended for the API
"expand_columns": expand_columns,
}
# Build request body
request_body = QueryBuilder.prepare_query_params(query_params)
# Extract synthetic fields if any were specified
synthetic_fields = (
request_body.pop("_synthetic_fields", [])
if "_synthetic_fields" in request_body
else []
)
# Make sure all requested synthetic columns are included in synthetic_fields
for col in rs_columns: # Use rs_columns
if col not in synthetic_fields:
synthetic_fields.append(col)
# Execute query
all_traces = list(self.client.query_traces(request_body))
# Add synthetic columns and invalid column warnings back to the results
if rs_columns or inv_columns: # Use corrected variables
all_traces = self._add_synthetic_columns(
all_traces, rs_columns, inv_columns
)
# Client-side cost-based sorting if needed
if client_side_cost_sort and all_traces:
logger.info(f"Performing client-side sorting by {sort_by}")
# Sort traces by cost
all_traces.sort(
key=lambda t: TraceProcessor.get_cost(t, sort_by),
reverse=(sort_direction == "desc"),
)
# Apply limit if specified
if limit is not None:
all_traces = all_traces[:limit]
# If we need to synthesize fields, do it
if synthetic_fields:
logger.info(f"Synthesizing fields: {synthetic_fields}")
all_traces = [
TraceProcessor.synthesize_fields(trace, synthetic_fields)
for trace in all_traces
]
# Process traces
result = TraceProcessor.process_traces(
traces=all_traces,
truncate_length=truncate_length or 0,
return_full_data=return_full_data,
metadata_only=metadata_only,
)
return result
def query_paginated_traces(
self,
entity_name: str,
project_name: str,
chunk_size: int = 20,
filters: Optional[Dict[str, Any]] = None,
sort_by: str = "started_at",
sort_direction: str = "desc",
target_limit: Optional[int] = None,
include_costs: bool = True,
include_feedback: bool = True,
columns: Optional[List[str]] = None,
expand_columns: Optional[List[str]] = None,
truncate_length: Optional[int] = 200,
return_full_data: bool = False,
metadata_only: bool = False,
) -> QueryResult:
"""Query traces with pagination.
Args:
entity_name: Weights & Biases entity name.
project_name: Weights & Biands project name.
chunk_size: Number of traces to retrieve in each chunk.
filters: Dictionary of filter conditions.
sort_by: Field to sort by.
sort_direction: Sort direction ('asc' or 'desc').
target_limit: Maximum total number of results to return.
include_costs: Include tracked API cost information in the results.
include_feedback: Include Weave annotations in the results.
columns: List of specific columns to include in the results.
expand_columns: List of columns to expand in the results.
truncate_length: Maximum length for string values.
return_full_data: Whether to include full untruncated trace data.
metadata_only: Whether to only include metadata without traces.
Returns:
QueryResult object with metadata and optionally traces.
"""
# Special handling for cost-based sorting
client_side_cost_sort = sort_by in self.COST_FIELDS
# Determine effective_sort_by for the server
effective_sort_by = "started_at" # Default
if sort_by == "latency_ms":
effective_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"]
logger.info(
f"Paginated sort by 'latency_ms', server will use '{effective_sort_by}'."
)
elif "." in sort_by:
base_field = sort_by.split(".")[0]
if base_field in VALID_COLUMNS:
effective_sort_by = sort_by
logger.info(
f"Paginated sort by nested field '{sort_by}', server will use it directly."
)
else:
logger.warning(
f"Paginated sort by invalid nested field '{sort_by}', defaulting to 'started_at'."
)
elif (
sort_by in VALID_COLUMNS and sort_by not in self.COST_FIELDS
): # Exclude COST_FIELDS as they are client-sorted
effective_sort_by = sort_by
elif (
sort_by not in self.COST_FIELDS
): # If not valid and not cost, warn and default
logger.warning(
f"Paginated sort by invalid field '{sort_by}', defaulting to 'started_at'."
)
# Validate and filter columns using CallSchema
# Pass the original 'columns'
filtered_api_columns, rs_columns, inv_columns = (
self._validate_and_filter_columns(columns)
)
# Store invalid columns for later
self.invalid_columns = inv_columns # Corrected
# If costs was requested as a column, make sure to include it
if "costs" in rs_columns: # Corrected
include_costs = True
# Ensure required columns for synthetic fields are included - Handled by _validate_and_filter_columns
# filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns)
if client_side_cost_sort:
logger.info(f"Cost-based sorting detected: {sort_by}")
all_traces = self._query_for_cost_sorting(
entity_name=entity_name,
project_name=project_name,
filters=filters,
sort_by=sort_by,
sort_direction=sort_direction,
target_limit=target_limit,
columns=filtered_api_columns, # Pass filtered columns for API
expand_columns=expand_columns,
include_costs=True,
include_feedback=include_feedback,
requested_synthetic_columns=rs_columns, # Pass synthetic columns request
invalid_columns=inv_columns, # Pass invalid columns
)
else:
# Normal paginated query logic
all_traces = []
current_offset = 0
while True:
logger.info(
f"Querying chunk with offset {current_offset}, size {chunk_size}"
)
remaining = (
target_limit - len(all_traces) if target_limit else chunk_size
)
current_chunk_size = (
min(chunk_size, remaining) if target_limit else chunk_size
)
chunk_result = self.query_traces(
entity_name=entity_name,
project_name=project_name,
filters=filters,
sort_by=effective_sort_by,
sort_direction=sort_direction,
limit=current_chunk_size,
offset=current_offset,
include_costs=include_costs,
include_feedback=include_feedback,
columns=columns, # Pass original 'columns' here, query_traces will validate and filter.
# This ensures that if 'latency_ms' was requested, it's handled correctly
# by the nested call to _validate_and_filter_columns inside query_traces.
expand_columns=expand_columns,
return_full_data=True, # We want raw data for now
metadata_only=False,
)
# Get the traces from the QueryResult and handle both None and empty list cases
traces_from_chunk = (
chunk_result.traces if chunk_result and chunk_result.traces else []
)
if not traces_from_chunk:
break
all_traces.extend(traces_from_chunk)
if len(traces_from_chunk) < current_chunk_size or (
target_limit and len(all_traces) >= target_limit
):
break
current_offset += chunk_size
# Process all traces at once with appropriate parameters
if target_limit and all_traces:
all_traces = all_traces[:target_limit]
result = TraceProcessor.process_traces(
traces=all_traces,
truncate_length=truncate_length or 0,
return_full_data=return_full_data,
metadata_only=metadata_only,
)
logger.debug(
f"Final result from query_paginated_traces:\n\n{len(result.model_dump_json(indent=2))}\n"
)
assert isinstance(result, QueryResult), (
f"Result type must be a QueryResult, found: {type(result)}"
)
return result
def _query_for_cost_sorting(
self,
entity_name: str,
project_name: str,
filters: Optional[Dict[str, Any]] = None,
sort_by: str = "total_cost",
sort_direction: str = "desc",
target_limit: Optional[int] = None,
columns: Optional[List[str]] = None,
expand_columns: Optional[List[str]] = None,
include_costs: bool = True,
include_feedback: bool = True,
requested_synthetic_columns: Optional[List[str]] = None,
invalid_columns: Optional[Set[str]] = None,
) -> List[Dict[str, Any]]:
"""Special two-stage query logic for cost-based sorting.
Args:
entity_name: Weights & Biases entity name.
project_name: Weights & Biands project name.
filters: Dictionary of filter conditions.
sort_by: Cost field to sort by.
sort_direction: Sort direction ('asc' or 'desc').
target_limit: Maximum number of results to return.
columns: List of specific columns to include in the results.
expand_columns: List of columns to expand in the results.
include_costs: Include tracked API cost information in the results.
include_feedback: Include Weave annotations in the results.
requested_synthetic_columns: List of synthetic columns requested by the user.
invalid_columns: Set of invalid column names that were requested.
Returns:
List of trace dictionaries sorted by the specified cost field.
"""
if invalid_columns is None:
invalid_columns = set()
# First pass: Fetch all trace IDs and costs
first_pass_query = {
"entity_name": entity_name,
"project_name": project_name,
"filters": filters or {},
"sort_by": "started_at", # Use a standard sort for the first pass
"sort_direction": "desc",
"limit": 1000000, # Explicitly set a large limit to get all traces
"include_costs": True, # We need costs for sorting
"include_feedback": False, # Don't need feedback for the first pass
"columns": ["id", "summary"], # Need summary for costs data
}
first_pass_request = QueryBuilder.prepare_query_params(first_pass_query)
first_pass_results = list(self.client.query_traces(first_pass_request))
logger.info(
f"First pass of cost sorting request retrieved {len(first_pass_results)} traces"
)
# Filter and sort by cost
filtered_results = [
t
for t in first_pass_results
if TraceProcessor.get_cost(t, sort_by) is not None
]
filtered_results.sort(
key=lambda t: TraceProcessor.get_cost(t, sort_by),
reverse=(sort_direction == "desc"),
)
# Get the IDs of the top N traces
top_ids = (
[t["id"] for t in filtered_results[:target_limit] if "id" in t]
if target_limit
else [t["id"] for t in filtered_results if "id" in t]
)
logger.info(f"After sorting by {sort_by}, selected {len(top_ids)} trace IDs")
if not top_ids:
return []
# Second pass: Fetch the full details for the selected traces
second_pass_query = {
"entity_name": entity_name,
"project_name": project_name,
"filters": {"call_ids": top_ids},
"include_costs": include_costs,
"include_feedback": include_feedback,
"columns": columns,
"expand_columns": expand_columns,
}
# Make sure we request summary if costs were requested
if requested_synthetic_columns and "costs" in requested_synthetic_columns:
if not columns or "summary" not in columns:
if not second_pass_query["columns"]:
second_pass_query["columns"] = ["summary"]
elif "summary" not in second_pass_query["columns"]:
second_pass_query["columns"].append("summary")
logger.info("Added 'summary' to columns for cost data retrieval")
second_pass_request = QueryBuilder.prepare_query_params(second_pass_query)
second_pass_results = list(self.client.query_traces(second_pass_request))
logger.info(f"Second pass retrieved {len(second_pass_results)} traces")
# Add synthetic columns and invalid column warnings back to the results
if requested_synthetic_columns or invalid_columns:
second_pass_results = self._add_synthetic_columns(
second_pass_results,
requested_synthetic_columns or [],
invalid_columns,
)
# Ensure the results are in the same order as the IDs
id_to_index = {id: i for i, id in enumerate(top_ids)}
second_pass_results.sort(
key=lambda t: id_to_index.get(t.get("id"), float("inf"))
)
return second_pass_results