#!/usr/bin/env python3 """ Load testing script for W&B MCP Server Measures concurrent connections, requests/second, and latency """ import asyncio import time import statistics from typing import List, Dict, Any, Optional import httpx import json from datetime import datetime import argparse import sys class MCPLoadTester: def __init__(self, base_url: str = "http://localhost:7860", api_key: str = None): self.base_url = base_url self.api_key = api_key or "test_key_12345678901234567890123456789012345678" self.metrics = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "response_times": [], "session_creation_times": [], "tool_call_times": [] } async def create_session(self, client: httpx.AsyncClient) -> Optional[str]: """Initialize an MCP session.""" start_time = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json, text/event-stream", } payload = { "jsonrpc": "2.0", "method": "initialize", "params": { "protocolVersion": "2025-06-18", "capabilities": {}, "clientInfo": {"name": "load_test", "version": "1.0.0"} }, "id": 1 } try: response = await client.post( f"{self.base_url}/mcp", headers=headers, json=payload, timeout=10 ) elapsed = time.time() - start_time self.metrics["session_creation_times"].append(elapsed) self.metrics["total_requests"] += 1 if response.status_code == 200: self.metrics["successful_requests"] += 1 return response.headers.get("mcp-session-id") else: self.metrics["failed_requests"] += 1 return None except Exception as e: self.metrics["failed_requests"] += 1 self.metrics["total_requests"] += 1 print(f"Session creation failed: {e}") return None async def call_tool(self, client: httpx.AsyncClient, session_id: str, tool_name: str, params: Dict[str, Any]): """Call a tool using the session.""" start_time = time.time() headers = { "Mcp-Session-Id": session_id, "Content-Type": "application/json", "Accept": "application/json, text/event-stream", } payload = { "jsonrpc": "2.0", "method": "tools/call", "params": { "name": tool_name, "arguments": params }, "id": 2 } try: response = await client.post( f"{self.base_url}/mcp", headers=headers, json=payload, timeout=30 ) elapsed = time.time() - start_time self.metrics["tool_call_times"].append(elapsed) self.metrics["response_times"].append(elapsed) self.metrics["total_requests"] += 1 if response.status_code == 200: self.metrics["successful_requests"] += 1 else: self.metrics["failed_requests"] += 1 except Exception as e: self.metrics["failed_requests"] += 1 self.metrics["total_requests"] += 1 print(f"Tool call failed: {e}") async def run_client_session(self, client_id: int, num_requests: int, delay: float = 0.1): """Simulate a client making multiple requests.""" async with httpx.AsyncClient() as client: # Create session session_id = await self.create_session(client) if not session_id: return # Make multiple tool calls for i in range(num_requests): await self.call_tool( client, session_id, "query_wandb_entity_projects", # Simple tool that doesn't need entity/project {} ) # Small delay between requests if delay > 0: await asyncio.sleep(delay) async def run_load_test(self, num_clients: int, requests_per_client: int, delay: float = 0.1): """Run the load test with specified parameters.""" print(f"\n{'='*60}") print(f"Starting Load Test") print(f"{'='*60}") print(f"Clients: {num_clients}") print(f"Requests per client: {requests_per_client}") print(f"Total requests: {num_clients * (requests_per_client + 1)}") # +1 for session creation print(f"Server: {self.base_url}") print(f"Delay between requests: {delay}s") print(f"{'='*60}\n") # Reset metrics self.metrics = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "response_times": [], "session_creation_times": [], "tool_call_times": [] } start_time = time.time() # Run all client sessions concurrently tasks = [ self.run_client_session(i, requests_per_client, delay) for i in range(num_clients) ] # Show progress print("Running load test...") await asyncio.gather(*tasks) total_time = time.time() - start_time # Calculate and display results self.display_results(total_time, num_clients, requests_per_client) return self.metrics def display_results(self, total_time: float, num_clients: int, requests_per_client: int): """Display load test results.""" print(f"\n{'='*60}") print(f"Load Test Results") print(f"{'='*60}") # Overall metrics total_requests = self.metrics["total_requests"] success_rate = (self.metrics["successful_requests"] / total_requests * 100) if total_requests > 0 else 0 print(f"\nšŸ“Š Overall Metrics:") print(f" Total Time: {total_time:.2f}s") print(f" Total Requests: {total_requests}") print(f" Successful: {self.metrics['successful_requests']} ({success_rate:.1f}%)") print(f" Failed: {self.metrics['failed_requests']}") if total_time > 0: print(f" Requests/Second: {total_requests / total_time:.2f}") # Session creation metrics if self.metrics["session_creation_times"]: print(f"\nšŸ”‘ Session Creation:") print(f" Mean: {statistics.mean(self.metrics['session_creation_times']):.3f}s") print(f" Median: {statistics.median(self.metrics['session_creation_times']):.3f}s") if len(self.metrics["session_creation_times"]) > 1: print(f" Std Dev: {statistics.stdev(self.metrics['session_creation_times']):.3f}s") # Tool call metrics if self.metrics["tool_call_times"]: print(f"\nšŸ”§ Tool Calls:") print(f" Mean: {statistics.mean(self.metrics['tool_call_times']):.3f}s") print(f" Median: {statistics.median(self.metrics['tool_call_times']):.3f}s") if len(self.metrics["tool_call_times"]) > 1: print(f" Std Dev: {statistics.stdev(self.metrics['tool_call_times']):.3f}s") print(f" Min: {min(self.metrics['tool_call_times']):.3f}s") print(f" Max: {max(self.metrics['tool_call_times']):.3f}s") # Calculate percentiles sorted_times = sorted(self.metrics["tool_call_times"]) p50_idx = len(sorted_times) // 2 p95_idx = min(int(len(sorted_times) * 0.95), len(sorted_times) - 1) p99_idx = min(int(len(sorted_times) * 0.99), len(sorted_times) - 1) p50 = sorted_times[p50_idx] p95 = sorted_times[p95_idx] p99 = sorted_times[p99_idx] print(f"\nšŸ“ˆ Latency Percentiles:") print(f" p50: {p50:.3f}s") print(f" p95: {p95:.3f}s") print(f" p99: {p99:.3f}s") # Throughput print(f"\n⚔ Throughput:") print(f" Concurrent Clients: {num_clients}") if total_time > 0: print(f" Requests/Second/Client: {(requests_per_client + 1) / total_time:.2f}") print(f" Total Throughput: {total_requests / total_time:.2f} req/s") print(f"\n{'='*60}\n") async def run_standard_tests(base_url: str = "http://localhost:7860", api_key: str = None): """Run standard load test scenarios.""" tester = MCPLoadTester(base_url, api_key) # Test 1: Light load (10 clients, 5 requests each) print("\n🟢 TEST 1: Light Load") await tester.run_load_test(10, 5, delay=0.1) # Test 2: Medium load (50 clients, 10 requests each) print("\n🟔 TEST 2: Medium Load") await tester.run_load_test(50, 10, delay=0.05) # Test 3: Heavy load (100 clients, 20 requests each) print("\nšŸ”“ TEST 3: Heavy Load") await tester.run_load_test(100, 20, delay=0.01) async def run_stress_test(base_url: str = "http://localhost:7860", api_key: str = None): """Run stress test to find breaking point.""" tester = MCPLoadTester(base_url, api_key) print("\nšŸ”„ STRESS TEST: Finding Breaking Point") print("=" * 60) client_counts = [10, 25, 50, 100, 200, 500] results = [] for clients in client_counts: print(f"\nTesting with {clients} concurrent clients...") metrics = await tester.run_load_test(clients, 10, delay=0.01) success_rate = (metrics["successful_requests"] / metrics["total_requests"] * 100) if metrics["total_requests"] > 0 else 0 results.append((clients, success_rate)) # Stop if success rate drops below 95% if success_rate < 95: print(f"\nāš ļø Performance degradation detected at {clients} clients") print(f"Success rate dropped to {success_rate:.1f}%") break print("\nšŸ“Š Stress Test Summary:") print("Clients | Success Rate") print("--------|-------------") for clients, rate in results: print(f"{clients:7d} | {rate:6.1f}%") def main(): parser = argparse.ArgumentParser(description='Load test W&B MCP Server') parser.add_argument('--url', default='http://localhost:7860', help='Server URL') parser.add_argument('--api-key', help='W&B API key (optional, uses test key if not provided)') parser.add_argument('--mode', choices=['standard', 'stress', 'custom'], default='standard', help='Test mode: standard, stress, or custom') parser.add_argument('--clients', type=int, default=10, help='Number of concurrent clients (for custom mode)') parser.add_argument('--requests', type=int, default=10, help='Requests per client (for custom mode)') parser.add_argument('--delay', type=float, default=0.1, help='Delay between requests in seconds (for custom mode)') args = parser.parse_args() print("W&B MCP Server Load Tester") print(f"Server: {args.url}") print(f"Mode: {args.mode}") if args.mode == 'standard': asyncio.run(run_standard_tests(args.url, args.api_key)) elif args.mode == 'stress': asyncio.run(run_stress_test(args.url, args.api_key)) else: # custom tester = MCPLoadTester(args.url, args.api_key) asyncio.run(tester.run_load_test(args.clients, args.requests, args.delay)) if __name__ == "__main__": main()