mcp-server / load_test.py
NiWaRe's picture
load test and scalability plan
40e4410
raw
history blame
12.1 kB
#!/usr/bin/env python3
"""
Load testing script for W&B MCP Server
Measures concurrent connections, requests/second, and latency
"""
import asyncio
import time
import statistics
from typing import List, Dict, Any, Optional
import httpx
import json
from datetime import datetime
import argparse
import sys
class MCPLoadTester:
def __init__(self, base_url: str = "http://localhost:7860", api_key: str = None):
self.base_url = base_url
self.api_key = api_key or "test_key_12345678901234567890123456789012345678"
self.metrics = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"response_times": [],
"session_creation_times": [],
"tool_call_times": []
}
async def create_session(self, client: httpx.AsyncClient) -> Optional[str]:
"""Initialize an MCP session."""
start_time = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
payload = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "load_test", "version": "1.0.0"}
},
"id": 1
}
try:
response = await client.post(
f"{self.base_url}/mcp",
headers=headers,
json=payload,
timeout=10
)
elapsed = time.time() - start_time
self.metrics["session_creation_times"].append(elapsed)
self.metrics["total_requests"] += 1
if response.status_code == 200:
self.metrics["successful_requests"] += 1
return response.headers.get("mcp-session-id")
else:
self.metrics["failed_requests"] += 1
return None
except Exception as e:
self.metrics["failed_requests"] += 1
self.metrics["total_requests"] += 1
print(f"Session creation failed: {e}")
return None
async def call_tool(self, client: httpx.AsyncClient, session_id: str, tool_name: str, params: Dict[str, Any]):
"""Call a tool using the session."""
start_time = time.time()
headers = {
"Mcp-Session-Id": session_id,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
payload = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": params
},
"id": 2
}
try:
response = await client.post(
f"{self.base_url}/mcp",
headers=headers,
json=payload,
timeout=30
)
elapsed = time.time() - start_time
self.metrics["tool_call_times"].append(elapsed)
self.metrics["response_times"].append(elapsed)
self.metrics["total_requests"] += 1
if response.status_code == 200:
self.metrics["successful_requests"] += 1
else:
self.metrics["failed_requests"] += 1
except Exception as e:
self.metrics["failed_requests"] += 1
self.metrics["total_requests"] += 1
print(f"Tool call failed: {e}")
async def run_client_session(self, client_id: int, num_requests: int, delay: float = 0.1):
"""Simulate a client making multiple requests."""
async with httpx.AsyncClient() as client:
# Create session
session_id = await self.create_session(client)
if not session_id:
return
# Make multiple tool calls
for i in range(num_requests):
await self.call_tool(
client,
session_id,
"query_wandb_entity_projects", # Simple tool that doesn't need entity/project
{}
)
# Small delay between requests
if delay > 0:
await asyncio.sleep(delay)
async def run_load_test(self, num_clients: int, requests_per_client: int, delay: float = 0.1):
"""Run the load test with specified parameters."""
print(f"\n{'='*60}")
print(f"Starting Load Test")
print(f"{'='*60}")
print(f"Clients: {num_clients}")
print(f"Requests per client: {requests_per_client}")
print(f"Total requests: {num_clients * (requests_per_client + 1)}") # +1 for session creation
print(f"Server: {self.base_url}")
print(f"Delay between requests: {delay}s")
print(f"{'='*60}\n")
# Reset metrics
self.metrics = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"response_times": [],
"session_creation_times": [],
"tool_call_times": []
}
start_time = time.time()
# Run all client sessions concurrently
tasks = [
self.run_client_session(i, requests_per_client, delay)
for i in range(num_clients)
]
# Show progress
print("Running load test...")
await asyncio.gather(*tasks)
total_time = time.time() - start_time
# Calculate and display results
self.display_results(total_time, num_clients, requests_per_client)
return self.metrics
def display_results(self, total_time: float, num_clients: int, requests_per_client: int):
"""Display load test results."""
print(f"\n{'='*60}")
print(f"Load Test Results")
print(f"{'='*60}")
# Overall metrics
total_requests = self.metrics["total_requests"]
success_rate = (self.metrics["successful_requests"] / total_requests * 100) if total_requests > 0 else 0
print(f"\n📊 Overall Metrics:")
print(f" Total Time: {total_time:.2f}s")
print(f" Total Requests: {total_requests}")
print(f" Successful: {self.metrics['successful_requests']} ({success_rate:.1f}%)")
print(f" Failed: {self.metrics['failed_requests']}")
if total_time > 0:
print(f" Requests/Second: {total_requests / total_time:.2f}")
# Session creation metrics
if self.metrics["session_creation_times"]:
print(f"\n🔑 Session Creation:")
print(f" Mean: {statistics.mean(self.metrics['session_creation_times']):.3f}s")
print(f" Median: {statistics.median(self.metrics['session_creation_times']):.3f}s")
if len(self.metrics["session_creation_times"]) > 1:
print(f" Std Dev: {statistics.stdev(self.metrics['session_creation_times']):.3f}s")
# Tool call metrics
if self.metrics["tool_call_times"]:
print(f"\n🔧 Tool Calls:")
print(f" Mean: {statistics.mean(self.metrics['tool_call_times']):.3f}s")
print(f" Median: {statistics.median(self.metrics['tool_call_times']):.3f}s")
if len(self.metrics["tool_call_times"]) > 1:
print(f" Std Dev: {statistics.stdev(self.metrics['tool_call_times']):.3f}s")
print(f" Min: {min(self.metrics['tool_call_times']):.3f}s")
print(f" Max: {max(self.metrics['tool_call_times']):.3f}s")
# Calculate percentiles
sorted_times = sorted(self.metrics["tool_call_times"])
p50_idx = len(sorted_times) // 2
p95_idx = min(int(len(sorted_times) * 0.95), len(sorted_times) - 1)
p99_idx = min(int(len(sorted_times) * 0.99), len(sorted_times) - 1)
p50 = sorted_times[p50_idx]
p95 = sorted_times[p95_idx]
p99 = sorted_times[p99_idx]
print(f"\n📈 Latency Percentiles:")
print(f" p50: {p50:.3f}s")
print(f" p95: {p95:.3f}s")
print(f" p99: {p99:.3f}s")
# Throughput
print(f"\n⚡ Throughput:")
print(f" Concurrent Clients: {num_clients}")
if total_time > 0:
print(f" Requests/Second/Client: {(requests_per_client + 1) / total_time:.2f}")
print(f" Total Throughput: {total_requests / total_time:.2f} req/s")
print(f"\n{'='*60}\n")
async def run_standard_tests(base_url: str = "http://localhost:7860", api_key: str = None):
"""Run standard load test scenarios."""
tester = MCPLoadTester(base_url, api_key)
# Test 1: Light load (10 clients, 5 requests each)
print("\n🟢 TEST 1: Light Load")
await tester.run_load_test(10, 5, delay=0.1)
# Test 2: Medium load (50 clients, 10 requests each)
print("\n🟡 TEST 2: Medium Load")
await tester.run_load_test(50, 10, delay=0.05)
# Test 3: Heavy load (100 clients, 20 requests each)
print("\n🔴 TEST 3: Heavy Load")
await tester.run_load_test(100, 20, delay=0.01)
async def run_stress_test(base_url: str = "http://localhost:7860", api_key: str = None):
"""Run stress test to find breaking point."""
tester = MCPLoadTester(base_url, api_key)
print("\n🔥 STRESS TEST: Finding Breaking Point")
print("=" * 60)
client_counts = [10, 25, 50, 100, 200, 500]
results = []
for clients in client_counts:
print(f"\nTesting with {clients} concurrent clients...")
metrics = await tester.run_load_test(clients, 10, delay=0.01)
success_rate = (metrics["successful_requests"] / metrics["total_requests"] * 100) if metrics["total_requests"] > 0 else 0
results.append((clients, success_rate))
# Stop if success rate drops below 95%
if success_rate < 95:
print(f"\n⚠️ Performance degradation detected at {clients} clients")
print(f"Success rate dropped to {success_rate:.1f}%")
break
print("\n📊 Stress Test Summary:")
print("Clients | Success Rate")
print("--------|-------------")
for clients, rate in results:
print(f"{clients:7d} | {rate:6.1f}%")
def main():
parser = argparse.ArgumentParser(description='Load test W&B MCP Server')
parser.add_argument('--url', default='http://localhost:7860', help='Server URL')
parser.add_argument('--api-key', help='W&B API key (optional, uses test key if not provided)')
parser.add_argument('--mode', choices=['standard', 'stress', 'custom'], default='standard',
help='Test mode: standard, stress, or custom')
parser.add_argument('--clients', type=int, default=10, help='Number of concurrent clients (for custom mode)')
parser.add_argument('--requests', type=int, default=10, help='Requests per client (for custom mode)')
parser.add_argument('--delay', type=float, default=0.1, help='Delay between requests in seconds (for custom mode)')
args = parser.parse_args()
print("W&B MCP Server Load Tester")
print(f"Server: {args.url}")
print(f"Mode: {args.mode}")
if args.mode == 'standard':
asyncio.run(run_standard_tests(args.url, args.api_key))
elif args.mode == 'stress':
asyncio.run(run_stress_test(args.url, args.api_key))
else: # custom
tester = MCPLoadTester(args.url, args.api_key)
asyncio.run(tester.run_load_test(args.clients, args.requests, args.delay))
if __name__ == "__main__":
main()