UnlimitedMusicGen / modules /mcp_client.py
Surn's picture
mcp_client v1
4e9b4e7
import os
import json
import time
import requests
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
class RequestsMCPClient:
"""
A requests-based implementation of an MCP client with configurable request timeout.
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the MCP client with extended request timeout.
Args:
config: Dictionary with configuration parameters.
'url': The URL of the MCP server.
'transport': The transport method ('sse' or 'http').
'timeout': Request timeout in seconds (default: 3600 seconds / 1 hour).
"""
self.url = config.get("url", "http://localhost:11434")
self.transport = config.get("transport", "sse")
# Set an extended timeout - default 1 hour
self.timeout = config.get("timeout", 60)
self.session = requests.Session()
self.tools = None
def get_tools(self) -> List[Any]:
"""
Get available tools from the MCP server.
"""
try:
# Use a shorter timeout for the tools listing request
response = self.session.get(f"{self.url}/tools", timeout=60)
response.raise_for_status()
tools_data = response.json()
# Create Tool objects with callable functionality
class Tool:
def __init__(self, tool_data):
self.name = tool_data.get("name", "")
self.description = tool_data.get("description", "")
self.parameters = tool_data.get("parameters", {})
self.parent = None
def __call__(self, **kwargs):
return self.parent._call_tool(self.name, kwargs)
def __repr__(self):
return f"Tool(name='{self.name}')"
# Create Tool objects from the response
self.tools = []
for tool_data in tools_data:
tool = Tool(tool_data)
tool.parent = self
self.tools.append(tool)
return self.tools
except requests.exceptions.Timeout:
print(f"Timeout when fetching tools from {self.url}")
return []
except requests.RequestException as e:
print(f"Error getting tools: {str(e)}")
return []
def _call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""
Call a tool on the MCP server with extended timeout.
"""
payload = {
"name": tool_name,
"arguments": arguments
}
if self.transport.lower() == "sse":
return self._call_tool_sse(payload)
else:
return self._call_tool_http(payload)
def _call_tool_http(self, payload: Dict[str, Any]) -> Any:
"""Call a tool using HTTP transport with extended timeout."""
try:
print(f"Making HTTP request to {self.url}/run with timeout {self.timeout}s")
# Use the extended timeout for the actual tool call
response = self.session.post(
f"{self.url}/run",
json=payload,
timeout=self.timeout # Use the extended timeout
)
response.raise_for_status()
return response.json().get("result")
except requests.exceptions.Timeout:
return f"Error: Request timed out after {self.timeout} seconds"
except requests.RequestException as e:
print(f"Error calling tool via HTTP: {str(e)}")
return f"Error: {str(e)}"
def _call_tool_sse(self, payload: Dict[str, Any]) -> Any:
"""Call a tool using Server-Sent Events (SSE) transport with extended timeout."""
try:
print(f"Making SSE request to {self.url}/run with timeout {self.timeout}s")
# Use the extended timeout for SSE streaming
response = self.session.post(
f"{self.url}/run",
json=payload,
stream=True,
headers={"Accept": "text/event-stream"},
timeout=self.timeout # Use the extended timeout
)
response.raise_for_status()
# Process the response manually line by line with timeout awareness
complete_result = ""
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data:"):
data = line[5:].strip()
try:
json_data = json.loads(data)
if "error" in json_data:
return f"Error: {json_data['error']}"
if "result" in json_data:
complete_result += json_data["result"]
except json.JSONDecodeError:
# If not valid JSON, just append the data
complete_result += data
elif line.startswith("event:") and "done" in line:
break
return complete_result
except requests.exceptions.Timeout:
return f"Error: Request timed out after {self.timeout} seconds"
except requests.RequestException as e:
print(f"Error calling tool via SSE: {str(e)}")
return f"Error: {str(e)}"
def disconnect(self):
"""Close the session."""
if self.session:
self.session.close()
print("MCP client session closed.")
def query_mcp(prompt, mcp_server_url="http://localhost:11434", model="codellama",
transport="sse", timeout=36):
"""
Query a model using the MCP protocol with extended request timeout.
Args:
prompt (str): The text prompt to send to the model
mcp_server_url (str): URL of the MCP server
model (str): Model name to use
transport (str): Transport method ('http' or 'sse')
timeout (int): Request timeout in seconds (default: 60)
Returns:
str: The model's response
"""
mcp_client = None
try:
# Initialize MCP client with specified timeout
mcp_client = RequestsMCPClient({
"url": mcp_server_url,
"transport": transport,
"timeout": timeout
})
print(f"Connecting to MCP server at {mcp_server_url} with {timeout}s timeout")
# Get available tools
tools = mcp_client.get_tools()
if not tools:
return f"Error: No tools available from the MCP server at {mcp_server_url}"
# Find appropriate tool for the task
predict_tool = next((t for t in tools if "predict" in t.name.lower()), None)
generate_tool = next((t for t in tools if "generate" in t.name.lower()), None)
complete_tool = next((t for t in tools if "complete" in t.name.lower()), None)
# Use the first available tool from our priority list
tool = predict_tool or generate_tool or complete_tool or tools[0]
print(f"Using MCP tool: {tool.name}")
# Prepare arguments - adjust these based on the specific tool requirements
tool_args = {
"model": model,
"prompt": prompt
}
# Call the tool with our extended timeout
start_time = time.time()
print(f"Starting tool call at {time.strftime('%H:%M:%S')}")
result = tool(**tool_args)
elapsed = time.time() - start_time
print(f"Tool call completed in {elapsed:.2f} seconds")
return result
except Exception as e:
error_message = f"Error using MCP client: {str(e)}"
print(error_message)
import traceback
traceback.print_exc()
return error_message
finally:
if mcp_client:
try:
mcp_client.disconnect()
print("MCP client disconnected")
except Exception as e_disconnect:
print(f"Error disconnecting MCP client: {e_disconnect}")
# Example usage
# if __name__ == "__main__":
# response = query_mcp(
# prompt="Write a complex function in Python that processes large datasets",
# mcp_server_url="http://localhost:11434",
# model="codellama",
# transport="sse",
# timeout=300 # 5 min timeout for long-running tasks
# )
# print("Response:", response)