File size: 8,720 Bytes
4e9b4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import json
import time
import requests
from typing import Dict, List, Any, Optional, Union
from pathlib import Path

class RequestsMCPClient:
    """
    A requests-based implementation of an MCP client with configurable request timeout.
    """    
    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the MCP client with extended request timeout.        
        Args:
            config: Dictionary with configuration parameters.
                'url': The URL of the MCP server.
                'transport': The transport method ('sse' or 'http').
                'timeout': Request timeout in seconds (default: 3600 seconds / 1 hour).
        """
        self.url = config.get("url", "http://localhost:11434")
        self.transport = config.get("transport", "sse")
        # Set an extended timeout - default 1 hour
        self.timeout = config.get("timeout", 60)
        self.session = requests.Session()
        self.tools = None
    
    def get_tools(self) -> List[Any]:
        """
        Get available tools from the MCP server.
        """
        try:
            # Use a shorter timeout for the tools listing request
            response = self.session.get(f"{self.url}/tools", timeout=60)
            response.raise_for_status()
            tools_data = response.json()
            # Create Tool objects with callable functionality
            class Tool:
                def __init__(self, tool_data):
                    self.name = tool_data.get("name", "")
                    self.description = tool_data.get("description", "")
                    self.parameters = tool_data.get("parameters", {})
                    self.parent = None
                
                def __call__(self, **kwargs):
                    return self.parent._call_tool(self.name, kwargs)
                
                def __repr__(self):
                    return f"Tool(name='{self.name}')"
            # Create Tool objects from the response
            self.tools = []
            for tool_data in tools_data:
                tool = Tool(tool_data)
                tool.parent = self
                self.tools.append(tool)
            return self.tools            
        except requests.exceptions.Timeout:
            print(f"Timeout when fetching tools from {self.url}")
            return []
        except requests.RequestException as e:
            print(f"Error getting tools: {str(e)}")
            return []
    
    def _call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
        """
        Call a tool on the MCP server with extended timeout.
        """
        payload = {
            "name": tool_name,
            "arguments": arguments
        }        
        if self.transport.lower() == "sse":
            return self._call_tool_sse(payload)
        else:
            return self._call_tool_http(payload)
    
    def _call_tool_http(self, payload: Dict[str, Any]) -> Any:
        """Call a tool using HTTP transport with extended timeout."""
        try:
            print(f"Making HTTP request to {self.url}/run with timeout {self.timeout}s")
            
            # Use the extended timeout for the actual tool call
            response = self.session.post(
                f"{self.url}/run",
                json=payload,
                timeout=self.timeout  # Use the extended timeout
            )
            response.raise_for_status()
            return response.json().get("result")
        except requests.exceptions.Timeout:
            return f"Error: Request timed out after {self.timeout} seconds"
        except requests.RequestException as e:
            print(f"Error calling tool via HTTP: {str(e)}")
            return f"Error: {str(e)}"
    
    def _call_tool_sse(self, payload: Dict[str, Any]) -> Any:
        """Call a tool using Server-Sent Events (SSE) transport with extended timeout."""
        try:
            print(f"Making SSE request to {self.url}/run with timeout {self.timeout}s")
            
            # Use the extended timeout for SSE streaming
            response = self.session.post(
                f"{self.url}/run",
                json=payload,
                stream=True,
                headers={"Accept": "text/event-stream"},
                timeout=self.timeout  # Use the extended timeout
            )
            response.raise_for_status()
            
            # Process the response manually line by line with timeout awareness
            complete_result = ""
            for line in response.iter_lines(decode_unicode=True):
                if not line:
                    continue
                
                if line.startswith("data:"):
                    data = line[5:].strip()
                    try:
                        json_data = json.loads(data)
                        if "error" in json_data:
                            return f"Error: {json_data['error']}"
                        if "result" in json_data:
                            complete_result += json_data["result"]
                    except json.JSONDecodeError:
                        # If not valid JSON, just append the data
                        complete_result += data
                elif line.startswith("event:") and "done" in line:
                    break            
            return complete_result            
        except requests.exceptions.Timeout:
            return f"Error: Request timed out after {self.timeout} seconds"
        except requests.RequestException as e:
            print(f"Error calling tool via SSE: {str(e)}")
            return f"Error: {str(e)}"
    
    def disconnect(self):
        """Close the session."""
        if self.session:
            self.session.close()
            print("MCP client session closed.")

def query_mcp(prompt, mcp_server_url="http://localhost:11434", model="codellama", 
              transport="sse", timeout=36):
    """
    Query a model using the MCP protocol with extended request timeout.
    
    Args:
        prompt (str): The text prompt to send to the model
        mcp_server_url (str): URL of the MCP server
        model (str): Model name to use
        transport (str): Transport method ('http' or 'sse')
        timeout (int): Request timeout in seconds (default: 60)
    
    Returns:
        str: The model's response
    """
    mcp_client = None
    try:
        # Initialize MCP client with specified timeout
        mcp_client = RequestsMCPClient({
            "url": mcp_server_url, 
            "transport": transport,
            "timeout": timeout
        })        
        print(f"Connecting to MCP server at {mcp_server_url} with {timeout}s timeout")
        # Get available tools
        tools = mcp_client.get_tools()
        if not tools:
            return f"Error: No tools available from the MCP server at {mcp_server_url}"        
        # Find appropriate tool for the task
        predict_tool = next((t for t in tools if "predict" in t.name.lower()), None)
        generate_tool = next((t for t in tools if "generate" in t.name.lower()), None)
        complete_tool = next((t for t in tools if "complete" in t.name.lower()), None)
        
        # Use the first available tool from our priority list
        tool = predict_tool or generate_tool or complete_tool or tools[0]
        print(f"Using MCP tool: {tool.name}")        
        # Prepare arguments - adjust these based on the specific tool requirements
        tool_args = {
            "model": model,
            "prompt": prompt
        }        
        # Call the tool with our extended timeout
        start_time = time.time()
        print(f"Starting tool call at {time.strftime('%H:%M:%S')}")
        result = tool(**tool_args)
        elapsed = time.time() - start_time
        print(f"Tool call completed in {elapsed:.2f} seconds")
        return result        
    except Exception as e:
        error_message = f"Error using MCP client: {str(e)}"
        print(error_message)
        import traceback
        traceback.print_exc()
        return error_message
    finally:
        if mcp_client:
            try:
                mcp_client.disconnect()
                print("MCP client disconnected")
            except Exception as e_disconnect:
                print(f"Error disconnecting MCP client: {e_disconnect}")

# Example usage
# if __name__ == "__main__":
#     response = query_mcp(
#         prompt="Write a complex function in Python that processes large datasets",
#         mcp_server_url="http://localhost:11434",
#         model="codellama",
#         transport="sse",
#         timeout=300  # 5 min timeout for long-running tasks
#     )
#     print("Response:", response)