File size: 12,134 Bytes
40e4410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#!/usr/bin/env python3
"""
Load testing script for W&B MCP Server
Measures concurrent connections, requests/second, and latency
"""

import asyncio
import time
import statistics
from typing import List, Dict, Any, Optional
import httpx
import json
from datetime import datetime
import argparse
import sys

class MCPLoadTester:
    def __init__(self, base_url: str = "http://localhost:7860", api_key: str = None):
        self.base_url = base_url
        self.api_key = api_key or "test_key_12345678901234567890123456789012345678"
        self.metrics = {
            "total_requests": 0,
            "successful_requests": 0,
            "failed_requests": 0,
            "response_times": [],
            "session_creation_times": [],
            "tool_call_times": []
        }
    
    async def create_session(self, client: httpx.AsyncClient) -> Optional[str]:
        """Initialize an MCP session."""
        start_time = time.time()
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "Accept": "application/json, text/event-stream",
        }
        
        payload = {
            "jsonrpc": "2.0",
            "method": "initialize",
            "params": {
                "protocolVersion": "2025-06-18",
                "capabilities": {},
                "clientInfo": {"name": "load_test", "version": "1.0.0"}
            },
            "id": 1
        }
        
        try:
            response = await client.post(
                f"{self.base_url}/mcp",
                headers=headers,
                json=payload,
                timeout=10
            )
            
            elapsed = time.time() - start_time
            self.metrics["session_creation_times"].append(elapsed)
            self.metrics["total_requests"] += 1
            
            if response.status_code == 200:
                self.metrics["successful_requests"] += 1
                return response.headers.get("mcp-session-id")
            else:
                self.metrics["failed_requests"] += 1
                return None
                
        except Exception as e:
            self.metrics["failed_requests"] += 1
            self.metrics["total_requests"] += 1
            print(f"Session creation failed: {e}")
            return None
    
    async def call_tool(self, client: httpx.AsyncClient, session_id: str, tool_name: str, params: Dict[str, Any]):
        """Call a tool using the session."""
        start_time = time.time()
        
        headers = {
            "Mcp-Session-Id": session_id,
            "Content-Type": "application/json",
            "Accept": "application/json, text/event-stream",
        }
        
        payload = {
            "jsonrpc": "2.0",
            "method": "tools/call",
            "params": {
                "name": tool_name,
                "arguments": params
            },
            "id": 2
        }
        
        try:
            response = await client.post(
                f"{self.base_url}/mcp",
                headers=headers,
                json=payload,
                timeout=30
            )
            
            elapsed = time.time() - start_time
            self.metrics["tool_call_times"].append(elapsed)
            self.metrics["response_times"].append(elapsed)
            self.metrics["total_requests"] += 1
            
            if response.status_code == 200:
                self.metrics["successful_requests"] += 1
            else:
                self.metrics["failed_requests"] += 1
                
        except Exception as e:
            self.metrics["failed_requests"] += 1
            self.metrics["total_requests"] += 1
            print(f"Tool call failed: {e}")
    
    async def run_client_session(self, client_id: int, num_requests: int, delay: float = 0.1):
        """Simulate a client making multiple requests."""
        async with httpx.AsyncClient() as client:
            # Create session
            session_id = await self.create_session(client)
            if not session_id:
                return
            
            # Make multiple tool calls
            for i in range(num_requests):
                await self.call_tool(
                    client,
                    session_id,
                    "query_wandb_entity_projects",  # Simple tool that doesn't need entity/project
                    {}
                )
                
                # Small delay between requests
                if delay > 0:
                    await asyncio.sleep(delay)
    
    async def run_load_test(self, num_clients: int, requests_per_client: int, delay: float = 0.1):
        """Run the load test with specified parameters."""
        print(f"\n{'='*60}")
        print(f"Starting Load Test")
        print(f"{'='*60}")
        print(f"Clients: {num_clients}")
        print(f"Requests per client: {requests_per_client}")
        print(f"Total requests: {num_clients * (requests_per_client + 1)}")  # +1 for session creation
        print(f"Server: {self.base_url}")
        print(f"Delay between requests: {delay}s")
        print(f"{'='*60}\n")
        
        # Reset metrics
        self.metrics = {
            "total_requests": 0,
            "successful_requests": 0,
            "failed_requests": 0,
            "response_times": [],
            "session_creation_times": [],
            "tool_call_times": []
        }
        
        start_time = time.time()
        
        # Run all client sessions concurrently
        tasks = [
            self.run_client_session(i, requests_per_client, delay)
            for i in range(num_clients)
        ]
        
        # Show progress
        print("Running load test...")
        await asyncio.gather(*tasks)
        
        total_time = time.time() - start_time
        
        # Calculate and display results
        self.display_results(total_time, num_clients, requests_per_client)
        
        return self.metrics
    
    def display_results(self, total_time: float, num_clients: int, requests_per_client: int):
        """Display load test results."""
        print(f"\n{'='*60}")
        print(f"Load Test Results")
        print(f"{'='*60}")
        
        # Overall metrics
        total_requests = self.metrics["total_requests"]
        success_rate = (self.metrics["successful_requests"] / total_requests * 100) if total_requests > 0 else 0
        
        print(f"\n📊 Overall Metrics:")
        print(f"  Total Time: {total_time:.2f}s")
        print(f"  Total Requests: {total_requests}")
        print(f"  Successful: {self.metrics['successful_requests']} ({success_rate:.1f}%)")
        print(f"  Failed: {self.metrics['failed_requests']}")
        if total_time > 0:
            print(f"  Requests/Second: {total_requests / total_time:.2f}")
        
        # Session creation metrics
        if self.metrics["session_creation_times"]:
            print(f"\n🔑 Session Creation:")
            print(f"  Mean: {statistics.mean(self.metrics['session_creation_times']):.3f}s")
            print(f"  Median: {statistics.median(self.metrics['session_creation_times']):.3f}s")
            if len(self.metrics["session_creation_times"]) > 1:
                print(f"  Std Dev: {statistics.stdev(self.metrics['session_creation_times']):.3f}s")
        
        # Tool call metrics
        if self.metrics["tool_call_times"]:
            print(f"\n🔧 Tool Calls:")
            print(f"  Mean: {statistics.mean(self.metrics['tool_call_times']):.3f}s")
            print(f"  Median: {statistics.median(self.metrics['tool_call_times']):.3f}s")
            if len(self.metrics["tool_call_times"]) > 1:
                print(f"  Std Dev: {statistics.stdev(self.metrics['tool_call_times']):.3f}s")
                print(f"  Min: {min(self.metrics['tool_call_times']):.3f}s")
                print(f"  Max: {max(self.metrics['tool_call_times']):.3f}s")
                
                # Calculate percentiles
                sorted_times = sorted(self.metrics["tool_call_times"])
                p50_idx = len(sorted_times) // 2
                p95_idx = min(int(len(sorted_times) * 0.95), len(sorted_times) - 1)
                p99_idx = min(int(len(sorted_times) * 0.99), len(sorted_times) - 1)
                
                p50 = sorted_times[p50_idx]
                p95 = sorted_times[p95_idx]
                p99 = sorted_times[p99_idx]
                
                print(f"\n📈 Latency Percentiles:")
                print(f"  p50: {p50:.3f}s")
                print(f"  p95: {p95:.3f}s")
                print(f"  p99: {p99:.3f}s")
        
        # Throughput
        print(f"\n⚡ Throughput:")
        print(f"  Concurrent Clients: {num_clients}")
        if total_time > 0:
            print(f"  Requests/Second/Client: {(requests_per_client + 1) / total_time:.2f}")
            print(f"  Total Throughput: {total_requests / total_time:.2f} req/s")
        
        print(f"\n{'='*60}\n")


