|
|
"""
|
|
|
Database Agent - A specialized ReAct agent for MITRE ATT&CK technique retrieval
|
|
|
|
|
|
This agent provides semantic search capabilities over the MITRE ATT&CK knowledge base
|
|
|
with support for filtered searches by tactics, platforms, and other metadata.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import sys
|
|
|
import time
|
|
|
from typing import List, Dict, Any, Optional, Literal
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
from langchain.chat_models import init_chat_model
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
from langchain_text_splitters import TokenTextSplitter
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
|
|
|
|
|
|
from langsmith import traceable, Client, get_current_run_tree
|
|
|
|
|
|
|
|
|
from src.agents.database_agent.prompts import DATABASE_AGENT_SYSTEM_PROMPT
|
|
|
|
|
|
|
|
|
try:
|
|
|
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
|
|
|
except Exception as e:
|
|
|
print(
|
|
|
f"[WARNING] Could not import CyberKnowledgeBase. Please adjust import paths. {e}"
|
|
|
)
|
|
|
sys.exit(1)
|
|
|
|
|
|
ls_client = Client(api_key=os.getenv("LANGSMITH_API_KEY"))
|
|
|
|
|
|
|
|
|
def truncate_to_tokens(text: str, max_tokens: int) -> str:
|
|
|
"""
|
|
|
Truncate text to a maximum number of tokens using LangChain's TokenTextSplitter.
|
|
|
|
|
|
Args:
|
|
|
text: The text to truncate
|
|
|
max_tokens: Maximum number of tokens
|
|
|
|
|
|
Returns:
|
|
|
Truncated text within the token limit
|
|
|
"""
|
|
|
if not text:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
cleaned_text = text.replace("\n", " ")
|
|
|
|
|
|
|
|
|
splitter = TokenTextSplitter(
|
|
|
encoding_name="o200k_base", chunk_size=max_tokens, chunk_overlap=0
|
|
|
)
|
|
|
|
|
|
chunks = splitter.split_text(cleaned_text)
|
|
|
return chunks[0] if chunks else ""
|
|
|
|
|
|
|
|
|
class DatabaseAgent:
|
|
|
"""
|
|
|
A specialized ReAct agent for MITRE ATT&CK technique retrieval and search.
|
|
|
|
|
|
This agent provides intelligent search capabilities over the MITRE ATT&CK knowledge base,
|
|
|
including semantic search, filtered search, and multi-query search with RRF fusion.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
kb_path: str = "./cyber_knowledge_base",
|
|
|
llm_client: BaseChatModel = None,
|
|
|
):
|
|
|
"""
|
|
|
Initialize the Database Agent.
|
|
|
|
|
|
Args:
|
|
|
kb_path: Path to the cyber knowledge base directory
|
|
|
llm_client: LLM model to use for the agent
|
|
|
"""
|
|
|
self.kb_path = kb_path
|
|
|
self.kb = self._init_knowledge_base()
|
|
|
|
|
|
if llm_client:
|
|
|
self.llm = llm_client
|
|
|
else:
|
|
|
self.llm = init_chat_model(
|
|
|
"google_genai:gemini-2.0-flash",
|
|
|
temperature=0.1,
|
|
|
)
|
|
|
print(
|
|
|
f"[INFO] Database Agent: Using default LLM model: google_genai:gemini-2.0-flash"
|
|
|
)
|
|
|
|
|
|
self.tools = self._create_tools()
|
|
|
|
|
|
|
|
|
self.agent = self._create_react_agent()
|
|
|
|
|
|
@traceable(name="database_agent_init_kb")
|
|
|
def _init_knowledge_base(self) -> CyberKnowledgeBase:
|
|
|
"""Initialize and load the cyber knowledge base."""
|
|
|
kb = CyberKnowledgeBase()
|
|
|
|
|
|
if kb.load_knowledge_base(self.kb_path):
|
|
|
print("[SUCCESS] Database Agent: Loaded existing knowledge base")
|
|
|
return kb
|
|
|
else:
|
|
|
print(
|
|
|
f"[ERROR] Database Agent: Could not load knowledge base from {self.kb_path}"
|
|
|
)
|
|
|
print("Please ensure the knowledge base is built and available.")
|
|
|
raise RuntimeError("Knowledge base not available")
|
|
|
|
|
|
@traceable(name="database_agent_format_results")
|
|
|
def _format_results_as_json(self, results) -> List[Dict[str, Any]]:
|
|
|
"""Format search results as structured JSON."""
|
|
|
output = []
|
|
|
for doc in results:
|
|
|
technique_info = {
|
|
|
"attack_id": doc.metadata.get("attack_id", "Unknown"),
|
|
|
"name": doc.metadata.get("name", "Unknown"),
|
|
|
"tactics": [
|
|
|
t.strip()
|
|
|
for t in doc.metadata.get("tactics", "").split(",")
|
|
|
if t.strip()
|
|
|
],
|
|
|
"platforms": [
|
|
|
p.strip()
|
|
|
for p in doc.metadata.get("platforms", "").split(",")
|
|
|
if p.strip()
|
|
|
],
|
|
|
"description": truncate_to_tokens(doc.page_content, 300),
|
|
|
"relevance_score": doc.metadata.get("relevance_score", None),
|
|
|
"rrf_score": doc.metadata.get("rrf_score", None),
|
|
|
"mitigation_count": doc.metadata.get("mitigation_count", 0),
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
output.append(technique_info)
|
|
|
return output
|
|
|
|
|
|
def _log_search_metrics(
|
|
|
self,
|
|
|
search_type: str,
|
|
|
query: str,
|
|
|
results_count: int,
|
|
|
execution_time: float,
|
|
|
success: bool,
|
|
|
):
|
|
|
"""Log search performance metrics to LangSmith."""
|
|
|
try:
|
|
|
current_run = get_current_run_tree()
|
|
|
if current_run:
|
|
|
ls_client.create_feedback(
|
|
|
run_id=current_run.id,
|
|
|
key="database_search_performance",
|
|
|
score=1.0 if success else 0.0,
|
|
|
value={
|
|
|
"search_type": search_type,
|
|
|
"query": query,
|
|
|
"results_count": results_count,
|
|
|
"execution_time": execution_time,
|
|
|
"success": success,
|
|
|
},
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Failed to log search metrics: {e}")
|
|
|
|
|
|
def _log_agent_performance(
|
|
|
self, query: str, message_count: int, execution_time: float, success: bool
|
|
|
):
|
|
|
"""Log overall agent performance metrics."""
|
|
|
try:
|
|
|
current_run = get_current_run_tree()
|
|
|
if current_run:
|
|
|
ls_client.create_feedback(
|
|
|
run_id=current_run.id,
|
|
|
key="database_agent_performance",
|
|
|
score=1.0 if success else 0.0,
|
|
|
value={
|
|
|
"query": query,
|
|
|
"message_count": message_count,
|
|
|
"execution_time": execution_time,
|
|
|
"success": success,
|
|
|
"agent_type": "database_search",
|
|
|
},
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Failed to log agent metrics: {e}")
|
|
|
|
|
|
def _create_tools(self):
|
|
|
"""Create the search tools for the Database Agent."""
|
|
|
|
|
|
@tool
|
|
|
@traceable(name="database_search_techniques")
|
|
|
def search_techniques(query: str, top_k: int = 5) -> str:
|
|
|
"""
|
|
|
Search for MITRE ATT&CK techniques using semantic search.
|
|
|
|
|
|
Args:
|
|
|
query: Search query string
|
|
|
top_k: Number of results to return (default: 5, max: 20)
|
|
|
|
|
|
Returns:
|
|
|
JSON string with search results containing technique details
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
|
|
|
top_k = min(max(top_k, 1), 20)
|
|
|
|
|
|
|
|
|
results = self.kb.search(query, top_k=top_k)
|
|
|
techniques = self._format_results_as_json(results)
|
|
|
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_search_metrics(
|
|
|
"single_query", query, len(techniques), execution_time, True
|
|
|
)
|
|
|
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"search_type": "single_query",
|
|
|
"query": query,
|
|
|
"techniques": techniques,
|
|
|
"total_results": len(techniques),
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_search_metrics(
|
|
|
"single_query", query, 0, execution_time, False
|
|
|
)
|
|
|
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"error": str(e),
|
|
|
"techniques": [],
|
|
|
"message": "Error occurred during search",
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
@tool
|
|
|
@traceable(name="database_search_techniques_filtered")
|
|
|
def search_techniques_filtered(
|
|
|
query: str,
|
|
|
top_k: int = 5,
|
|
|
filter_tactics: Optional[List[str]] = None,
|
|
|
filter_platforms: Optional[List[str]] = None,
|
|
|
) -> str:
|
|
|
"""
|
|
|
Search for MITRE ATT&CK techniques with metadata filters.
|
|
|
|
|
|
Args:
|
|
|
query: Search query string
|
|
|
top_k: Number of results to return (default: 5, max: 20)
|
|
|
filter_tactics: Filter by specific tactics (e.g., ['defense-evasion', 'privilege-escalation'])
|
|
|
filter_platforms: Filter by platforms (e.g., ['Windows', 'Linux'])
|
|
|
|
|
|
Returns:
|
|
|
JSON string with filtered search results
|
|
|
|
|
|
Examples of tactics: initial-access, execution, persistence, privilege-escalation,
|
|
|
defense-evasion, credential-access, discovery, lateral-movement, collection,
|
|
|
command-and-control, exfiltration, impact
|
|
|
|
|
|
Examples of platforms: Windows, macOS, Linux, AWS, Azure, GCP, SaaS, Network,
|
|
|
Containers, Android, iOS
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
|
|
|
top_k = min(max(top_k, 1), 20)
|
|
|
|
|
|
|
|
|
results = self.kb.search(
|
|
|
query,
|
|
|
top_k=top_k,
|
|
|
filter_tactics=filter_tactics,
|
|
|
filter_platforms=filter_platforms,
|
|
|
)
|
|
|
techniques = self._format_results_as_json(results)
|
|
|
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_search_metrics(
|
|
|
"filtered_query", query, len(techniques), execution_time, True
|
|
|
)
|
|
|
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"search_type": "single_query_filtered",
|
|
|
"query": query,
|
|
|
"filters": {
|
|
|
"tactics": filter_tactics,
|
|
|
"platforms": filter_platforms,
|
|
|
},
|
|
|
"techniques": techniques,
|
|
|
"total_results": len(techniques),
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_search_metrics(
|
|
|
"filtered_query", query, 0, execution_time, False
|
|
|
)
|
|
|
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"error": str(e),
|
|
|
"techniques": [],
|
|
|
"message": "Error occurred during filtered search",
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
|
|
|
return [search_techniques]
|
|
|
|
|
|
def _create_react_agent(self):
|
|
|
"""Create the ReAct agent with the search tools using the prompt from prompts.py."""
|
|
|
return create_react_agent(
|
|
|
model=self.llm,
|
|
|
tools=self.tools,
|
|
|
prompt=DATABASE_AGENT_SYSTEM_PROMPT,
|
|
|
name="database_agent",
|
|
|
)
|
|
|
|
|
|
@traceable(name="database_agent_search")
|
|
|
def search(self, query: str, **kwargs) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Search for techniques using the agent's capabilities.
|
|
|
|
|
|
Args:
|
|
|
query: The search query or question
|
|
|
**kwargs: Additional parameters passed to the agent
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with the agent's response
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
messages = [HumanMessage(content=query)]
|
|
|
response = self.agent.invoke({"messages": messages}, **kwargs)
|
|
|
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_agent_performance(
|
|
|
query, len(response.get("messages", [])), execution_time, True
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"success": True,
|
|
|
"messages": response["messages"],
|
|
|
"final_response": (
|
|
|
response["messages"][-1].content if response["messages"] else ""
|
|
|
),
|
|
|
}
|
|
|
except Exception as e:
|
|
|
execution_time = time.time() - start_time
|
|
|
self._log_agent_performance(query, 0, execution_time, False)
|
|
|
|
|
|
return {
|
|
|
"success": False,
|
|
|
"error": str(e),
|
|
|
"messages": [],
|
|
|
"final_response": f"Error during search: {str(e)}",
|
|
|
}
|
|
|
|
|
|
@traceable(name="database_agent_stream_search")
|
|
|
def stream_search(self, query: str, **kwargs):
|
|
|
"""
|
|
|
Stream the agent's search process for real-time feedback.
|
|
|
|
|
|
Args:
|
|
|
query: The search query or question
|
|
|
**kwargs: Additional parameters passed to the agent
|
|
|
|
|
|
Yields:
|
|
|
Streaming responses from the agent
|
|
|
"""
|
|
|
try:
|
|
|
messages = [HumanMessage(content=query)]
|
|
|
for chunk in self.agent.stream({"messages": messages}, **kwargs):
|
|
|
yield chunk
|
|
|
except Exception as e:
|
|
|
yield {"error": str(e)}
|
|
|
|
|
|
|
|
|
@traceable(name="database_agent_test")
|
|
|
def test_database_agent():
|
|
|
"""Test function to demonstrate Database Agent capabilities."""
|
|
|
print("Testing Database Agent...")
|
|
|
|
|
|
|
|
|
try:
|
|
|
agent = DatabaseAgent()
|
|
|
print("Database Agent initialized successfully")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to initialize Database Agent: {e}")
|
|
|
return
|
|
|
|
|
|
|
|
|
test_queries = [
|
|
|
"Find techniques related to credential dumping and LSASS memory access",
|
|
|
"What are Windows-specific privilege escalation techniques?",
|
|
|
"Search for defense evasion techniques that work on Linux platforms",
|
|
|
"Find lateral movement techniques involving SMB or WMI",
|
|
|
"What techniques are used for persistence on macOS systems?",
|
|
|
]
|
|
|
|
|
|
for i, query in enumerate(test_queries, 1):
|
|
|
print(f"\n--- Test Query {i} ---")
|
|
|
print(f"Query: {query}")
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
result = agent.search(query)
|
|
|
if result["success"]:
|
|
|
print("Search completed successfully")
|
|
|
|
|
|
for msg in reversed(result["messages"]):
|
|
|
if isinstance(msg, AIMessage) and not hasattr(msg, "tool_calls"):
|
|
|
print(f"Response: {msg.content[:300]}...")
|
|
|
break
|
|
|
else:
|
|
|
print(f"Search failed: {result['error']}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_database_agent()
|
|
|
|