minhan6559's picture
Upload 102 files
9e3d618 verified
"""
Retrieval Supervisor - Coordinates CTI Agent, Database Agent, and Grader Agent
This supervisor manages the retrieval pipeline for cybersecurity analysis, coordinating
multiple specialized agents to provide comprehensive threat intelligence and MITRE ATT&CK
technique retrieval.
"""
import json
import os
from typing import Dict, Any, List, Optional
from pathlib import Path
from langchain_core.messages import convert_to_messages
# LangGraph and LangChain imports
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
# Import your agent classes
from src.agents.cti_agent.cti_agent import CTIAgent
from src.agents.database_agent.agent import DatabaseAgent
# Import prompts
from src.agents.retrieval_supervisor.prompts import (
GRADER_AGENT_PROMPT,
SUPERVISOR_PROMPT_TEMPLATE,
INPUT_MESSAGE_TEMPLATE,
LOG_ANALYSIS_SECTION_TEMPLATE,
CONTEXT_SECTION_TEMPLATE,
)
class RetrievalSupervisor:
"""
Retrieval Supervisor that coordinates CTI Agent, Database Agent, and Grader Agent
using LangGraph's supervisor pattern for comprehensive threat intelligence retrieval.
"""
def __init__(
self,
llm_model: str = "google_genai:gemini-2.0-flash",
kb_path: str = "./cyber_knowledge_base",
max_iterations: int = 3,
llm_client=None,
):
"""
Initialize the Retrieval Supervisor.
Args:
llm_model: Specific model to use
kb_path: Path to the cyber knowledge base
max_iterations: Maximum iterations for the retrieval pipeline
llm_client: Optional pre-initialized LLM client (overrides llm_model)
"""
self.max_iterations = max_iterations
self.llm_model = llm_model
# Initialize the supervisor LLM
if llm_client:
self.llm_client = llm_client
print(f"[INFO] Retrieval Supervisor: Using provided LLM client")
elif "gpt-oss" in llm_model:
reasoning_effort = "low"
reasoning_format = "hidden"
self.llm_client = init_chat_model(
llm_model,
temperature=0.1,
reasoning_effort=reasoning_effort,
reasoning_format=reasoning_format,
)
print(
f"[INFO] Retrieval Supervisor: Using GPT-OSS model: {llm_model} with reasoning effort: {reasoning_effort}"
)
else:
self.llm_client = init_chat_model(llm_model, temperature=0.1)
print(f"[INFO] Retrieval Supervisor: Initialized with {llm_model}")
# Initialize agents
# self.cti_agent = self._initialize_cti_agent()
self.database_agent = self._initialize_database_agent(kb_path)
self.grader_agent = self._initialize_grader_agent()
# Create the supervisor
self.supervisor = self._create_supervisor()
def _initialize_cti_agent(self) -> CTIAgent:
"""Initialize the CTI Agent."""
try:
cti_agent = CTIAgent(llm=self.llm_client)
print("CTI Agent initialized successfully")
return cti_agent
except Exception as e:
print(f"Failed to initialize CTI Agent: {e}")
raise
def _initialize_database_agent(self, kb_path: str) -> DatabaseAgent:
"""Initialize the Database Agent."""
try:
database_agent = DatabaseAgent(
kb_path=kb_path,
llm_client=self.llm_client,
)
print("Database Agent initialized successfully")
return database_agent
except Exception as e:
print(f"Failed to initialize Database Agent: {e}")
raise
def _initialize_grader_agent(self):
"""Initialize the Grader Agent as a ReAct agent with no tools."""
return create_react_agent(
model=self.llm_client,
tools=[], # No tools for grader
prompt=GRADER_AGENT_PROMPT,
name="retrieval_grader_agent",
)
def _create_supervisor(self):
"""Create the supervisor using langgraph_supervisor."""
# Prepare agent list with CompiledStateGraph objects
agents = [
self.database_agent.agent, # Database Agent's ReAct agent
self.grader_agent, # Grader Agent (ReAct agent)
]
# Format supervisor prompt with max_iterations
supervisor_prompt = SUPERVISOR_PROMPT_TEMPLATE.format(
max_iterations=self.max_iterations
)
return create_supervisor(
model=self.llm_client,
agents=agents,
prompt=supervisor_prompt,
add_handoff_back_messages=True,
# output_mode="full_history",
supervisor_name="retrieval_supervisor",
).compile(name="retrieval_supervisor")
def invoke(
self,
query: str,
log_analysis_report: Optional[Dict[str, Any]] = None,
context: Optional[str] = None,
trace: bool = False,
) -> Dict[str, Any]:
"""
Invoke the retrieval supervisor pipeline.
Args:
query: The intelligence retrieval query/task
log_analysis_report: Optional log analysis report from log analysis agent
context: Optional additional context
trace: Whether to trace the pipeline
Returns:
Dictionary containing the structured retrieval results
"""
try:
# Build the input message with context
input_content = self._build_input_message(
query, log_analysis_report, context
)
# Initialize state
initial_state = {"messages": [HumanMessage(content=input_content)]}
# print("\n" + "=" * 60)
# print("RETRIEVAL SUPERVISOR PIPELINE STARTING")
# print("=" * 60)
# print(f"Query: {query}")
# if log_analysis_report:
# print(
# f"Log Analysis Report Assessment: {log_analysis_report.get('overall_assessment', 'Unknown')} assessment"
# )
# print()
# Execute the supervisor pipeline
raw_result = self.supervisor.invoke(initial_state)
if trace:
self._print_trace_pipeline(raw_result)
# Parse structured output from the supervisor
structured_result = self._parse_supervisor_output(raw_result, query)
return structured_result
except Exception as e:
print(f"[ERROR] Retrieval Supervisor pipeline failed: {e}")
raise
def invoke_direct_query(self, query: str, trace: bool = False) -> Dict[str, Any]:
"""Invoke the retrieval supervisor pipeline with a direct query."""
raw_result = self.supervisor.invoke({"messages": [HumanMessage(content=query)]})
if trace:
self._print_trace_pipeline(raw_result)
# Parse structured output from the supervisor
structured_result = self._parse_supervisor_output(raw_result, query)
return structured_result
def stream(
self,
query: str,
log_analysis_report: Optional[Dict[str, Any]] = None,
context: Optional[str] = None,
):
# Build the input message with context
input_content = self._build_input_message(query, log_analysis_report, context)
# Initialize state
initial_state = {"messages": [HumanMessage(content=input_content)]}
for chunk in self.supervisor.stream(initial_state, subgraphs=True):
self._pretty_print_messages(chunk, last_message=True)
def _pretty_print_message(self, message, indent=False):
pretty_message = message.pretty_repr(html=True)
if not indent:
print(pretty_message)
return
indented = "\n".join("\t" + c for c in pretty_message.split("\n"))
print(indented)
def _pretty_print_messages(self, update, last_message=False):
is_subgraph = False
if isinstance(update, tuple):
ns, update = update
# skip parent graph updates in the printouts
if len(ns) == 0:
return
graph_id = ns[-1].split(":")[0]
print(f"Update from subgraph {graph_id}:")
print("\n")
is_subgraph = True
for node_name, node_update in update.items():
update_label = f"Update from node {node_name}:"
if is_subgraph:
update_label = "\t" + update_label
print(update_label)
print("\n")
messages = convert_to_messages(node_update["messages"])
if last_message:
messages = messages[-1:]
for m in messages:
self._pretty_print_message(m, indent=is_subgraph)
print("\n")
def _print_trace_pipeline(self, result: Dict[str, Any]):
"""Print detailed trace of the pipeline execution with message flow."""
messages = result.get("messages", [])
if not messages:
print("[TRACE] No messages found in pipeline result")
return
print("\n" + "=" * 60)
print("PIPELINE EXECUTION TRACE")
print("=" * 60)
# Print all messages with detailed formatting
for i, msg in enumerate(messages, 1):
print(f"\n--- Message {i} ---")
if isinstance(msg, HumanMessage):
print(f"[Human] {msg.content}")
elif isinstance(msg, AIMessage):
agent_name = getattr(msg, "name", None) or "agent"
print(f"[Agent:{agent_name}] {msg.content}")
# Check for function calls
if (
hasattr(msg, "additional_kwargs")
and "function_call" in msg.additional_kwargs
):
fc = msg.additional_kwargs["function_call"]
print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}")
elif isinstance(msg, ToolMessage):
tool_name = getattr(msg, "name", None) or "tool"
content = (
msg.content if isinstance(msg.content, str) else str(msg.content)
)
# Truncate long content for readability
preview = content[:300] + ("..." if len(content) > 300 else "")
print(f"[Tool:{tool_name}] {preview}")
else:
print(f"[Message] {getattr(msg, 'content', '')}")
# Print final supervisor decision if available
if messages:
latest_message = messages[-1]
if isinstance(latest_message, AIMessage):
print(f"\n--- Final Supervisor Output ---")
print(latest_message.content)
# Check if this looks like a grader decision
if "decision" in latest_message.content.lower():
try:
# Try to parse JSON decision
content = latest_message.content
if "{" in content and "}" in content:
start = content.find("{")
end = content.rfind("}") + 1
decision_json = json.loads(content[start:end])
decision = decision_json.get("decision", "unknown")
print(
f"\n[SUCCESS] Pipeline completed - Decision: {decision}"
)
if decision == "ACCEPT":
print("Results accepted by grader")
elif decision == "NEEDS_MITRE":
print("Additional MITRE technique analysis needed")
except (json.JSONDecodeError, KeyError):
print("\n[INFO] Pipeline completed (decision parsing failed)")
print("\n" + "=" * 60)
print("TRACE COMPLETED")
print("=" * 60)
def _build_input_message(
self,
query: str,
log_analysis_report: Optional[Dict[str, Any]],
context: Optional[str],
) -> str:
"""Build the input message for the supervisor."""
# Build log analysis section
log_analysis_section = ""
if log_analysis_report:
log_analysis_section = LOG_ANALYSIS_SECTION_TEMPLATE.format(
log_analysis_report=json.dumps(log_analysis_report, indent=2)
)
# Build context section
context_section = ""
if context:
context_section = CONTEXT_SECTION_TEMPLATE.format(context=context)
# Build complete input message
input_message = INPUT_MESSAGE_TEMPLATE.format(
query=query,
log_analysis_section=log_analysis_section,
context_section=context_section,
)
return input_message
def _parse_supervisor_output(
self, raw_result: Dict[str, Any], original_query: str
) -> Dict[str, Any]:
"""Parse the supervisor's structured output from the raw result."""
messages = raw_result.get("messages", [])
# Look for the final supervisor message with structured JSON output
final_supervisor_message = None
for msg in reversed(messages):
if (
hasattr(msg, "name")
and msg.name == "supervisor"
and hasattr(msg, "content")
and msg.content
):
final_supervisor_message = msg.content
break
if not final_supervisor_message:
# Fallback: use the last message
if messages:
final_supervisor_message = (
messages[-1].content if hasattr(messages[-1], "content") else ""
)
# Try to extract JSON from the supervisor's final message
structured_output = self._extract_json_from_content(final_supervisor_message)
if structured_output:
# Validate and enhance the structured output
return self._validate_and_enhance_output(
structured_output, original_query, messages
)
else:
# Fallback: create structured output from message analysis
return self._create_fallback_output(messages, original_query)
def _extract_json_from_content(self, content: str) -> Optional[Dict[str, Any]]:
"""Extract JSON from supervisor message content."""
if not content:
return None
# Look for JSON blocks
if "```json" in content:
json_blocks = content.split("```json")
for block in json_blocks[1:]:
json_str = block.split("```")[0].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError:
continue
# Look for any JSON-like structures
start_idx = 0
while True:
start_idx = content.find("{", start_idx)
if start_idx == -1:
break
# Find matching closing brace
brace_count = 0
end_idx = start_idx
for i in range(start_idx, len(content)):
if content[i] == "{":
brace_count += 1
elif content[i] == "}":
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if brace_count == 0:
json_str = content[start_idx:end_idx]
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
start_idx += 1
return None
def _validate_and_enhance_output(
self, structured_output: Dict[str, Any], original_query: str, messages: List
) -> Dict[str, Any]:
"""Validate and enhance the structured output."""
# Ensure required fields exist
if "status" not in structured_output:
structured_output["status"] = "SUCCESS"
if "final_assessment" not in structured_output:
structured_output["final_assessment"] = "ACCEPTED"
if "retrieved_techniques" not in structured_output:
structured_output["retrieved_techniques"] = []
if "agents_used" not in structured_output:
# Extract agents used from messages
agents_used = set()
for msg in messages:
if hasattr(msg, "name") and msg.name:
agents_used.add(str(msg.name))
structured_output["agents_used"] = list(agents_used)
if "summary" not in structured_output:
technique_count = len(structured_output.get("retrieved_techniques", []))
structured_output["summary"] = (
f"Retrieved {technique_count} MITRE techniques for analysis"
)
if "iteration_count" not in structured_output:
structured_output["iteration_count"] = 1
# Add metadata
structured_output["query"] = original_query
structured_output["total_techniques"] = len(
structured_output.get("retrieved_techniques", [])
)
return structured_output
def _create_fallback_output(
self, messages: List, original_query: str
) -> Dict[str, Any]:
"""Create fallback structured output when JSON parsing fails."""
# Extract techniques from database agent messages
techniques = []
agents_used = set()
for msg in messages:
if hasattr(msg, "name") and msg.name:
agents_used.add(str(msg.name))
# Look for database agent results
if "database" in str(msg.name).lower() and hasattr(msg, "content"):
try:
# Try to extract techniques from tool messages
if hasattr(msg, "name") and "search_techniques" in str(
msg.name
):
tool_data = (
json.loads(msg.content)
if isinstance(msg.content, str)
else msg.content
)
if "techniques" in tool_data:
for tech in tool_data["techniques"]:
# Convert tactics to list format
tactics = tech.get("tactics", [])
if isinstance(tactics, str):
tactics = [tactics] if tactics else []
elif not isinstance(tactics, list):
tactics = []
technique = {
"technique_id": tech.get("attack_id", ""),
"technique_name": tech.get("name", ""),
"tactic": tactics, # Now as list
"description": tech.get("description", ""),
"relevance_score": tech.get(
"relevance_score", 0.5
),
}
techniques.append(technique)
except (json.JSONDecodeError, TypeError, AttributeError):
continue
return {
"status": "PARTIAL",
"final_assessment": "NEEDS_MORE_INFO",
"retrieved_techniques": techniques,
"agents_used": list(agents_used),
"summary": f"Retrieved {len(techniques)} MITRE techniques (fallback parsing)",
"iteration_count": 1,
"query": original_query,
"total_techniques": len(techniques),
"parsing_method": "fallback",
}
def _process_results(
self, result: Dict[str, Any], original_query: str
) -> Dict[str, Any]:
"""Process and format the supervisor results."""
messages = result.get("messages", [])
# Extract information from messages
agents_used = set()
cti_results = []
database_results = []
grader_decisions = []
for msg in messages:
if hasattr(msg, "name"):
agent_name = msg.name
if agent_name: # ignore None or empty
agents_used.add(str(agent_name))
if agent_name == "database_agent":
database_results.append(msg.content)
elif agent_name == "retrieval_grader_agent":
grader_decisions.append(msg.content)
# Get final supervisor message
final_message = ""
for msg in reversed(messages):
if (
isinstance(msg, AIMessage)
and hasattr(msg, "name")
and msg.name == "supervisor"
):
final_message = msg.content
break
# Determine final assessment
final_assessment = self._determine_final_assessment(
grader_decisions, final_message
)
# Extract recommendations
recommendations = self._extract_recommendations(
cti_results, database_results, grader_decisions
)
return {
"status": "SUCCESS",
"query": original_query,
"agents_used": [
a for a in list(agents_used) if isinstance(a, str) and a.strip()
],
"results": {
"cti_intelligence": cti_results,
"mitre_techniques": database_results,
"quality_assessments": grader_decisions,
"supervisor_synthesis": final_message,
},
"final_assessment": final_assessment,
"recommendations": recommendations,
"message_history": messages,
"summary": self._generate_summary(
cti_results, database_results, final_assessment
),
}
def _determine_final_assessment(
self, grader_decisions: List[str], final_message: str
) -> str:
"""Determine the final assessment based on grader decisions."""
# Look for the latest grader decision
if grader_decisions:
latest_decision = grader_decisions[-1]
try:
# Try to parse JSON from grader
if "{" in latest_decision and "}" in latest_decision:
start = latest_decision.find("{")
end = latest_decision.rfind("}") + 1
decision_json = json.loads(latest_decision[start:end])
return decision_json.get("decision", "UNKNOWN")
except json.JSONDecodeError:
pass
# Fallback to content analysis
content = (final_message + " " + " ".join(grader_decisions)).lower()
if "accept" in content:
return "ACCEPTED"
elif "needs_both" in content:
return "NEEDS_BOTH"
elif "needs_cti" in content:
return "NEEDS_CTI"
elif "needs_mitre" in content:
return "NEEDS_MITRE"
else:
return "COMPLETED"
def _extract_recommendations(
self,
cti_results: List[str],
database_results: List[str],
grader_decisions: List[str],
) -> List[str]:
"""Extract actionable recommendations from the results."""
recommendations = []
# Standard recommendations based on results
if cti_results:
recommendations.append("Review CTI findings for threat actor attribution")
recommendations.append("Implement IOC-based detection rules")
if database_results:
recommendations.append("Map detected techniques to defensive controls")
recommendations.append("Update threat hunting playbooks")
# Extract specific recommendations from grader
for decision in grader_decisions:
try:
if "{" in decision and "}" in decision:
start = decision.find("{")
end = decision.rfind("}") + 1
decision_json = json.loads(decision[start:end])
suggestions = decision_json.get("improvement_suggestions", [])
recommendations.extend(suggestions)
except json.JSONDecodeError:
continue
# Remove duplicates and limit
unique_recommendations = list(dict.fromkeys(recommendations))
return unique_recommendations[:5] # Top 5 recommendations
def _generate_summary(
self, cti_results: List[str], database_results: List[str], final_assessment: str
) -> str:
"""Generate a concise summary of the retrieval results."""
summary_parts = [
f"Retrieval Status: {final_assessment}",
f"CTI Sources Analyzed: {len(cti_results)}",
f"MITRE Techniques Retrieved: {len(database_results)}",
]
if cti_results:
summary_parts.append("Threat intelligence gathered from external sources")
if database_results:
summary_parts.append("MITRE ATT&CK techniques mapped to findings")
return " | ".join(summary_parts)
def stream_invoke(
self,
query: str,
log_analysis_report: Optional[Dict[str, Any]] = None,
context: Optional[str] = None,
):
"""
Stream the retrieval supervisor pipeline execution.
Args:
query: The intelligence retrieval query/task
log_analysis_report: Optional log analysis report from log analysis agent
context: Optional additional context
Yields:
Streaming updates from the supervisor pipeline
"""
try:
# Build the input message with context
input_content = self._build_input_message(
query, log_analysis_report, context
)
# Initialize state
initial_state = {"messages": [HumanMessage(content=input_content)]}
# print("\n" + "=" * 60)
# print("RETRIEVAL SUPERVISOR PIPELINE STREAMING")
# print("=" * 60)
# print(f"Query: {query}")
# print()
# Stream the supervisor pipeline
for chunk in self.supervisor.stream(initial_state):
yield chunk
except Exception as e:
yield {"error": str(e)}
# Example usage and testing
def test_retrieval_supervisor():
"""Test the Retrieval Supervisor with sample data."""
# Sample log analysis report
sample_report = {
"overall_assessment": "ABNORMAL",
"total_events_analyzed": 245,
"analysis_summary": "Detected suspicious PowerShell execution with base64 encoding and potential credential access attempts targeting LSASS process",
"abnormal_events": [
{
"event_id": "4688",
"event_description": "PowerShell process creation with encoded command parameter",
"why_abnormal": "Base64 encoded command suggests obfuscation and evasion techniques",
"severity": "HIGH",
"potential_threat": "Defense evasion or malware execution",
"attack_category": "defense_evasion",
},
{
"event_id": "4656",
"event_description": "Handle request to LSASS process memory",
"why_abnormal": "Unusual access pattern to sensitive authentication process",
"severity": "CRITICAL",
"potential_threat": "Credential dumping attack",
"attack_category": "credential_access",
},
],
}
try:
# Initialize supervisor
supervisor = RetrievalSupervisor()
# Test query
query = "Analyze the detected PowerShell and LSASS access patterns. Provide threat intelligence on related attack campaigns and map to MITRE ATT&CK techniques."
# Execute retrieval with trace enabled
results = supervisor.invoke(
query=query,
log_analysis_report=sample_report,
context="High-priority security incident requiring immediate threat intelligence",
trace=True,
)
# Display results
print("=" * 60)
print("RETRIEVAL RESULTS SUMMARY")
print("=" * 60)
print(f"Status: {results['status']}")
print(f"Final Assessment: {results['final_assessment']}")
print(f"Agents Used: {', '.join(results['agents_used'])}")
print(f"\nSummary: {results['summary']}")
print("\nRecommendations:")
for i, rec in enumerate(results["recommendations"], 1):
print(f"{i}. {rec}")
return results
except Exception as e:
print(f"Test failed: {e}")
return None
if __name__ == "__main__":
test_retrieval_supervisor()