from __future__ import annotations import logging import netrc import os import subprocess import sys from dataclasses import dataclass, field from typing import Dict, List, Optional from urllib.parse import urlparse import simple_parsing from rich.logging import RichHandler from rich.console import Console os.environ["WANDB_SILENT"] = "True" os.environ["WEAVE_SILENT"] = "True" # Define a handler to redirect logs class RedirectLoggerHandler(logging.Handler): """A handler that redirects log records to another logger.""" def __init__(self, target_logger, *args, **kwargs): super().__init__(*args, **kwargs) self.target_logger = target_logger def emit(self, record): # Format the message using the handler's formatter if it has one # otherwise use the record's message. This ensures consistency # if formatters are used elsewhere. try: msg = self.format(record) new_record = logging.makeLogRecord( { **record.__dict__, "msg": msg, "args": [], # Args are already incorporated into msg by format() } ) self.target_logger.handle(new_record) except Exception: self.handleError(record) # Moved get_rich_logger here def get_rich_logger( name: str, propagate: bool = False, default_level_str: str = "INFO", env_var_name: Optional[str] = None, ) -> logging.Logger: """ Configure and return a logger with RichHandler. The log level can be set via an environment variable if `env_var_name` is provided. Otherwise, it defaults to `default_level_str`. """ logger = logging.getLogger(name) stderr_console = Console(stderr=True) _rich_handler = RichHandler( console=stderr_console, show_time=True, show_level=True, show_path=False, markup=True, ) if logger.hasHandlers(): logger.handlers.clear() logger.addHandler(_rich_handler) # Determine the effective log level string # Start with the function's default_level_str (e.g., "INFO") effective_level_str = default_level_str.upper() source_of_level = f"function default ('{default_level_str}')" if env_var_name: env_level_value = os.environ.get(env_var_name) if env_level_value: effective_level_str = env_level_value.upper() source_of_level = ( f"environment variable '{env_var_name}' ('{env_level_value}')" ) # Attempt to convert the string to a logging level integer final_log_level = getattr(logging, effective_level_str, None) # If conversion failed, issue a warning and determine a fallback level. if not isinstance(final_log_level, int): warning_msg_parts = [ f"Warning: Invalid log level string '{effective_level_str}' from {source_of_level} for logger '{name}'.", "Valid levels are DEBUG, INFO, WARNING, ERROR, CRITICAL.", ] # Check if the issue was with an environment variable and if the original default_level_str is valid if ( env_var_name and os.environ.get(env_var_name) and effective_level_str != default_level_str.upper() ): fallback_to_default_level = getattr( logging, default_level_str.upper(), None ) if isinstance(fallback_to_default_level, int): final_log_level = fallback_to_default_level warning_msg_parts.append( f"Falling back to function default '{default_level_str.upper()}'." ) else: # Function default is also bad, use hardcoded INFO final_log_level = logging.INFO warning_msg_parts.append( f"Function default '{default_level_str.upper()}' also invalid. Falling back to INFO." ) else: # No env var was specified, or env var was not set, or default_level_str itself was bad final_log_level = logging.INFO # Hardcoded ultimate fallback warning_msg_parts.append("Falling back to INFO.") print(" ".join(warning_msg_parts), file=sys.stderr) logger.setLevel(final_log_level) logger.propagate = propagate return logger # Setup module-level logger now that get_rich_logger is defined utils_logger = get_rich_logger(__name__) # Define server arguments using a dataclass for simple_parsing @dataclass class ServerMCPArgs: """Arguments for the Weave MCP Server.""" wandb_api_key: Optional[str] = field( default=None, metadata=dict(help="Weights & Biases API key") ) weave_entity: Optional[str] = field( default=None, metadata=dict( help="The Weights & Biases entity to log traced MCP server calls to" ), ) weave_project: Optional[str] = field( default="weave-mcp-server", metadata=dict( help="The Weights & Biases project to log traced MCP server calls to" ), ) transport: str = field( default="stdio", metadata=dict( help="Transport type: 'stdio' for local MCP client communication or 'http' for HTTP server" ), ) port: Optional[int] = field( default=None, metadata=dict( help="Port to run the HTTP server on. Defaults to 8080 when using HTTP transport." ), ) host: str = field( default="localhost", metadata=dict(help="Host to bind HTTP server to"), ) # Initialize server args global variable _server_args = None # Moved helper functions def _wandb_base_url() -> str: return os.getenv("WANDB_BASE_URL", "https://api.wandb.ai") def _wandb_api_key_via_netrc_file(filepath: str) -> str | None: netrc_path = os.path.expanduser(filepath) if not os.path.exists(netrc_path): return None nrc = netrc.netrc(netrc_path) res = nrc.authenticators(urlparse(_wandb_base_url()).netloc) api_key = None if res: _, _, api_key = res return api_key def _wandb_api_key_via_netrc() -> str | None: for filepath in ("~/.netrc", "~/_netrc"): api_key = _wandb_api_key_via_netrc_file(filepath) if api_key: return api_key return None def get_server_args(): """Get the server arguments, parsing them if not already done.""" global _server_args if _server_args is None: _server_args = ServerMCPArgs() # wandb_api_key is None by default # Only parse args when explicitly requested, not at import time if os.environ.get("PARSE_ARGS_AT_IMPORT", "0") == "1": # This potentially updates _server_args with values from command line, # including wandb_api_key if provided as an argument. _server_args = simple_parsing.parse(ServerMCPArgs, dest=_server_args) # Check netrc file first, if API key not already set (e.g., by CLI) if _server_args.wandb_api_key is None: netrc_api_key = _wandb_api_key_via_netrc() if netrc_api_key: os.environ["WANDB_API_KEY"] = netrc_api_key # Set for other modules _server_args.wandb_api_key = netrc_api_key # utils_logger.info("W&B API key loaded from .netrc file.") # If not set via netrc or CLI, try environment variable if _server_args.wandb_api_key is None: env_api_key = os.getenv("WANDB_API_KEY") if env_api_key: _server_args.wandb_api_key = env_api_key # utils_logger.info("W&B API key loaded from WANDB_API_KEY environment variable.") # If after all methods (CLI, netrc, env var), API key is still None or effectively empty, # set to empty string to match previous behavior and log a warning. if not _server_args.wandb_api_key: # Covers None or empty string _server_args.wandb_api_key = "" # Ensure it's an empty string if not found utils_logger.warning( "W&B API key was not found through command-line arguments, .netrc, or WANDB_API_KEY environment variable. " "Services requiring W&B authentication may not function correctly or may fail." ) return _server_args def merge_metadata(metadata_list: List[Dict]) -> Dict: """Merge metadata from multiple query results.""" if not metadata_list: return {} merged = { "total_traces": 0, "token_counts": { "total_tokens": 0, "input_tokens": 0, "output_tokens": 0, "average_tokens_per_trace": 0, }, "time_range": {"earliest": None, "latest": None}, "status_summary": {"success": 0, "error": 0, "other": 0}, "op_distribution": {}, } for metadata in metadata_list: # Sum up trace counts merged["total_traces"] += metadata.get("total_traces", 0) # Sum up token counts token_counts = metadata.get("token_counts", {}) merged["token_counts"]["total_tokens"] += token_counts.get("total_tokens", 0) merged["token_counts"]["input_tokens"] += token_counts.get("input_tokens", 0) merged["token_counts"]["output_tokens"] += token_counts.get("output_tokens", 0) # Update time range time_range = metadata.get("time_range", {}) if time_range.get("earliest"): if ( not merged["time_range"]["earliest"] or time_range["earliest"] < merged["time_range"]["earliest"] ): merged["time_range"]["earliest"] = time_range["earliest"] if time_range.get("latest"): if ( not merged["time_range"]["latest"] or time_range["latest"] > merged["time_range"]["latest"] ): merged["time_range"]["latest"] = time_range["latest"] # Sum up status counts status_summary = metadata.get("status_summary", {}) merged["status_summary"]["success"] += status_summary.get("success", 0) merged["status_summary"]["error"] += status_summary.get("error", 0) merged["status_summary"]["other"] += status_summary.get("other", 0) # Merge op distributions for op, count in metadata.get("op_distribution", {}).items(): merged["op_distribution"][op] = merged["op_distribution"].get(op, 0) + count # Calculate average tokens per trace if merged["total_traces"] > 0: merged["token_counts"]["average_tokens_per_trace"] = ( merged["token_counts"]["total_tokens"] / merged["total_traces"] ) return merged def get_git_commit(): logger = get_rich_logger(__name__) try: result = subprocess.run( ["git", "rev-parse", "HEAD"], capture_output=True, text=True ) return str(result.stdout.strip())[:8] except Exception as e: logger.warning(f"Failed to get git commit: {e}") return "unknown"