async def run_standard_tests(base_url: str = "http://localhost:7860", api_key: str = None):
    """Run standard load test scenarios."""
    tester = MCPLoadTester(base_url, api_key)
    
    # Test 1: Light load (10 clients, 5 requests each)
    print("\n🟢 TEST 1: Light Load")
    await tester.run_load_test(10, 5, delay=0.1)
    
    # Test 2: Medium load (50 clients, 10 requests each)
    print("\n🟡 TEST 2: Medium Load")
    await tester.run_load_test(50, 10, delay=0.05)
    
    # Test 3: Heavy load (100 clients, 20 requests each)
    print("\n🔴 TEST 3: Heavy Load")
    await tester.run_load_test(100, 20, delay=0.01)


async def run_stress_test(base_url: str = "http://localhost:7860", api_key: str = None):
    """Run stress test to find breaking point."""
    tester = MCPLoadTester(base_url, api_key)
    
    print("\n🔥 STRESS TEST: Finding Breaking Point")
    print("=" * 60)
    
    client_counts = [10, 25, 50, 100, 200, 500]
    results = []
    
    for clients in client_counts:
        print(f"\nTesting with {clients} concurrent clients...")
        metrics = await tester.run_load_test(clients, 10, delay=0.01)
        
        success_rate = (metrics["successful_requests"] / metrics["total_requests"] * 100) if metrics["total_requests"] > 0 else 0
        results.append((clients, success_rate))
        
        # Stop if success rate drops below 95%
        if success_rate < 95:
            print(f"\n⚠️ Performance degradation detected at {clients} clients")
            print(f"Success rate dropped to {success_rate:.1f}%")
            break
    
    print("\n📊 Stress Test Summary:")
    print("Clients | Success Rate")
    print("--------|-------------")
    for clients, rate in results:
        print(f"{clients:7d} | {rate:6.1f}%")


def main():
    parser = argparse.ArgumentParser(description='Load test W&B MCP Server')
    parser.add_argument('--url', default='http://localhost:7860', help='Server URL')
    parser.add_argument('--api-key', help='W&B API key (optional, uses test key if not provided)')
    parser.add_argument('--mode', choices=['standard', 'stress', 'custom'], default='standard',
                        help='Test mode: standard, stress, or custom')
    parser.add_argument('--clients', type=int, default=10, help='Number of concurrent clients (for custom mode)')
    parser.add_argument('--requests', type=int, default=10, help='Requests per client (for custom mode)')
    parser.add_argument('--delay', type=float, default=0.1, help='Delay between requests in seconds (for custom mode)')
    
    args = parser.parse_args()
    
    print("W&B MCP Server Load Tester")
    print(f"Server: {args.url}")
    print(f"Mode: {args.mode}")
    
    if args.mode == 'standard':
        asyncio.run(run_standard_tests(args.url, args.api_key))
    elif args.mode == 'stress':
        asyncio.run(run_stress_test(args.url, args.api_key))
    else:  # custom
        tester = MCPLoadTester(args.url, args.api_key)
        asyncio.run(tester.run_load_test(args.clients, args.requests, args.delay))


if __name__ == "__main__":
    main()