Log-Analysis-MultiAgent / src /scripts /test_retrieval_supervisor.py
minhan6559's picture
Upload 126 files
223ef32 verified
raw
history blame
13.9 kB
"""
Test script for the new Retrieval Supervisor pipeline.
This script uses the RetrievalSupervisor class to run the complete retrieval pipeline
on log analysis reports, providing comprehensive threat intelligence and MITRE ATT&CK
technique retrieval.
"""
import os
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
import sys
import json
import argparse
from typing import Dict, Any, Optional
from pathlib import Path
# Add the project root to Python path so we can import from src
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Import the RetrievalSupervisor
try:
from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor
except ImportError as e:
print(f"[ERROR] Could not import RetrievalSupervisor: {e}")
print("Please ensure the supervisor.py file is in the correct location.")
sys.exit(1)
load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
def load_log_analysis_report(file_path: str) -> Dict[str, Any]:
"""Load log analysis report from JSON file."""
try:
with open(file_path, "r", encoding="utf-8") as f:
report = json.load(f)
print(f"[SUCCESS] Loaded log analysis report from {file_path}")
return report
except FileNotFoundError:
print(f"[ERROR] Log analysis report file not found: {file_path}")
sys.exit(1)
except json.JSONDecodeError as e:
print(f"[ERROR] Invalid JSON in log analysis report: {e}")
sys.exit(1)
except Exception as e:
print(f"[ERROR] Error loading report: {e}")
sys.exit(1)
def validate_report_structure(report: Dict[str, Any]) -> bool:
"""Validate that the report has the expected structure."""
required_fields = ["overall_assessment", "analysis_summary"]
for field in required_fields:
if field not in report:
print(f"[WARNING] Missing field '{field}' in report")
return False
# Check for abnormal events if present
if "abnormal_events" in report:
if not isinstance(report["abnormal_events"], list):
print("[WARNING] 'abnormal_events' should be a list")
return False
for i, event in enumerate(report["abnormal_events"]):
if not isinstance(event, dict):
print(f"[WARNING] Event {i} is not a dictionary")
return False
event_required = [
"event_id",
"event_description",
"why_abnormal",
"severity",
]
for field in event_required:
if field not in event:
print(f"[WARNING] Event {i} missing field '{field}'")
return False
return True
def run_retrieval_pipeline(
report_path: str,
llm_model: str = "google_genai:gemini-2.5-flash",
kb_path: str = "./cyber_knowledge_base",
max_iterations: int = 3,
context: Optional[str] = None,
interactive: bool = False,
):
"""Run the complete retrieval pipeline using RetrievalSupervisor."""
# Load the log analysis report
report = load_log_analysis_report(report_path)
# Validate report structure
if not validate_report_structure(report):
print("[WARNING] Report structure validation failed, but continuing...")
# Initialize the RetrievalSupervisor
print("\n" + "=" * 60)
print("INITIALIZING RETRIEVAL SUPERVISOR")
print("=" * 60)
try:
supervisor = RetrievalSupervisor(
llm_model=llm_model,
kb_path=kb_path,
max_iterations=max_iterations,
)
except Exception as e:
print(f"[ERROR] Failed to initialize RetrievalSupervisor: {e}")
return None
# Generate query based on report content
query = "Analyze this IOCs report from log analysis agent and retrieve relevant MITRE ATT&CK techniques"
print("\n" + "=" * 60)
print("RUNNING RETRIEVAL PIPELINE")
print("=" * 60)
print(f"Query: {query}")
print(f"Report Assessment: {report.get('overall_assessment', 'Unknown')}")
print(f"Context: {context}")
print()
# Execute the retrieval pipeline
try:
results = supervisor.invoke(
query=query,
log_analysis_report=report,
context=context,
trace=True,
)
# Display results
display_results(results)
return results
except Exception as e:
print(f"[ERROR] Pipeline execution failed: {e}")
return None
def generate_query_from_report(report: Dict[str, Any]) -> str:
"""Generate a comprehensive query based on the log analysis report."""
# Base query components
query_parts = [
"Analyze the detected security anomalies and provide comprehensive threat intelligence."
]
# Add specific analysis based on report content
if "abnormal_events" in report and report["abnormal_events"]:
query_parts.append("Focus on the following detected anomalies:")
for i, event in enumerate(
report["abnormal_events"][:3], 1
): # Limit to top 3 events
event_desc = event.get("event_description", "Unknown event")
threat = event.get("potential_threat", "Unknown threat")
category = event.get("attack_category", "Unknown category")
query_parts.append(
f"{i}. {event_desc} - Potential: {threat} (Category: {category})"
)
# Add analysis summary if available
if "analysis_summary" in report:
query_parts.append(f"Analysis Summary: {report['analysis_summary']}")
# Add specific intelligence requirements
query_parts.extend(
[
"",
"Please provide:",
"1. Relevant threat intelligence from CTI sources",
"2. MITRE ATT&CK technique mapping and tactical analysis",
"3. Actionable recommendations for threat hunting and defense",
"4. IOCs and indicators for detection rules",
]
)
return "\n".join(query_parts)
def display_results(results: Dict[str, Any]):
"""Display the retrieval results in a formatted way."""
print("\n" + "=" * 60)
print("RETRIEVAL RESULTS")
print("=" * 60)
# Basic status information
print(f"Status: {results.get('status', 'Unknown')}")
print(f"Final Assessment: {results.get('final_assessment', 'Unknown')}")
print(f"Agents Used: {', '.join(results.get('agents_used', []))}")
print(f"Summary: {results.get('summary', 'No summary available')}")
print(f"Total Techniques: {results.get('total_techniques', 0)}")
print(f"Iteration Count: {results.get('iteration_count', 0)}")
# Display structured techniques
retrieved_techniques = results.get("retrieved_techniques", [])
if retrieved_techniques:
print(f"\nRetrieved MITRE Techniques ({len(retrieved_techniques)}):")
for i, technique in enumerate(retrieved_techniques, 1):
print(f"\n {i}. {technique.get('technique_id', 'N/A')}: {technique.get('technique_name', 'N/A')}")
print(f" Tactic: {technique.get('tactic', 'N/A')}")
print(f" Relevance Score: {technique.get('relevance_score', 0)}")
description = technique.get('description', 'No description')
if len(description) > 100:
description = description[:100] + "..."
print(f" Description: {description}")
else:
print("\nNo techniques retrieved")
# Display recommendations (if available)
recommendations = results.get("recommendations", [])
if recommendations:
print(f"\nRecommendations ({len(recommendations)}):")
for i, rec in enumerate(recommendations, 1):
print(f" {i}. {rec}")
# Display detailed results (legacy format for backward compatibility)
detailed_results = results.get("results", {})
if detailed_results:
print(f"\nDetailed Results (Legacy Format):")
# CTI Intelligence
cti_intelligence = detailed_results.get("cti_intelligence", [])
if cti_intelligence:
print(f"\n CTI Intelligence ({len(cti_intelligence)} sources):")
for i, cti in enumerate(cti_intelligence, 1):
preview = str(cti)[:200] + "..." if len(str(cti)) > 200 else str(cti)
print(f" {i}. {preview}")
# MITRE Techniques
mitre_techniques = detailed_results.get("mitre_techniques", [])
if mitre_techniques:
print(f"\n MITRE Techniques ({len(mitre_techniques)} retrieved):")
for i, technique in enumerate(mitre_techniques, 1):
preview = (
str(technique)[:200] + "..."
if len(str(technique)) > 200
else str(technique)
)
print(f" {i}. {preview}")
# Quality Assessments
quality_assessments = detailed_results.get("quality_assessments", [])
if quality_assessments:
print(f"\n Quality Assessments ({len(quality_assessments)}):")
for i, assessment in enumerate(quality_assessments, 1):
preview = (
str(assessment)[:200] + "..."
if len(str(assessment)) > 200
else str(assessment)
)
print(f" {i}. {preview}")
print("\n" + "=" * 60)
def interactive_mode():
"""Run in interactive mode for multiple reports."""
print("\n=== INTERACTIVE MODE ===")
print("Enter path to a log analysis JSON report, or 'quit' to exit:")
# Get initial configuration
model = (
input("LLM Model (default: google_genai:gemini-2.5-flash): ").strip()
or "google_genai:gemini-2.5-flash"
)
kb_path = (
input("Knowledge Base Path (default: ./cyber_knowledge_base): ").strip()
or "./cyber_knowledge_base"
)
while True:
user_input = input("\nReport JSON path: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
break
if user_input:
try:
run_retrieval_pipeline(
report_path=user_input,
llm_model=model,
kb_path=kb_path,
interactive=False,
)
except Exception as e:
print(f"Error: {str(e)}")
def main():
"""Main function to run the retrieval pipeline."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Test the new Retrieval Supervisor pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python test_new_retrieval_supervisor.py report.json
python test_new_retrieval_supervisor.py report.json --model google_genai:gemini-2.5-flash
python test_new_retrieval_supervisor.py report.json --interactive
python test_new_retrieval_supervisor.py report.json --context "Urgent security incident"
""",
)
parser.add_argument("report_path", help="Path to the log analysis report JSON file")
parser.add_argument(
"--model",
default="google_genai:gemini-2.0-flash",
help="LLM model name (default: google_genai:gemini-2.0-flash)",
)
parser.add_argument(
"--kb-path",
default="./cyber_knowledge_base",
help="Path to the cyber knowledge base directory (default: ./cyber_knowledge_base)",
)
parser.add_argument(
"--max-iterations",
type=int,
default=3,
help="Maximum iterations for the retrieval pipeline (default: 3)",
)
parser.add_argument(
"--context", help="Additional context for the analysis (optional)"
)
parser.add_argument(
"--interactive",
"-i",
action="store_true",
help="Run in interactive mode after pipeline completion",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Enable verbose output"
)
args = parser.parse_args()
# Validate report path
if not os.path.exists(args.report_path):
print(f"[ERROR] Report file not found: {args.report_path}")
sys.exit(1)
# Validate knowledge base path
if not os.path.exists(args.kb_path):
print(f"[WARNING] Knowledge base path not found: {args.kb_path}")
print(
"The pipeline may fail if the knowledge base is not properly initialized."
)
# Run the retrieval pipeline
try:
results = run_retrieval_pipeline(
report_path=args.report_path,
llm_model=args.model,
kb_path=args.kb_path,
max_iterations=args.max_iterations,
context=args.context,
interactive=args.interactive,
)
if results is None:
print("[ERROR] Pipeline execution failed")
sys.exit(1)
# Interactive mode
if args.interactive:
interactive_mode()
print("\n[SUCCESS] Pipeline completed successfully!")
except KeyboardInterrupt:
print("\n[INFO] Pipeline interrupted by user")
sys.exit(0)
except Exception as e:
print(f"[ERROR] Unexpected error: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()