Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import json | |
| import base64 | |
| from typing import List, Dict, Any, Union | |
| from contextlib import AsyncExitStack | |
| from io import BytesIO | |
| from PIL import Image | |
| import gradio as gr | |
| from gradio.components.chatbot import ChatMessage | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| from dotenv import load_dotenv | |
| from langchain_openai import ChatOpenAI | |
| load_dotenv() | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| class MCPClientWrapper: | |
| def __init__(self): | |
| self.session = None | |
| self.exit_stack = None | |
| self.mistral = ChatOpenAI(model_name="mistralai/mistral-small", temperature=0.7, openai_api_key=os.getenv("OPENROUTER_API_KEY"), openai_api_base=os.getenv("OPENROUTER_API_BASE_URL")) | |
| self.tools = [] | |
| def connect(self, server_path: str) -> str: | |
| return loop.run_until_complete(self._connect(server_path)) | |
| async def _connect(self, server_path: str) -> str: | |
| if self.exit_stack: | |
| await self.exit_stack.aclose() | |
| self.exit_stack = AsyncExitStack() | |
| is_python = server_path.endswith('.py') | |
| command = "python" if is_python else "node" | |
| server_params = StdioServerParameters( | |
| command=command, | |
| args=[server_path], | |
| env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"} | |
| ) | |
| stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | |
| self.stdio, self.write = stdio_transport | |
| self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) | |
| await self.session.initialize() | |
| response = await self.session.list_tools() | |
| self.tools = [{ | |
| "name": tool.name, | |
| "description": tool.description, | |
| "input_schema": tool.inputSchema | |
| } for tool in response.tools] | |
| tool_names = [tool["name"] for tool in self.tools] | |
| return f"Connected to MCP server. Available tools: {', '.join(tool_names)}" | |
| def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple: | |
| if not self.session: | |
| return history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": "Please connect to an MCP server first."} | |
| ], gr.Textbox(value="") | |
| new_messages = loop.run_until_complete(self._process_query(message, history)) | |
| return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="") | |
| async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]): | |
| claude_messages = [] | |
| for msg in history: | |
| if isinstance(msg, ChatMessage): | |
| role, content = msg.role, msg.content | |
| else: | |
| role, content = msg.get("role"), msg.get("content") | |
| if role in ["user", "assistant", "system"]: | |
| claude_messages.append({"role": role, "content": content}) | |
| claude_messages.append({"role": "user", "content": message}) | |
| response = self.mistral.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=1000, | |
| messages=claude_messages, | |
| tools=self.tools | |
| ) | |
| result_messages = [] | |
| for content in response.content: | |
| if content.type == 'text': | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": content.text | |
| }) | |
| elif content.type == 'tool_use': | |
| tool_name = content.name | |
| tool_args = content.input | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": f"I'll use the {tool_name} tool to help answer your question.", | |
| "metadata": { | |
| "title": f"Using tool: {tool_name}", | |
| "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", | |
| "status": "pending", | |
| "id": f"tool_call_{tool_name}" | |
| } | |
| }) | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```", | |
| "metadata": { | |
| "parent_id": f"tool_call_{tool_name}", | |
| "id": f"params_{tool_name}", | |
| "title": "Tool Parameters" | |
| } | |
| }) | |
| result = await self.session.call_tool(tool_name, tool_args) | |
| if result_messages and "metadata" in result_messages[-2]: | |
| result_messages[-2]["metadata"]["status"] = "done" | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "Here are the results from the tool:", | |
| "metadata": { | |
| "title": f"Tool Result for {tool_name}", | |
| "status": "done", | |
| "id": f"result_{tool_name}" | |
| } | |
| }) | |
| result_content = result.content | |
| if isinstance(result_content, list): | |
| result_content = "\n".join(str(item) for item in result_content) | |
| try: | |
| result_json = json.loads(result_content) | |
| if isinstance(result_json, dict) and "type" in result_json: | |
| if result_json["type"] == "image" and "url" in result_json: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")}, | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"image_{tool_name}", | |
| "title": "Generated Image" | |
| } | |
| }) | |
| else: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```\n" + result_content + "\n```", | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"raw_result_{tool_name}", | |
| "title": "Raw Output" | |
| } | |
| }) | |
| except: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```\n" + result_content + "\n```", | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"raw_result_{tool_name}", | |
| "title": "Raw Output" | |
| } | |
| }) | |
| claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"}) | |
| next_response = self.mistral.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=1000, | |
| messages=claude_messages, | |
| ) | |
| if next_response.content and next_response.content[0].type == 'text': | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": next_response.content[0].text | |
| }) | |
| return result_messages | |
| # New methods for image processing | |
| def image_to_base64(self, image): | |
| """Convert PIL image to base64 string""" | |
| if image is None: | |
| return None | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| async def process_image(self, image, operation, target_format=None, width=None, height=None): | |
| """Process an image using MCP tools""" | |
| if not self.session: | |
| return None, "Please connect to an MCP server first." | |
| if image is None: | |
| return None, "No image provided." | |
| try: | |
| img_base64 = self.image_to_base64(image) | |
| if operation == "Remove Background": | |
| result = await self.session.call_tool("remove_background_from_url", {"url": img_base64}) | |
| elif operation == "Change Format": | |
| if not target_format: | |
| return None, "Please select a target format." | |
| result = await self.session.call_tool("change_format", { | |
| "image_base64": img_base64, | |
| "target_format": target_format.lower() | |
| }) | |
| elif operation == "Resize Image": | |
| if not width or not height: | |
| return None, "Please provide width and height." | |
| result = await self.session.call_tool("resize_image", { | |
| "image_base64": img_base64, | |
| "width": int(width), | |
| "height": int(height) | |
| }) | |
| elif operation == "Visualize Image": | |
| result = await self.session.call_tool("visualize_base64_image", {"image_base64": img_base64}) | |
| else: | |
| return None, "Unknown operation." | |
| # Process the result | |
| result_content = result.content | |
| if isinstance(result_content, str): | |
| try: | |
| result_data = json.loads(result_content) | |
| if "image_base64" in result_data: | |
| # Convert result base64 back to image | |
| img_data = base64.b64decode(result_data["image_base64"]) | |
| result_img = Image.open(BytesIO(img_data)) | |
| return result_img, "Image processed successfully." | |
| else: | |
| return None, f"Unexpected result format: {result_content}" | |
| except json.JSONDecodeError: | |
| return None, f"Error decoding result: {result_content}" | |
| else: | |
| return None, f"Unexpected result type: {type(result_content)}" | |
| except Exception as e: | |
| return None, f"Error processing image: {str(e)}" | |
| client = MCPClientWrapper() | |
| def gradio_interface(): | |
| with gr.Blocks(title="MCP Assistant") as demo: | |
| gr.Markdown("# MCP Assistant") | |
| gr.Markdown("Connect to your MCP server to chat or process images") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| server_path = gr.Textbox( | |
| label="Server Script Path", | |
| placeholder="Enter path to server script", | |
| value="mcp_server.py" | |
| ) | |
| with gr.Column(scale=1): | |
| connect_btn = gr.Button("Connect") | |
| status = gr.Textbox(label="Connection Status", interactive=False) | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Chat Interface"): | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=500, | |
| type="messages", | |
| show_copy_button=True, | |
| avatar_images=("👤", "🤖") | |
| ) | |
| with gr.Row(equal_height=True): | |
| msg = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about the available tools or how to process images", | |
| scale=4 | |
| ) | |
| clear_btn = gr.Button("Clear Chat", scale=1) | |
| with gr.TabItem("Image Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| operation = gr.Radio( | |
| ["Remove Background", "Change Format", "Resize Image", "Visualize Image"], | |
| label="Select Operation", | |
| value="Visualize Image" | |
| ) | |
| with gr.Group() as format_options: | |
| target_format = gr.Dropdown( | |
| ["png", "jpeg", "webp"], | |
| label="Target Format", | |
| value="png", | |
| visible=False | |
| ) | |
| with gr.Group() as resize_options: | |
| with gr.Row(): | |
| width = gr.Number(label="Width", value=300, visible=False) | |
| height = gr.Number(label="Height", value=300, visible=False) | |
| process_btn = gr.Button("Process Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Processed Image") | |
| output_message = gr.Textbox(label="Status") | |
| # Connect to server | |
| connect_btn.click(client.connect, inputs=server_path, outputs=status) | |
| # Chat functionality | |
| msg.submit(client.process_message, [msg, chatbot], [chatbot, msg]) | |
| clear_btn.click(lambda: [], None, chatbot) | |
| # Image processing functionality | |
| def update_options(op): | |
| return { | |
| target_format: op == "Change Format", | |
| width: op == "Resize Image", | |
| height: op == "Resize Image" | |
| } | |
| operation.change(update_options, inputs=operation, outputs=[target_format, width, height]) | |
| def process_image_wrapper(image, operation, target_format, width, height): | |
| return loop.run_until_complete(client.process_image(image, operation, target_format, width, height)) | |
| process_btn.click( | |
| process_image_wrapper, | |
| inputs=[input_image, operation, target_format, width, height], | |
| outputs=[output_image, output_message] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| if not os.getenv("OPENROUTER_API_KEY"): | |
| print("Warning: OPENROUTER_API_KEY not found in environment. Please set it in your .env file.") | |
| interface = gradio_interface() | |
| interface.launch(debug=True) |