|
|
"""
|
|
|
Retrieval Supervisor - Coordinates CTI Agent, Database Agent, and Grader Agent
|
|
|
|
|
|
This supervisor manages the retrieval pipeline for cybersecurity analysis, coordinating
|
|
|
multiple specialized agents to provide comprehensive threat intelligence and MITRE ATT&CK
|
|
|
technique retrieval.
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import os
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
from pathlib import Path
|
|
|
from langchain_core.messages import convert_to_messages
|
|
|
|
|
|
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
|
|
from langchain.chat_models import init_chat_model
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
from langgraph_supervisor import create_supervisor
|
|
|
|
|
|
|
|
|
from src.agents.cti_agent.cti_agent import CTIAgent
|
|
|
from src.agents.database_agent.agent import DatabaseAgent
|
|
|
|
|
|
|
|
|
from src.agents.retrieval_supervisor.prompts import (
|
|
|
GRADER_AGENT_PROMPT,
|
|
|
SUPERVISOR_PROMPT_TEMPLATE,
|
|
|
INPUT_MESSAGE_TEMPLATE,
|
|
|
LOG_ANALYSIS_SECTION_TEMPLATE,
|
|
|
CONTEXT_SECTION_TEMPLATE,
|
|
|
)
|
|
|
|
|
|
|
|
|
class RetrievalSupervisor:
|
|
|
"""
|
|
|
Retrieval Supervisor that coordinates CTI Agent, Database Agent, and Grader Agent
|
|
|
using LangGraph's supervisor pattern for comprehensive threat intelligence retrieval.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
llm_model: str = "google_genai:gemini-2.0-flash",
|
|
|
kb_path: str = "./cyber_knowledge_base",
|
|
|
max_iterations: int = 3,
|
|
|
llm_client=None,
|
|
|
):
|
|
|
"""
|
|
|
Initialize the Retrieval Supervisor.
|
|
|
|
|
|
Args:
|
|
|
llm_model: Specific model to use
|
|
|
kb_path: Path to the cyber knowledge base
|
|
|
max_iterations: Maximum iterations for the retrieval pipeline
|
|
|
llm_client: Optional pre-initialized LLM client (overrides llm_model)
|
|
|
"""
|
|
|
self.max_iterations = max_iterations
|
|
|
self.llm_model = llm_model
|
|
|
|
|
|
|
|
|
if llm_client:
|
|
|
self.llm_client = llm_client
|
|
|
print(f"[INFO] Retrieval Supervisor: Using provided LLM client")
|
|
|
elif "gpt-oss" in llm_model:
|
|
|
reasoning_effort = "low"
|
|
|
reasoning_format = "hidden"
|
|
|
self.llm_client = init_chat_model(
|
|
|
llm_model,
|
|
|
temperature=0.1,
|
|
|
reasoning_effort=reasoning_effort,
|
|
|
reasoning_format=reasoning_format,
|
|
|
)
|
|
|
print(
|
|
|
f"[INFO] Retrieval Supervisor: Using GPT-OSS model: {llm_model} with reasoning effort: {reasoning_effort}"
|
|
|
)
|
|
|
else:
|
|
|
self.llm_client = init_chat_model(llm_model, temperature=0.1)
|
|
|
print(f"[INFO] Retrieval Supervisor: Initialized with {llm_model}")
|
|
|
|
|
|
|
|
|
|
|
|
self.database_agent = self._initialize_database_agent(kb_path)
|
|
|
self.grader_agent = self._initialize_grader_agent()
|
|
|
|
|
|
|
|
|
self.supervisor = self._create_supervisor()
|
|
|
|
|
|
def _initialize_cti_agent(self) -> CTIAgent:
|
|
|
"""Initialize the CTI Agent."""
|
|
|
try:
|
|
|
cti_agent = CTIAgent(llm=self.llm_client)
|
|
|
print("CTI Agent initialized successfully")
|
|
|
return cti_agent
|
|
|
except Exception as e:
|
|
|
print(f"Failed to initialize CTI Agent: {e}")
|
|
|
raise
|
|
|
|
|
|
def _initialize_database_agent(self, kb_path: str) -> DatabaseAgent:
|
|
|
"""Initialize the Database Agent."""
|
|
|
try:
|
|
|
database_agent = DatabaseAgent(
|
|
|
kb_path=kb_path,
|
|
|
llm_client=self.llm_client,
|
|
|
)
|
|
|
print("Database Agent initialized successfully")
|
|
|
return database_agent
|
|
|
except Exception as e:
|
|
|
print(f"Failed to initialize Database Agent: {e}")
|
|
|
raise
|
|
|
|
|
|
def _initialize_grader_agent(self):
|
|
|
"""Initialize the Grader Agent as a ReAct agent with no tools."""
|
|
|
return create_react_agent(
|
|
|
model=self.llm_client,
|
|
|
tools=[],
|
|
|
prompt=GRADER_AGENT_PROMPT,
|
|
|
name="retrieval_grader_agent",
|
|
|
)
|
|
|
|
|
|
def _create_supervisor(self):
|
|
|
"""Create the supervisor using langgraph_supervisor."""
|
|
|
|
|
|
|
|
|
agents = [
|
|
|
self.database_agent.agent,
|
|
|
self.grader_agent,
|
|
|
]
|
|
|
|
|
|
|
|
|
supervisor_prompt = SUPERVISOR_PROMPT_TEMPLATE.format(
|
|
|
max_iterations=self.max_iterations
|
|
|
)
|
|
|
|
|
|
return create_supervisor(
|
|
|
model=self.llm_client,
|
|
|
agents=agents,
|
|
|
prompt=supervisor_prompt,
|
|
|
add_handoff_back_messages=True,
|
|
|
|
|
|
supervisor_name="retrieval_supervisor",
|
|
|
).compile(name="retrieval_supervisor")
|
|
|
|
|
|
def invoke(
|
|
|
self,
|
|
|
query: str,
|
|
|
log_analysis_report: Optional[Dict[str, Any]] = None,
|
|
|
context: Optional[str] = None,
|
|
|
trace: bool = False,
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Invoke the retrieval supervisor pipeline.
|
|
|
|
|
|
Args:
|
|
|
query: The intelligence retrieval query/task
|
|
|
log_analysis_report: Optional log analysis report from log analysis agent
|
|
|
context: Optional additional context
|
|
|
trace: Whether to trace the pipeline
|
|
|
Returns:
|
|
|
Dictionary containing the structured retrieval results
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
input_content = self._build_input_message(
|
|
|
query, log_analysis_report, context
|
|
|
)
|
|
|
|
|
|
|
|
|
initial_state = {"messages": [HumanMessage(content=input_content)]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_result = self.supervisor.invoke(initial_state)
|
|
|
|
|
|
if trace:
|
|
|
self._print_trace_pipeline(raw_result)
|
|
|
|
|
|
|
|
|
structured_result = self._parse_supervisor_output(raw_result, query)
|
|
|
|
|
|
return structured_result
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Retrieval Supervisor pipeline failed: {e}")
|
|
|
raise
|
|
|
|
|
|
def invoke_direct_query(self, query: str, trace: bool = False) -> Dict[str, Any]:
|
|
|
"""Invoke the retrieval supervisor pipeline with a direct query."""
|
|
|
raw_result = self.supervisor.invoke({"messages": [HumanMessage(content=query)]})
|
|
|
if trace:
|
|
|
self._print_trace_pipeline(raw_result)
|
|
|
|
|
|
|
|
|
structured_result = self._parse_supervisor_output(raw_result, query)
|
|
|
return structured_result
|
|
|
|
|
|
def stream(
|
|
|
self,
|
|
|
query: str,
|
|
|
log_analysis_report: Optional[Dict[str, Any]] = None,
|
|
|
context: Optional[str] = None,
|
|
|
):
|
|
|
|
|
|
input_content = self._build_input_message(query, log_analysis_report, context)
|
|
|
|
|
|
|
|
|
initial_state = {"messages": [HumanMessage(content=input_content)]}
|
|
|
|
|
|
for chunk in self.supervisor.stream(initial_state, subgraphs=True):
|
|
|
self._pretty_print_messages(chunk, last_message=True)
|
|
|
|
|
|
def _pretty_print_message(self, message, indent=False):
|
|
|
pretty_message = message.pretty_repr(html=True)
|
|
|
if not indent:
|
|
|
print(pretty_message)
|
|
|
return
|
|
|
|
|
|
indented = "\n".join("\t" + c for c in pretty_message.split("\n"))
|
|
|
print(indented)
|
|
|
|
|
|
def _pretty_print_messages(self, update, last_message=False):
|
|
|
is_subgraph = False
|
|
|
if isinstance(update, tuple):
|
|
|
ns, update = update
|
|
|
|
|
|
if len(ns) == 0:
|
|
|
return
|
|
|
|
|
|
graph_id = ns[-1].split(":")[0]
|
|
|
print(f"Update from subgraph {graph_id}:")
|
|
|
print("\n")
|
|
|
is_subgraph = True
|
|
|
|
|
|
for node_name, node_update in update.items():
|
|
|
update_label = f"Update from node {node_name}:"
|
|
|
if is_subgraph:
|
|
|
update_label = "\t" + update_label
|
|
|
|
|
|
print(update_label)
|
|
|
print("\n")
|
|
|
|
|
|
messages = convert_to_messages(node_update["messages"])
|
|
|
if last_message:
|
|
|
messages = messages[-1:]
|
|
|
|
|
|
for m in messages:
|
|
|
self._pretty_print_message(m, indent=is_subgraph)
|
|
|
print("\n")
|
|
|
|
|
|
def _print_trace_pipeline(self, result: Dict[str, Any]):
|
|
|
"""Print detailed trace of the pipeline execution with message flow."""
|
|
|
messages = result.get("messages", [])
|
|
|
|
|
|
if not messages:
|
|
|
print("[TRACE] No messages found in pipeline result")
|
|
|
return
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("PIPELINE EXECUTION TRACE")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
for i, msg in enumerate(messages, 1):
|
|
|
print(f"\n--- Message {i} ---")
|
|
|
|
|
|
if isinstance(msg, HumanMessage):
|
|
|
print(f"[Human] {msg.content}")
|
|
|
|
|
|
elif isinstance(msg, AIMessage):
|
|
|
agent_name = getattr(msg, "name", None) or "agent"
|
|
|
print(f"[Agent:{agent_name}] {msg.content}")
|
|
|
|
|
|
|
|
|
if (
|
|
|
hasattr(msg, "additional_kwargs")
|
|
|
and "function_call" in msg.additional_kwargs
|
|
|
):
|
|
|
fc = msg.additional_kwargs["function_call"]
|
|
|
print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}")
|
|
|
|
|
|
elif isinstance(msg, ToolMessage):
|
|
|
tool_name = getattr(msg, "name", None) or "tool"
|
|
|
content = (
|
|
|
msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
|
)
|
|
|
|
|
|
preview = content[:300] + ("..." if len(content) > 300 else "")
|
|
|
print(f"[Tool:{tool_name}] {preview}")
|
|
|
|
|
|
else:
|
|
|
print(f"[Message] {getattr(msg, 'content', '')}")
|
|
|
|
|
|
|
|
|
if messages:
|
|
|
latest_message = messages[-1]
|
|
|
if isinstance(latest_message, AIMessage):
|
|
|
print(f"\n--- Final Supervisor Output ---")
|
|
|
print(latest_message.content)
|
|
|
|
|
|
|
|
|
if "decision" in latest_message.content.lower():
|
|
|
try:
|
|
|
|
|
|
content = latest_message.content
|
|
|
if "{" in content and "}" in content:
|
|
|
start = content.find("{")
|
|
|
end = content.rfind("}") + 1
|
|
|
decision_json = json.loads(content[start:end])
|
|
|
|
|
|
decision = decision_json.get("decision", "unknown")
|
|
|
print(
|
|
|
f"\n[SUCCESS] Pipeline completed - Decision: {decision}"
|
|
|
)
|
|
|
|
|
|
if decision == "ACCEPT":
|
|
|
print("Results accepted by grader")
|
|
|
elif decision == "NEEDS_MITRE":
|
|
|
print("Additional MITRE technique analysis needed")
|
|
|
|
|
|
except (json.JSONDecodeError, KeyError):
|
|
|
print("\n[INFO] Pipeline completed (decision parsing failed)")
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("TRACE COMPLETED")
|
|
|
print("=" * 60)
|
|
|
|
|
|
def _build_input_message(
|
|
|
self,
|
|
|
query: str,
|
|
|
log_analysis_report: Optional[Dict[str, Any]],
|
|
|
context: Optional[str],
|
|
|
) -> str:
|
|
|
"""Build the input message for the supervisor."""
|
|
|
|
|
|
|
|
|
log_analysis_section = ""
|
|
|
if log_analysis_report:
|
|
|
log_analysis_section = LOG_ANALYSIS_SECTION_TEMPLATE.format(
|
|
|
log_analysis_report=json.dumps(log_analysis_report, indent=2)
|
|
|
)
|
|
|
|
|
|
|
|
|
context_section = ""
|
|
|
if context:
|
|
|
context_section = CONTEXT_SECTION_TEMPLATE.format(context=context)
|
|
|
|
|
|
|
|
|
input_message = INPUT_MESSAGE_TEMPLATE.format(
|
|
|
query=query,
|
|
|
log_analysis_section=log_analysis_section,
|
|
|
context_section=context_section,
|
|
|
)
|
|
|
|
|
|
return input_message
|
|
|
|
|
|
def _parse_supervisor_output(
|
|
|
self, raw_result: Dict[str, Any], original_query: str
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Parse the supervisor's structured output from the raw result."""
|
|
|
messages = raw_result.get("messages", [])
|
|
|
|
|
|
|
|
|
final_supervisor_message = None
|
|
|
for msg in reversed(messages):
|
|
|
if (
|
|
|
hasattr(msg, "name")
|
|
|
and msg.name == "supervisor"
|
|
|
and hasattr(msg, "content")
|
|
|
and msg.content
|
|
|
):
|
|
|
final_supervisor_message = msg.content
|
|
|
break
|
|
|
|
|
|
if not final_supervisor_message:
|
|
|
|
|
|
if messages:
|
|
|
final_supervisor_message = (
|
|
|
messages[-1].content if hasattr(messages[-1], "content") else ""
|
|
|
)
|
|
|
|
|
|
|
|
|
structured_output = self._extract_json_from_content(final_supervisor_message)
|
|
|
|
|
|
if structured_output:
|
|
|
|
|
|
return self._validate_and_enhance_output(
|
|
|
structured_output, original_query, messages
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
return self._create_fallback_output(messages, original_query)
|
|
|
|
|
|
def _extract_json_from_content(self, content: str) -> Optional[Dict[str, Any]]:
|
|
|
"""Extract JSON from supervisor message content."""
|
|
|
if not content:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if "```json" in content:
|
|
|
json_blocks = content.split("```json")
|
|
|
for block in json_blocks[1:]:
|
|
|
json_str = block.split("```")[0].strip()
|
|
|
try:
|
|
|
return json.loads(json_str)
|
|
|
except json.JSONDecodeError:
|
|
|
continue
|
|
|
|
|
|
|
|
|
start_idx = 0
|
|
|
while True:
|
|
|
start_idx = content.find("{", start_idx)
|
|
|
if start_idx == -1:
|
|
|
break
|
|
|
|
|
|
|
|
|
brace_count = 0
|
|
|
end_idx = start_idx
|
|
|
for i in range(start_idx, len(content)):
|
|
|
if content[i] == "{":
|
|
|
brace_count += 1
|
|
|
elif content[i] == "}":
|
|
|
brace_count -= 1
|
|
|
if brace_count == 0:
|
|
|
end_idx = i + 1
|
|
|
break
|
|
|
|
|
|
if brace_count == 0:
|
|
|
json_str = content[start_idx:end_idx]
|
|
|
try:
|
|
|
return json.loads(json_str)
|
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
|
|
|
|
start_idx += 1
|
|
|
|
|
|
return None
|
|
|
|
|
|
def _validate_and_enhance_output(
|
|
|
self, structured_output: Dict[str, Any], original_query: str, messages: List
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Validate and enhance the structured output."""
|
|
|
|
|
|
if "status" not in structured_output:
|
|
|
structured_output["status"] = "SUCCESS"
|
|
|
|
|
|
if "final_assessment" not in structured_output:
|
|
|
structured_output["final_assessment"] = "ACCEPTED"
|
|
|
|
|
|
if "retrieved_techniques" not in structured_output:
|
|
|
structured_output["retrieved_techniques"] = []
|
|
|
|
|
|
if "agents_used" not in structured_output:
|
|
|
|
|
|
agents_used = set()
|
|
|
for msg in messages:
|
|
|
if hasattr(msg, "name") and msg.name:
|
|
|
agents_used.add(str(msg.name))
|
|
|
structured_output["agents_used"] = list(agents_used)
|
|
|
|
|
|
if "summary" not in structured_output:
|
|
|
technique_count = len(structured_output.get("retrieved_techniques", []))
|
|
|
structured_output["summary"] = (
|
|
|
f"Retrieved {technique_count} MITRE techniques for analysis"
|
|
|
)
|
|
|
|
|
|
if "iteration_count" not in structured_output:
|
|
|
structured_output["iteration_count"] = 1
|
|
|
|
|
|
|
|
|
structured_output["query"] = original_query
|
|
|
structured_output["total_techniques"] = len(
|
|
|
structured_output.get("retrieved_techniques", [])
|
|
|
)
|
|
|
|
|
|
return structured_output
|
|
|
|
|
|
def _create_fallback_output(
|
|
|
self, messages: List, original_query: str
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Create fallback structured output when JSON parsing fails."""
|
|
|
|
|
|
techniques = []
|
|
|
agents_used = set()
|
|
|
|
|
|
for msg in messages:
|
|
|
if hasattr(msg, "name") and msg.name:
|
|
|
agents_used.add(str(msg.name))
|
|
|
|
|
|
|
|
|
if "database" in str(msg.name).lower() and hasattr(msg, "content"):
|
|
|
try:
|
|
|
|
|
|
if hasattr(msg, "name") and "search_techniques" in str(
|
|
|
msg.name
|
|
|
):
|
|
|
tool_data = (
|
|
|
json.loads(msg.content)
|
|
|
if isinstance(msg.content, str)
|
|
|
else msg.content
|
|
|
)
|
|
|
if "techniques" in tool_data:
|
|
|
for tech in tool_data["techniques"]:
|
|
|
|
|
|
tactics = tech.get("tactics", [])
|
|
|
if isinstance(tactics, str):
|
|
|
tactics = [tactics] if tactics else []
|
|
|
elif not isinstance(tactics, list):
|
|
|
tactics = []
|
|
|
|
|
|
technique = {
|
|
|
"technique_id": tech.get("attack_id", ""),
|
|
|
"technique_name": tech.get("name", ""),
|
|
|
"tactic": tactics,
|
|
|
"description": tech.get("description", ""),
|
|
|
"relevance_score": tech.get(
|
|
|
"relevance_score", 0.5
|
|
|
),
|
|
|
}
|
|
|
techniques.append(technique)
|
|
|
except (json.JSONDecodeError, TypeError, AttributeError):
|
|
|
continue
|
|
|
|
|
|
return {
|
|
|
"status": "PARTIAL",
|
|
|
"final_assessment": "NEEDS_MORE_INFO",
|
|
|
"retrieved_techniques": techniques,
|
|
|
"agents_used": list(agents_used),
|
|
|
"summary": f"Retrieved {len(techniques)} MITRE techniques (fallback parsing)",
|
|
|
"iteration_count": 1,
|
|
|
"query": original_query,
|
|
|
"total_techniques": len(techniques),
|
|
|
"parsing_method": "fallback",
|
|
|
}
|
|
|
|
|
|
def _process_results(
|
|
|
self, result: Dict[str, Any], original_query: str
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Process and format the supervisor results."""
|
|
|
|
|
|
messages = result.get("messages", [])
|
|
|
|
|
|
|
|
|
agents_used = set()
|
|
|
cti_results = []
|
|
|
database_results = []
|
|
|
grader_decisions = []
|
|
|
|
|
|
for msg in messages:
|
|
|
if hasattr(msg, "name"):
|
|
|
agent_name = msg.name
|
|
|
if agent_name:
|
|
|
agents_used.add(str(agent_name))
|
|
|
|
|
|
if agent_name == "database_agent":
|
|
|
database_results.append(msg.content)
|
|
|
elif agent_name == "retrieval_grader_agent":
|
|
|
grader_decisions.append(msg.content)
|
|
|
|
|
|
|
|
|
final_message = ""
|
|
|
for msg in reversed(messages):
|
|
|
if (
|
|
|
isinstance(msg, AIMessage)
|
|
|
and hasattr(msg, "name")
|
|
|
and msg.name == "supervisor"
|
|
|
):
|
|
|
final_message = msg.content
|
|
|
break
|
|
|
|
|
|
|
|
|
final_assessment = self._determine_final_assessment(
|
|
|
grader_decisions, final_message
|
|
|
)
|
|
|
|
|
|
|
|
|
recommendations = self._extract_recommendations(
|
|
|
cti_results, database_results, grader_decisions
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"status": "SUCCESS",
|
|
|
"query": original_query,
|
|
|
"agents_used": [
|
|
|
a for a in list(agents_used) if isinstance(a, str) and a.strip()
|
|
|
],
|
|
|
"results": {
|
|
|
"cti_intelligence": cti_results,
|
|
|
"mitre_techniques": database_results,
|
|
|
"quality_assessments": grader_decisions,
|
|
|
"supervisor_synthesis": final_message,
|
|
|
},
|
|
|
"final_assessment": final_assessment,
|
|
|
"recommendations": recommendations,
|
|
|
"message_history": messages,
|
|
|
"summary": self._generate_summary(
|
|
|
cti_results, database_results, final_assessment
|
|
|
),
|
|
|
}
|
|
|
|
|
|
def _determine_final_assessment(
|
|
|
self, grader_decisions: List[str], final_message: str
|
|
|
) -> str:
|
|
|
"""Determine the final assessment based on grader decisions."""
|
|
|
|
|
|
|
|
|
if grader_decisions:
|
|
|
latest_decision = grader_decisions[-1]
|
|
|
try:
|
|
|
|
|
|
if "{" in latest_decision and "}" in latest_decision:
|
|
|
start = latest_decision.find("{")
|
|
|
end = latest_decision.rfind("}") + 1
|
|
|
decision_json = json.loads(latest_decision[start:end])
|
|
|
return decision_json.get("decision", "UNKNOWN")
|
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
content = (final_message + " " + " ".join(grader_decisions)).lower()
|
|
|
if "accept" in content:
|
|
|
return "ACCEPTED"
|
|
|
elif "needs_both" in content:
|
|
|
return "NEEDS_BOTH"
|
|
|
elif "needs_cti" in content:
|
|
|
return "NEEDS_CTI"
|
|
|
elif "needs_mitre" in content:
|
|
|
return "NEEDS_MITRE"
|
|
|
else:
|
|
|
return "COMPLETED"
|
|
|
|
|
|
def _extract_recommendations(
|
|
|
self,
|
|
|
cti_results: List[str],
|
|
|
database_results: List[str],
|
|
|
grader_decisions: List[str],
|
|
|
) -> List[str]:
|
|
|
"""Extract actionable recommendations from the results."""
|
|
|
|
|
|
recommendations = []
|
|
|
|
|
|
|
|
|
if cti_results:
|
|
|
recommendations.append("Review CTI findings for threat actor attribution")
|
|
|
recommendations.append("Implement IOC-based detection rules")
|
|
|
|
|
|
if database_results:
|
|
|
recommendations.append("Map detected techniques to defensive controls")
|
|
|
recommendations.append("Update threat hunting playbooks")
|
|
|
|
|
|
|
|
|
for decision in grader_decisions:
|
|
|
try:
|
|
|
if "{" in decision and "}" in decision:
|
|
|
start = decision.find("{")
|
|
|
end = decision.rfind("}") + 1
|
|
|
decision_json = json.loads(decision[start:end])
|
|
|
suggestions = decision_json.get("improvement_suggestions", [])
|
|
|
recommendations.extend(suggestions)
|
|
|
except json.JSONDecodeError:
|
|
|
continue
|
|
|
|
|
|
|
|
|
unique_recommendations = list(dict.fromkeys(recommendations))
|
|
|
return unique_recommendations[:5]
|
|
|
|
|
|
def _generate_summary(
|
|
|
self, cti_results: List[str], database_results: List[str], final_assessment: str
|
|
|
) -> str:
|
|
|
"""Generate a concise summary of the retrieval results."""
|
|
|
|
|
|
summary_parts = [
|
|
|
f"Retrieval Status: {final_assessment}",
|
|
|
f"CTI Sources Analyzed: {len(cti_results)}",
|
|
|
f"MITRE Techniques Retrieved: {len(database_results)}",
|
|
|
]
|
|
|
|
|
|
if cti_results:
|
|
|
summary_parts.append("Threat intelligence gathered from external sources")
|
|
|
if database_results:
|
|
|
summary_parts.append("MITRE ATT&CK techniques mapped to findings")
|
|
|
|
|
|
return " | ".join(summary_parts)
|
|
|
|
|
|
def stream_invoke(
|
|
|
self,
|
|
|
query: str,
|
|
|
log_analysis_report: Optional[Dict[str, Any]] = None,
|
|
|
context: Optional[str] = None,
|
|
|
):
|
|
|
"""
|
|
|
Stream the retrieval supervisor pipeline execution.
|
|
|
|
|
|
Args:
|
|
|
query: The intelligence retrieval query/task
|
|
|
log_analysis_report: Optional log analysis report from log analysis agent
|
|
|
context: Optional additional context
|
|
|
|
|
|
Yields:
|
|
|
Streaming updates from the supervisor pipeline
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
input_content = self._build_input_message(
|
|
|
query, log_analysis_report, context
|
|
|
)
|
|
|
|
|
|
|
|
|
initial_state = {"messages": [HumanMessage(content=input_content)]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for chunk in self.supervisor.stream(initial_state):
|
|
|
yield chunk
|
|
|
|
|
|
except Exception as e:
|
|
|
yield {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
def test_retrieval_supervisor():
|
|
|
"""Test the Retrieval Supervisor with sample data."""
|
|
|
|
|
|
|
|
|
sample_report = {
|
|
|
"overall_assessment": "ABNORMAL",
|
|
|
"total_events_analyzed": 245,
|
|
|
"analysis_summary": "Detected suspicious PowerShell execution with base64 encoding and potential credential access attempts targeting LSASS process",
|
|
|
"abnormal_events": [
|
|
|
{
|
|
|
"event_id": "4688",
|
|
|
"event_description": "PowerShell process creation with encoded command parameter",
|
|
|
"why_abnormal": "Base64 encoded command suggests obfuscation and evasion techniques",
|
|
|
"severity": "HIGH",
|
|
|
"potential_threat": "Defense evasion or malware execution",
|
|
|
"attack_category": "defense_evasion",
|
|
|
},
|
|
|
{
|
|
|
"event_id": "4656",
|
|
|
"event_description": "Handle request to LSASS process memory",
|
|
|
"why_abnormal": "Unusual access pattern to sensitive authentication process",
|
|
|
"severity": "CRITICAL",
|
|
|
"potential_threat": "Credential dumping attack",
|
|
|
"attack_category": "credential_access",
|
|
|
},
|
|
|
],
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
|
|
|
supervisor = RetrievalSupervisor()
|
|
|
|
|
|
|
|
|
query = "Analyze the detected PowerShell and LSASS access patterns. Provide threat intelligence on related attack campaigns and map to MITRE ATT&CK techniques."
|
|
|
|
|
|
|
|
|
results = supervisor.invoke(
|
|
|
query=query,
|
|
|
log_analysis_report=sample_report,
|
|
|
context="High-priority security incident requiring immediate threat intelligence",
|
|
|
trace=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("RETRIEVAL RESULTS SUMMARY")
|
|
|
print("=" * 60)
|
|
|
print(f"Status: {results['status']}")
|
|
|
print(f"Final Assessment: {results['final_assessment']}")
|
|
|
print(f"Agents Used: {', '.join(results['agents_used'])}")
|
|
|
print(f"\nSummary: {results['summary']}")
|
|
|
|
|
|
print("\nRecommendations:")
|
|
|
for i, rec in enumerate(results["recommendations"], 1):
|
|
|
print(f"{i}. {rec}")
|
|
|
|
|
|
return results
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Test failed: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_retrieval_supervisor()
|
|
|
|