Spaces:
Paused
Paused
| from __future__ import annotations | |
| import logging | |
| import netrc | |
| import os | |
| import subprocess | |
| import sys | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| from urllib.parse import urlparse | |
| import simple_parsing | |
| from rich.logging import RichHandler | |
| from rich.console import Console | |
| os.environ["WANDB_SILENT"] = "True" | |
| os.environ["WEAVE_SILENT"] = "True" | |
| # Define a handler to redirect logs | |
| class RedirectLoggerHandler(logging.Handler): | |
| """A handler that redirects log records to another logger.""" | |
| def __init__(self, target_logger, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.target_logger = target_logger | |
| def emit(self, record): | |
| # Format the message using the handler's formatter if it has one | |
| # otherwise use the record's message. This ensures consistency | |
| # if formatters are used elsewhere. | |
| try: | |
| msg = self.format(record) | |
| new_record = logging.makeLogRecord( | |
| { | |
| **record.__dict__, | |
| "msg": msg, | |
| "args": [], # Args are already incorporated into msg by format() | |
| } | |
| ) | |
| self.target_logger.handle(new_record) | |
| except Exception: | |
| self.handleError(record) | |
| # Moved get_rich_logger here | |
| def get_rich_logger( | |
| name: str, | |
| propagate: bool = False, | |
| default_level_str: str = "INFO", | |
| env_var_name: Optional[str] = None, | |
| ) -> logging.Logger: | |
| """ | |
| Configure and return a logger with RichHandler. | |
| The log level can be set via an environment variable if `env_var_name` is provided. | |
| Otherwise, it defaults to `default_level_str`. | |
| """ | |
| logger = logging.getLogger(name) | |
| stderr_console = Console(stderr=True) | |
| _rich_handler = RichHandler( | |
| console=stderr_console, | |
| show_time=True, | |
| show_level=True, | |
| show_path=False, | |
| markup=True, | |
| ) | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| logger.addHandler(_rich_handler) | |
| # Determine the effective log level string | |
| # Start with the function's default_level_str (e.g., "INFO") | |
| effective_level_str = default_level_str.upper() | |
| source_of_level = f"function default ('{default_level_str}')" | |
| if env_var_name: | |
| env_level_value = os.environ.get(env_var_name) | |
| if env_level_value: | |
| effective_level_str = env_level_value.upper() | |
| source_of_level = ( | |
| f"environment variable '{env_var_name}' ('{env_level_value}')" | |
| ) | |
| # Attempt to convert the string to a logging level integer | |
| final_log_level = getattr(logging, effective_level_str, None) | |
| # If conversion failed, issue a warning and determine a fallback level. | |
| if not isinstance(final_log_level, int): | |
| warning_msg_parts = [ | |
| f"Warning: Invalid log level string '{effective_level_str}' from {source_of_level} for logger '{name}'.", | |
| "Valid levels are DEBUG, INFO, WARNING, ERROR, CRITICAL.", | |
| ] | |
| # Check if the issue was with an environment variable and if the original default_level_str is valid | |
| if ( | |
| env_var_name | |
| and os.environ.get(env_var_name) | |
| and effective_level_str != default_level_str.upper() | |
| ): | |
| fallback_to_default_level = getattr( | |
| logging, default_level_str.upper(), None | |
| ) | |
| if isinstance(fallback_to_default_level, int): | |
| final_log_level = fallback_to_default_level | |
| warning_msg_parts.append( | |
| f"Falling back to function default '{default_level_str.upper()}'." | |
| ) | |
| else: # Function default is also bad, use hardcoded INFO | |
| final_log_level = logging.INFO | |
| warning_msg_parts.append( | |
| f"Function default '{default_level_str.upper()}' also invalid. Falling back to INFO." | |
| ) | |
| else: # No env var was specified, or env var was not set, or default_level_str itself was bad | |
| final_log_level = logging.INFO # Hardcoded ultimate fallback | |
| warning_msg_parts.append("Falling back to INFO.") | |
| print(" ".join(warning_msg_parts), file=sys.stderr) | |
| logger.setLevel(final_log_level) | |
| logger.propagate = propagate | |
| return logger | |
| # Setup module-level logger now that get_rich_logger is defined | |
| utils_logger = get_rich_logger(__name__) | |
| # Define server arguments using a dataclass for simple_parsing | |
| class ServerMCPArgs: | |
| """Arguments for the Weave MCP Server.""" | |
| wandb_api_key: Optional[str] = field( | |
| default=None, metadata=dict(help="Weights & Biases API key") | |
| ) | |
| weave_entity: Optional[str] = field( | |
| default=None, | |
| metadata=dict( | |
| help="The Weights & Biases entity to log traced MCP server calls to" | |
| ), | |
| ) | |
| weave_project: Optional[str] = field( | |
| default="weave-mcp-server", | |
| metadata=dict( | |
| help="The Weights & Biases project to log traced MCP server calls to" | |
| ), | |
| ) | |
| transport: str = field( | |
| default="stdio", | |
| metadata=dict( | |
| help="Transport type: 'stdio' for local MCP client communication or 'http' for HTTP server" | |
| ), | |
| ) | |
| port: Optional[int] = field( | |
| default=None, | |
| metadata=dict( | |
| help="Port to run the HTTP server on. Defaults to 8080 when using HTTP transport." | |
| ), | |
| ) | |
| host: str = field( | |
| default="localhost", | |
| metadata=dict(help="Host to bind HTTP server to"), | |
| ) | |
| # Initialize server args global variable | |
| _server_args = None | |
| # Moved helper functions | |
| def _wandb_base_url() -> str: | |
| return os.getenv("WANDB_BASE_URL", "https://api.wandb.ai") | |
| def _wandb_api_key_via_netrc_file(filepath: str) -> str | None: | |
| netrc_path = os.path.expanduser(filepath) | |
| if not os.path.exists(netrc_path): | |
| return None | |
| nrc = netrc.netrc(netrc_path) | |
| res = nrc.authenticators(urlparse(_wandb_base_url()).netloc) | |
| api_key = None | |
| if res: | |
| _, _, api_key = res | |
| return api_key | |
| def _wandb_api_key_via_netrc() -> str | None: | |
| for filepath in ("~/.netrc", "~/_netrc"): | |
| api_key = _wandb_api_key_via_netrc_file(filepath) | |
| if api_key: | |
| return api_key | |
| return None | |
| def get_server_args(): | |
| """Get the server arguments, parsing them if not already done.""" | |
| global _server_args | |
| if _server_args is None: | |
| _server_args = ServerMCPArgs() # wandb_api_key is None by default | |
| # Only parse args when explicitly requested, not at import time | |
| if os.environ.get("PARSE_ARGS_AT_IMPORT", "0") == "1": | |
| # This potentially updates _server_args with values from command line, | |
| # including wandb_api_key if provided as an argument. | |
| _server_args = simple_parsing.parse(ServerMCPArgs, dest=_server_args) | |
| # Check netrc file first, if API key not already set (e.g., by CLI) | |
| if _server_args.wandb_api_key is None: | |
| netrc_api_key = _wandb_api_key_via_netrc() | |
| if netrc_api_key: | |
| os.environ["WANDB_API_KEY"] = netrc_api_key # Set for other modules | |
| _server_args.wandb_api_key = netrc_api_key | |
| # utils_logger.info("W&B API key loaded from .netrc file.") | |
| # If not set via netrc or CLI, try environment variable | |
| if _server_args.wandb_api_key is None: | |
| env_api_key = os.getenv("WANDB_API_KEY") | |
| if env_api_key: | |
| _server_args.wandb_api_key = env_api_key | |
| # utils_logger.info("W&B API key loaded from WANDB_API_KEY environment variable.") | |
| # If after all methods (CLI, netrc, env var), API key is still None or effectively empty, | |
| # set to empty string to match previous behavior and log a warning. | |
| if not _server_args.wandb_api_key: # Covers None or empty string | |
| _server_args.wandb_api_key = "" # Ensure it's an empty string if not found | |
| utils_logger.warning( | |
| "W&B API key was not found through command-line arguments, .netrc, or WANDB_API_KEY environment variable. " | |
| "Services requiring W&B authentication may not function correctly or may fail." | |
| ) | |
| return _server_args | |
| def merge_metadata(metadata_list: List[Dict]) -> Dict: | |
| """Merge metadata from multiple query results.""" | |
| if not metadata_list: | |
| return {} | |
| merged = { | |
| "total_traces": 0, | |
| "token_counts": { | |
| "total_tokens": 0, | |
| "input_tokens": 0, | |
| "output_tokens": 0, | |
| "average_tokens_per_trace": 0, | |
| }, | |
| "time_range": {"earliest": None, "latest": None}, | |
| "status_summary": {"success": 0, "error": 0, "other": 0}, | |
| "op_distribution": {}, | |
| } | |
| for metadata in metadata_list: | |
| # Sum up trace counts | |
| merged["total_traces"] += metadata.get("total_traces", 0) | |
| # Sum up token counts | |
| token_counts = metadata.get("token_counts", {}) | |
| merged["token_counts"]["total_tokens"] += token_counts.get("total_tokens", 0) | |
| merged["token_counts"]["input_tokens"] += token_counts.get("input_tokens", 0) | |
| merged["token_counts"]["output_tokens"] += token_counts.get("output_tokens", 0) | |
| # Update time range | |
| time_range = metadata.get("time_range", {}) | |
| if time_range.get("earliest"): | |
| if ( | |
| not merged["time_range"]["earliest"] | |
| or time_range["earliest"] < merged["time_range"]["earliest"] | |
| ): | |
| merged["time_range"]["earliest"] = time_range["earliest"] | |
| if time_range.get("latest"): | |
| if ( | |
| not merged["time_range"]["latest"] | |
| or time_range["latest"] > merged["time_range"]["latest"] | |
| ): | |
| merged["time_range"]["latest"] = time_range["latest"] | |
| # Sum up status counts | |
| status_summary = metadata.get("status_summary", {}) | |
| merged["status_summary"]["success"] += status_summary.get("success", 0) | |
| merged["status_summary"]["error"] += status_summary.get("error", 0) | |
| merged["status_summary"]["other"] += status_summary.get("other", 0) | |
| # Merge op distributions | |
| for op, count in metadata.get("op_distribution", {}).items(): | |
| merged["op_distribution"][op] = merged["op_distribution"].get(op, 0) + count | |
| # Calculate average tokens per trace | |
| if merged["total_traces"] > 0: | |
| merged["token_counts"]["average_tokens_per_trace"] = ( | |
| merged["token_counts"]["total_tokens"] / merged["total_traces"] | |
| ) | |
| return merged | |
| def get_git_commit(): | |
| logger = get_rich_logger(__name__) | |
| try: | |
| result = subprocess.run( | |
| ["git", "rev-parse", "HEAD"], capture_output=True, text=True | |
| ) | |
| return str(result.stdout.strip())[:8] | |
| except Exception as e: | |
| logger.warning(f"Failed to get git commit: {e}") | |
| return "unknown" | |