Spaces:
Running
Running
| from __future__ import annotations | |
| import dataclasses | |
| import enum | |
| import os | |
| from collections import OrderedDict | |
| from collections.abc import Mapping, Sequence | |
| from pathlib import Path | |
| from types import MappingProxyType | |
| from typing import TYPE_CHECKING, Any | |
| import boto3 | |
| import botocore | |
| import botocore.exceptions | |
| import gradio as gr | |
| import gradio.themes as gr_themes | |
| import markdown | |
| from langchain_aws import ChatBedrock | |
| from langchain_core.callbacks import BaseCallbackHandler | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langchain_core.tools import BaseTool | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_openai import AzureChatOpenAI | |
| from langgraph.prebuilt import create_react_agent | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletion | |
| from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry | |
| if TYPE_CHECKING: | |
| from langgraph.graph.graph import CompiledGraph | |
| #### Constants #### | |
| class AgentType(str, enum.Enum): | |
| """TDAgent type.""" | |
| DATA_ENRICHER = "Data enricher" | |
| INCIDENT_HANDLER = "Incident handler" | |
| PEN_TESTER = "PenTester" | |
| def __str__(self) -> str: # noqa: D105 | |
| return self.value | |
| AGENT_SYSTEM_MESSAGES = OrderedDict( | |
| ( | |
| ( | |
| AgentType.DATA_ENRICHER, | |
| """ | |
| You are a cybersecurity incidence data enriching assistant. Analysts | |
| will present information about security incidents and you must use | |
| all the tools at your disposal to enrich the data as much as possible. | |
| """.strip(), | |
| ), | |
| ( | |
| AgentType.INCIDENT_HANDLER, | |
| """ | |
| You are a security analyst assistant responsible for collecting, analyzing | |
| and disseminating actionable intelligence related to cyber threats, | |
| vulnerabilities and threat actors. | |
| When presented with potential incidents information or tickets, you should | |
| evaluate the presented evidence, gather additional data using any tool at | |
| your disposal and take corrective actions if possible. | |
| Afterwards, generate a cybersecurity report including: key findings, challenges, | |
| actions taken and recommendations. | |
| Never use external means of communication, like emails or SMS, unless | |
| instructed to do so. | |
| """.strip(), | |
| ), | |
| ( | |
| AgentType.PEN_TESTER, | |
| """ | |
| You are a cybersecurity pentester. You use tools to analyze domain to try to discover system vulnerabilities. | |
| Always report you findings and suggest next steps to deep dive where applicable. | |
| """.strip(), | |
| ), | |
| ), | |
| ) | |
| GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType( | |
| { | |
| "user": HumanMessage, | |
| "assistant": AIMessage, | |
| }, | |
| ) | |
| MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order | |
| ( | |
| ( | |
| "HuggingFace", | |
| { | |
| "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", | |
| # "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference | |
| "Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct", | |
| # "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501 | |
| # "Deepseek V3": "deepseek-ai/DeepSeek-V3", | |
| }, | |
| ), | |
| ( | |
| "AWS Bedrock", | |
| { | |
| "Anthropic Claude 3.5 Sonnet (EU)": ( | |
| "eu.anthropic.claude-3-5-sonnet-20240620-v1:0" | |
| ), | |
| "Anthropic Claude 3.7 Sonnet": ( | |
| "anthropic.claude-3-7-sonnet-20250219-v1:0" | |
| ), | |
| "Claude Sonnet 4": ( | |
| "anthropic.claude-sonnet-4-20250514-v1:0" | |
| ), | |
| }, | |
| ), | |
| ( | |
| "Azure OpenAI", | |
| { | |
| "GPT-4o": ("ggpt-4o-global-standard"), | |
| "GPT-4o Mini": ("o4-mini"), | |
| "GPT-4.5 Preview": ("gpt-4.5-preview"), | |
| }, | |
| ), | |
| ), | |
| ) | |
| CONNECT_STATE_DEFAULT = gr.State() | |
| class ToolInvocationInfo: | |
| """Information related to a tool invocation by the LLM.""" | |
| name: str | |
| inputs: Mapping[str, Any] | |
| class ToolsTracerCallback(BaseCallbackHandler): | |
| """Callback that registers tools invoked by the Agent.""" | |
| def __init__(self) -> None: | |
| self._tools_trace: list[ToolInvocationInfo] = [] | |
| def on_tool_start( # noqa: D102 | |
| self, | |
| serialized: dict[str, Any], | |
| *args: Any, | |
| inputs: dict[str, Any] | None = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| self._tools_trace.append( | |
| ToolInvocationInfo( | |
| name=serialized.get("name", "<unknown-function-name>"), | |
| inputs=inputs if inputs else {}, | |
| ), | |
| ) | |
| return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs) | |
| def tools_trace(self) -> Sequence[ToolInvocationInfo]: | |
| """Tools trace information.""" | |
| return self._tools_trace | |
| def clear(self) -> None: | |
| """Clear tools trace.""" | |
| self._tools_trace.clear() | |
| #### Shared variables #### | |
| llm_agent: CompiledGraph | None = None | |
| llm_tools_tracer: ToolsTracerCallback | None = None | |
| #### Utility functions #### | |
| ## Bedrock LLM creation ## | |
| def create_bedrock_llm( | |
| bedrock_model_id: str, | |
| aws_access_key: str, | |
| aws_secret_key: str, | |
| aws_session_token: str, | |
| aws_region: str, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[ChatBedrock | None, str]: | |
| """Create a LangGraph Bedrock agent.""" | |
| boto3_config = { | |
| "aws_access_key_id": aws_access_key, | |
| "aws_secret_access_key": aws_secret_key, | |
| "aws_session_token": aws_session_token if aws_session_token else None, | |
| "region_name": aws_region, | |
| } | |
| # Verify credentials | |
| try: | |
| sts = boto3.client("sts", **boto3_config) | |
| sts.get_caller_identity() | |
| except botocore.exceptions.ClientError as err: | |
| return None, str(err) | |
| try: | |
| bedrock_client = boto3.client("bedrock-runtime", **boto3_config) | |
| llm = ChatBedrock( | |
| model=bedrock_model_id, | |
| client=bedrock_client, | |
| model_kwargs={"temperature": temperature, "max_tokens": max_tokens}, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| ## Hugging Face LLM creation ## | |
| def create_hf_llm( | |
| hf_model_id: str, | |
| huggingfacehub_api_token: str | None = None, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[ChatHuggingFace | None, str]: | |
| """Create a LangGraph Hugging Face agent.""" | |
| try: | |
| llm = HuggingFaceEndpoint( | |
| model=hf_model_id, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| task="text-generation", | |
| huggingfacehub_api_token=huggingfacehub_api_token, | |
| ) | |
| chat_llm = ChatHuggingFace(llm=llm) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return chat_llm, "" | |
| ## OpenAI LLM creation ## | |
| def create_openai_llm( | |
| model_id: str, | |
| token_id: str, | |
| ) -> tuple[ChatCompletion | None, str]: | |
| """Create a LangGraph OpenAI agent.""" | |
| try: | |
| client = OpenAI( | |
| base_url="https://api.studio.nebius.com/v1/", | |
| api_key=token_id, | |
| ) | |
| llm = client.chat.completions.create( | |
| messages=[], # needs to be fixed | |
| model=model_id, | |
| max_tokens=512, | |
| temperature=0.8, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| def create_azure_llm( | |
| model_id: str, | |
| api_version: str, | |
| endpoint: str, | |
| token_id: str, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[AzureChatOpenAI | None, str]: | |
| """Create a LangGraph Azure OpenAI agent.""" | |
| try: | |
| os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint | |
| os.environ["AZURE_OPENAI_API_KEY"] = token_id | |
| if "o4-mini" in model_id: | |
| kwargs = {"max_completion_tokens": max_tokens} | |
| else: | |
| kwargs = {"max_tokens": max_tokens} | |
| llm = AzureChatOpenAI( | |
| azure_deployment=model_id, | |
| api_key=token_id, | |
| api_version=api_version, | |
| temperature=temperature, | |
| **kwargs, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| #### UI functionality #### | |
| async def gr_fetch_mcp_tools( | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| *, | |
| trace_tools: bool, | |
| ) -> list[BaseTool]: | |
| """Fetch tools from MCP servers.""" | |
| global llm_tools_tracer # noqa: PLW0603 | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| if trace_tools: | |
| llm_tools_tracer = ToolsTracerCallback() | |
| for tool in tools: | |
| if tool.callbacks is None: | |
| tool.callbacks = [llm_tools_tracer] | |
| elif isinstance(tool.callbacks, list): | |
| tool.callbacks.append(llm_tools_tracer) | |
| else: | |
| tool.callbacks.add_handler(llm_tools_tracer) | |
| else: | |
| llm_tools_tracer = None | |
| return tools | |
| return [] | |
| def gr_make_system_message( | |
| agent_type: AgentType, | |
| ) -> SystemMessage: | |
| """Make agent's system message.""" | |
| try: | |
| system_msg = AGENT_SYSTEM_MESSAGES[agent_type] | |
| except KeyError as err: | |
| raise gr.Error(f"Unknown agent type '{agent_type}'") from err | |
| return SystemMessage(system_msg) | |
| async def gr_connect_to_bedrock( # noqa: PLR0913 | |
| model_id: str, | |
| access_key: str, | |
| secret_key: str, | |
| session_token: str, | |
| region: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| agent_type: AgentType, | |
| trace_tool_calls: bool, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Bedrock agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| CONNECT_STATE_DEFAULT.value = True | |
| if not access_key or not secret_key: | |
| return "❌ Please provide both Access Key ID and Secret Access Key" | |
| llm, error = create_bedrock_llm( | |
| model_id, | |
| access_key.strip(), | |
| secret_key.strip(), | |
| session_token.strip(), | |
| region, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=await gr_fetch_mcp_tools( | |
| mcp_servers, | |
| trace_tools=trace_tool_calls, | |
| ), | |
| prompt=gr_make_system_message(agent_type=agent_type), | |
| ) | |
| return "✅ Successfully connected to AWS Bedrock!" | |
| async def gr_connect_to_hf( | |
| model_id: str, | |
| hf_access_token_textbox: str | None, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| agent_type: AgentType, | |
| trace_tool_calls: bool, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| CONNECT_STATE_DEFAULT.value = True | |
| llm, error = create_hf_llm( | |
| model_id, | |
| hf_access_token_textbox, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=await gr_fetch_mcp_tools( | |
| mcp_servers, | |
| trace_tools=trace_tool_calls, | |
| ), | |
| prompt=gr_make_system_message(agent_type=agent_type), | |
| ) | |
| return "✅ Successfully connected to Hugging Face!" | |
| async def gr_connect_to_azure( # noqa: PLR0913 | |
| model_id: str, | |
| azure_endpoint: str, | |
| api_key: str, | |
| api_version: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| agent_type: AgentType, | |
| trace_tool_calls: bool, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| CONNECT_STATE_DEFAULT.value = True | |
| llm, error = create_azure_llm( | |
| model_id, | |
| api_version=api_version, | |
| endpoint=azure_endpoint, | |
| token_id=api_key, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls), | |
| prompt=gr_make_system_message(agent_type=agent_type), | |
| ) | |
| return "✅ Successfully connected to Azure OpenAI!" | |
| # async def gr_connect_to_nebius( | |
| # model_id: str, | |
| # nebius_access_token_textbox: str, | |
| # mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| # ) -> str: | |
| # """Initialize Hugging Face agent.""" | |
| # global llm_agent | |
| # connected_state.value = True | |
| # llm, error = create_openai_llm(model_id, nebius_access_token_textbox) | |
| # if llm is None: | |
| # return f"❌ Connection failed: {error}" | |
| # tools = [] | |
| # if mcp_servers: | |
| # client = MultiServerMCPClient( | |
| # { | |
| # server.name.replace(" ", "-"): { | |
| # "url": server.value, | |
| # "transport": "sse", | |
| # } | |
| # for server in mcp_servers | |
| # }, | |
| # ) | |
| # tools = await client.get_tools() | |
| # llm_agent = create_react_agent( | |
| # model=str(llm), | |
| # tools=tools, | |
| # prompt=SYSTEM_MESSAGE, | |
| # ) | |
| # return "✅ Successfully connected to nebius!" | |
| with open("exfiltration_ticket.txt") as fhandle: # noqa: PTH123 | |
| exfiltration_ticket = fhandle.read() | |
| with open("sample_kali_linux_1.txt") as fhandle1: # noqa: PTH123 | |
| service_discovery_ticket = fhandle1.read() | |
| async def gr_chat_function( # noqa: D103 | |
| message: str, | |
| history: list[Mapping[str, str]], | |
| ) -> str: | |
| if llm_agent is None: | |
| return "Please configure your credentials first." | |
| messages = [] | |
| for hist_msg in history: | |
| role = hist_msg["role"] | |
| message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role] | |
| messages.append(message_type(content=hist_msg["content"])) | |
| messages.append(HumanMessage(content=message)) | |
| try: | |
| if llm_tools_tracer is not None: | |
| llm_tools_tracer.clear() | |
| llm_response = await llm_agent.ainvoke( | |
| { | |
| "messages": messages, | |
| }, | |
| ) | |
| return _add_tools_trace_to_message( | |
| llm_response["messages"][-1].content, | |
| ) | |
| except Exception as err: | |
| raise gr.Error( | |
| f"We encountered an error while invoking the model:\n{err}", | |
| print_exception=True, | |
| ) from err | |
| def _add_tools_trace_to_message(message: str) -> str: | |
| if not llm_tools_tracer or not llm_tools_tracer.tools_trace: | |
| return message | |
| import json | |
| traces = [] | |
| for index, tool_info in enumerate(llm_tools_tracer.tools_trace): | |
| trace_msg = f" {index}. {tool_info.name}" | |
| if tool_info.inputs: | |
| trace_msg += "\n" | |
| trace_msg += " * Arguments:\n" | |
| trace_msg += " ```json\n" | |
| trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n" | |
| trace_msg += " ```\n" | |
| traces.append(trace_msg) | |
| return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces) | |
| def _read_markdown_body_as_html(path: str = "README.md") -> str: | |
| with Path(path).open(encoding="utf-8") as f: # Default mode is "r" | |
| lines = f.readlines() | |
| # Skip YAML front matter if present | |
| if lines and lines[0].strip() == "---": | |
| for i in range(1, len(lines)): | |
| if lines[i].strip() == "---": | |
| lines = lines[i + 1 :] # skip metadata block | |
| break | |
| markdown_body = "".join(lines).strip() | |
| return markdown.markdown(markdown_body) | |
| ## UI components ## | |
| custom_css = """ | |
| .main-header { | |
| background: linear-gradient(135deg, #00a388 0%, #ffae00 100%); | |
| padding: 30px; | |
| border-radius: 5px; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| } | |
| """ | |
| with ( | |
| gr.Blocks( | |
| theme=gr_themes.Origin( | |
| primary_hue="teal", | |
| spacing_size="sm", | |
| font="sans-serif", | |
| ), | |
| title="TDAgent", | |
| fill_height=True, | |
| fill_width=True, | |
| css=custom_css, | |
| ) as gr_app, | |
| ): | |
| gr.HTML( | |
| """ | |
| <div class="main-header"> | |
| <h1>👩💻 TDAgentTools & TDAgent 👨💻</h1> | |
| <p style="font-size: 1.2em; margin: 10px 0 0 0;"> | |
| Empowering Cybersecurity with Agentic AI | |
| </p> | |
| </div> | |
| """, | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("About"), gr.Row(): | |
| html_content = _read_markdown_body_as_html("README.md") | |
| gr.Markdown(html_content) | |
| with gr.TabItem("TDAgent"), gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("🔌 MCP Servers", open=False): | |
| mcp_list = MutableCheckBoxGroup( | |
| values=[ | |
| MutableCheckBoxGroupEntry( | |
| name="TDAgent tools", | |
| value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
| ), | |
| ], | |
| label="MCP Servers", | |
| new_value_label="MCP endpoint", | |
| new_name_label="MCP endpoint name", | |
| new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse", | |
| new_name_placeholder="Swiss army knife of MCPs", | |
| ) | |
| with gr.Accordion("⚙️ Provider Configuration", open=True): | |
| model_provider = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value=None, | |
| label="Select Model Provider", | |
| ) | |
| ## Amazon Bedrock Configuration ## | |
| with gr.Group(visible=False) as aws_bedrock_conf_group: | |
| aws_access_key_textbox = gr.Textbox( | |
| label="AWS Access Key ID", | |
| type="password", | |
| placeholder="Enter your AWS Access Key ID", | |
| ) | |
| aws_secret_key_textbox = gr.Textbox( | |
| label="AWS Secret Access Key", | |
| type="password", | |
| placeholder="Enter your AWS Secret Access Key", | |
| ) | |
| aws_region_dropdown = gr.Dropdown( | |
| label="AWS Region", | |
| choices=[ | |
| "us-east-1", | |
| "us-west-2", | |
| "eu-west-1", | |
| "eu-central-1", | |
| "ap-southeast-1", | |
| ], | |
| value="eu-west-1", | |
| ) | |
| aws_session_token_textbox = gr.Textbox( | |
| label="AWS Session Token", | |
| type="password", | |
| placeholder="Enter your AWS session token", | |
| ) | |
| ## Huggingface Configuration ## | |
| with gr.Group(visible=False) as hf_conf_group: | |
| hf_token = gr.Textbox( | |
| label="HuggingFace Token", | |
| type="password", | |
| placeholder="Enter your Hugging Face Access Token", | |
| ) | |
| ## Azure Configuration ## | |
| with gr.Group(visible=False) as azure_conf_group: | |
| azure_endpoint = gr.Textbox( | |
| label="Azure OpenAI Endpoint", | |
| type="text", | |
| placeholder="Enter your Azure OpenAI Endpoint", | |
| ) | |
| azure_api_token = gr.Textbox( | |
| label="Azure Access Token", | |
| type="password", | |
| placeholder="Enter your Azure OpenAI Access Token", | |
| ) | |
| azure_api_version = gr.Textbox( | |
| label="Azure OpenAI API Version", | |
| type="text", | |
| placeholder="Enter your Azure OpenAI API Version", | |
| value="2024-12-01-preview", | |
| ) | |
| with gr.Accordion("🧠 Model Configuration", open=True): | |
| model_id_dropdown = gr.Dropdown( | |
| label="Select known model id or type your own below", | |
| choices=[], | |
| visible=False, | |
| ) | |
| model_id_textbox = gr.Textbox( | |
| label="Model ID", | |
| type="text", | |
| placeholder="Enter the model ID", | |
| visible=False, | |
| interactive=True, | |
| ) | |
| # Agent configuration options | |
| with gr.Group(): | |
| agent_system_message_radio = gr.Radio( | |
| choices=list(AGENT_SYSTEM_MESSAGES.keys()), | |
| value=next(iter(AGENT_SYSTEM_MESSAGES.keys())), | |
| label="Agent type", | |
| info=( | |
| "Changes the system message to pre-condition the agent" | |
| " to act in a desired way." | |
| ), | |
| ) | |
| agent_trace_tools_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Trace tool calls", | |
| info=( | |
| "Add the invoked tools trace at the end of the" | |
| " message" | |
| ), | |
| ) | |
| # Initialize the temperature and max tokens based on model specs | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| ) | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=128, | |
| maximum=8192, | |
| value=2048, | |
| step=64, | |
| ) | |
| connect_aws_bedrock_btn = gr.Button( | |
| "🔌 Connect to Bedrock", | |
| variant="primary", | |
| visible=False, | |
| ) | |
| connect_hf_btn = gr.Button( | |
| "🔌 Connect to Huggingface 🤗", | |
| variant="primary", | |
| visible=False, | |
| ) | |
| connect_azure_btn = gr.Button( | |
| "🔌 Connect to Azure", | |
| variant="primary", | |
| visible=False, | |
| ) | |
| status_textbox = gr.Textbox( | |
| label="Connection Status", | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=2): | |
| chat_interface = gr.ChatInterface( | |
| fn=gr_chat_function, | |
| type="messages", | |
| examples=[exfiltration_ticket, service_discovery_ticket], | |
| example_labels=[ | |
| "Enrich & Handle exfiltration ticket 🕵️♂️", | |
| "Handle service discovery ticket 🤖💻"], | |
| description="A simple threat analyst agent with MCP tools.", | |
| ) | |
| with gr.TabItem("Demo"): | |
| gr.Markdown( | |
| """ | |
| This is a demo of TDAgent, a simple threat analyst agent with MCP tools. | |
| You can configure the agent to use different LLM providers and connect to | |
| various MCP servers to access tools. | |
| """, | |
| ) | |
| gr.HTML( | |
| """<iframe width="560" height="315" src="https://www.youtube.com/embed/C6Z9EOW-3lE" frameborder="0" allowfullscreen></iframe>""", # noqa: E501 | |
| ) | |
| ## UI Events ## | |
| def _toggle_model_choices_ui( | |
| provider: str, | |
| ) -> dict[str, Any]: | |
| if provider in MODEL_OPTIONS: | |
| model_choices = list(MODEL_OPTIONS[provider].keys()) | |
| return gr.update( | |
| choices=model_choices, | |
| value=model_choices[0], | |
| visible=True, | |
| interactive=True, | |
| ) | |
| return gr.update(choices=[], visible=False) | |
| def _toggle_model_aws_bedrock_conf_ui( | |
| provider: str, | |
| ) -> tuple[dict[str, Any], ...]: | |
| is_aws = provider == "AWS Bedrock" | |
| return gr.update(visible=is_aws), gr.update(visible=is_aws) | |
| def _toggle_model_hf_conf_ui( | |
| provider: str, | |
| ) -> tuple[dict[str, Any], ...]: | |
| is_hf = provider == "HuggingFace" | |
| return gr.update(visible=is_hf), gr.update(visible=is_hf) | |
| def _toggle_model_azure_conf_ui( | |
| provider: str, | |
| ) -> tuple[dict[str, Any], ...]: | |
| is_azure = provider == "Azure OpenAI" | |
| return gr.update(visible=is_azure), gr.update(visible=is_azure) | |
| # Initialize a flag to check if connected | |
| def _on_change_model_configuration(*args: str) -> Any: # noqa: ARG001 | |
| # If model configuration changes after connecting, issue a warning | |
| if CONNECT_STATE_DEFAULT.value: | |
| CONNECT_STATE_DEFAULT.value = False # Reset the state | |
| return gr.Warning( | |
| "When changing model configuration, you need to reconnect.", | |
| duration=5, | |
| ) | |
| return gr.update() | |
| ## Connect Event Listeners ## | |
| model_provider.change( | |
| _toggle_model_choices_ui, | |
| inputs=[model_provider], | |
| outputs=[model_id_dropdown], | |
| ) | |
| model_provider.change( | |
| _toggle_model_aws_bedrock_conf_ui, | |
| inputs=[model_provider], | |
| outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn], | |
| ) | |
| model_provider.change( | |
| _toggle_model_hf_conf_ui, | |
| inputs=[model_provider], | |
| outputs=[hf_conf_group, connect_hf_btn], | |
| ) | |
| model_provider.change( | |
| _toggle_model_azure_conf_ui, | |
| inputs=[model_provider], | |
| outputs=[azure_conf_group, connect_azure_btn], | |
| ) | |
| connect_aws_bedrock_btn.click( | |
| gr_connect_to_bedrock, | |
| inputs=[ | |
| model_id_textbox, | |
| aws_access_key_textbox, | |
| aws_secret_key_textbox, | |
| aws_session_token_textbox, | |
| aws_region_dropdown, | |
| mcp_list.state, | |
| agent_system_message_radio, | |
| agent_trace_tools_checkbox, | |
| temperature, | |
| max_tokens, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| connect_hf_btn.click( | |
| gr_connect_to_hf, | |
| inputs=[ | |
| model_id_textbox, | |
| hf_token, | |
| mcp_list.state, | |
| agent_system_message_radio, | |
| agent_trace_tools_checkbox, | |
| temperature, | |
| max_tokens, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| connect_azure_btn.click( | |
| gr_connect_to_azure, | |
| inputs=[ | |
| model_id_textbox, | |
| azure_endpoint, | |
| azure_api_token, | |
| azure_api_version, | |
| mcp_list.state, | |
| agent_system_message_radio, | |
| agent_trace_tools_checkbox, | |
| temperature, | |
| max_tokens, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| model_id_dropdown.change( | |
| lambda x, y: ( | |
| gr.update( | |
| value=MODEL_OPTIONS.get(y, {}).get(x), | |
| visible=True, | |
| ) | |
| if x | |
| else model_id_textbox.value | |
| ), | |
| inputs=[model_id_dropdown, model_provider], | |
| outputs=[model_id_textbox], | |
| ) | |
| model_provider.change( | |
| _on_change_model_configuration, | |
| inputs=[model_provider], | |
| ) | |
| model_id_dropdown.change( | |
| _on_change_model_configuration, | |
| inputs=[model_id_dropdown, model_provider], | |
| ) | |
| ## Entry Point ## | |
| if __name__ == "__main__": | |
| gr_app.launch() | |