Log-Analysis-MultiAgent / src /scripts /build_cyber_database.py
minhan6559's picture
Upload 126 files
223ef32 verified
raw
history blame
19.7 kB
"""
MITRE ATT&CK Cyber Knowledge Base Management Script
This script manages the MITRE ATT&CK techniques knowledge base with:
- Processing techniques.json file containing MITRE ATT&CK data
- Semantic search using google/embeddinggemma-300m embeddings
- Cross-encoder reranking using Qwen/Qwen3-Reranker-0.6B
- Hybrid search combining ChromaDB (semantic) and BM25 (keyword)
- Metadata filtering by tactics, platforms, and technique attributes
Usage:
python build_cyber_database.py ingest --techniques-json ./processed_data/cti/techniques.json
python build_cyber_database.py test --query "process injection"
python build_cyber_database.py test --interactive
python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation" --filter-platforms "Windows"
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Optional, List
# Add the project root to Python path so we can import from src
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from langchain.text_splitter import TokenTextSplitter
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
def truncate_to_tokens(text: str, max_tokens: int = 300) -> str:
"""
Truncate text to a maximum number of tokens using LangChain's TokenTextSplitter.
Args:
text: The text to truncate
max_tokens: Maximum number of tokens (default: 300)
Returns:
Truncated text within the token limit
"""
if not text:
return ""
# Clean the text by replacing newlines with spaces
cleaned_text = text.replace("\n", " ")
# Use TokenTextSplitter to split by tokens
splitter = TokenTextSplitter(
encoding_name="cl100k_base", chunk_size=max_tokens, chunk_overlap=0
)
chunks = splitter.split_text(cleaned_text)
return chunks[0] if chunks else ""
def validate_techniques_file(techniques_json_path: str) -> bool:
"""Validate that techniques.json exists and is readable"""
if not os.path.exists(techniques_json_path):
print(f"[ERROR] Techniques file not found: {techniques_json_path}")
return False
try:
import json
with open(techniques_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
print(f"[ERROR] Invalid format: techniques.json should contain a list")
return False
if len(data) == 0:
print(f"[ERROR] Empty techniques file")
return False
# Check first item has required fields
first_technique = data[0]
required_fields = ["attack_id", "name", "description"]
missing_fields = [
field for field in required_fields if field not in first_technique
]
if missing_fields:
print(f"[ERROR] Missing required fields in techniques: {missing_fields}")
return False
print(f"[SUCCESS] Valid techniques file with {len(data)} techniques")
return True
except json.JSONDecodeError as e:
print(f"[ERROR] Invalid JSON format: {e}")
return False
except Exception as e:
print(f"[ERROR] Error reading techniques file: {e}")
return False
def ingest_techniques(args):
"""Ingest MITRE ATT&CK techniques and build knowledge base"""
print("=" * 60)
print("[INFO] INGESTING MITRE ATT&CK TECHNIQUES")
print("=" * 60)
# Validate techniques file
if not validate_techniques_file(args.techniques_json):
sys.exit(1)
# Initialize knowledge base
kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
try:
# Build knowledge base
kb.build_knowledge_base(
techniques_json_path=args.techniques_json,
persist_dir=args.persist_dir,
reset=args.reset,
)
# Show final statistics
print("\n[INFO] Knowledge Base Statistics:")
stats = kb.get_stats()
for key, value in stats.items():
if isinstance(value, dict):
print(f" {key}:")
for subkey, subvalue in list(value.items())[:5]: # Show first 5 items
print(f" {subkey}: {subvalue}")
if len(value) > 5:
print(f" ... and {len(value) - 5} more")
else:
print(f" {key}: {value}")
print(f"\n[SUCCESS] Knowledge base saved successfully to {args.persist_dir}!")
return True
except Exception as e:
print(f"[ERROR] Error during ingestion: {e}")
import traceback
traceback.print_exc()
return False
def test_retrieval(args):
"""Test retrieval on existing knowledge base"""
print("=" * 60)
print("[INFO] TESTING CYBER KNOWLEDGE BASE")
print("=" * 60)
# Load knowledge base
kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
# Load knowledge base
success = kb.load_knowledge_base(persist_dir=args.persist_dir)
if not success:
print("[ERROR] Failed to load knowledge base. Run 'ingest' first.")
sys.exit(1)
# Show knowledge base stats
print("\n[INFO] Knowledge Base Statistics:")
stats = kb.get_stats()
for key, value in stats.items():
if isinstance(value, dict):
print(f" {key}:")
for subkey, subvalue in list(value.items())[:5]: # Show first 5 items
print(f" {subkey}: {subvalue}")
if len(value) > 5:
print(f" ... and {len(value) - 5} more")
else:
print(f" {key}: {value}")
if args.interactive:
# Interactive testing mode
run_interactive_tests(kb)
elif args.query:
# Single query testing
test_single_query(kb, args.query, args.filter_tactics, args.filter_platforms)
else:
# Run default test suite
run_test_suite(kb)
def test_single_query(
kb,
query: str,
filter_tactics: Optional[List[str]] = None,
filter_platforms: Optional[List[str]] = None,
):
"""Test a single query with filters"""
print(f"\n[INFO] Testing Query: '{query}'")
if filter_tactics:
print(f"[INFO] Filtering by tactics: {filter_tactics}")
if filter_platforms:
print(f"[INFO] Filtering by platforms: {filter_platforms}")
print("-" * 40)
try:
# Test search with filters
results = kb.search(
query,
top_k=20,
filter_tactics=filter_tactics,
filter_platforms=filter_platforms,
)
display_detailed_results(results)
except Exception as e:
print(f"[ERROR] Error during search: {e}")
import traceback
traceback.print_exc()
def display_detailed_results(results):
"""Display search results with detailed MITRE ATT&CK information"""
if results:
for i, doc in enumerate(results, 1):
attack_id = doc.metadata.get("attack_id", "Unknown")
name = doc.metadata.get("name", "Unknown")
tactics_str = doc.metadata.get("tactics", "")
platforms_str = doc.metadata.get("platforms", "")
is_subtechnique = doc.metadata.get("is_subtechnique", False)
mitigation_count = doc.metadata.get("mitigation_count", 0)
mitigations = doc.metadata.get("mitigations", "")
# Get content preview from description
content_lines = doc.page_content.split("\n")
description_line = next(
(line for line in content_lines if line.startswith("Description:")), ""
)
if description_line:
description = description_line.replace("Description: ", "")
content_preview = truncate_to_tokens(description, 300)
else:
content_preview = truncate_to_tokens(doc.page_content, 300)
mitigation_preview = truncate_to_tokens(mitigations, 300)
print(f" {i}. {attack_id} - {name}")
print(f" Type: {'Sub-technique' if is_subtechnique else 'Technique'}")
print(f" Tactics: {tactics_str if tactics_str else 'None'}")
print(f" Platforms: {platforms_str if platforms_str else 'None'}")
print(
f" Mitigations: {mitigation_preview if mitigation_preview else 'None'}"
)
print(f" Mitigation Count: {mitigation_count}")
print(f" Description: {content_preview}")
print()
else:
print(" No results found")
def run_interactive_tests(kb):
"""Run interactive testing session with filtering options"""
print("\n[INFO] Interactive Testing Mode")
print("Available commands:")
print(" - Enter a query to search")
print(" - 'stats' to view knowledge base statistics")
print(" - 'tactics' to list available tactics")
print(" - 'platforms' to list available platforms")
print(
" - 'filter tactics:defense-evasion,privilege-escalation query' to filter by tactics"
)
print(" - 'filter platforms:Windows,Linux query' to filter by platforms")
print(" - 'technique T1055' to get specific technique info")
print(" - 'quit' to exit")
print("-" * 50)
while True:
try:
user_input = input("\n[INPUT] Enter command: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
break
if not user_input:
continue
# Handle special commands
if user_input.lower() == "stats":
display_stats(kb)
continue
if user_input.lower() == "tactics":
display_available_tactics(kb)
continue
if user_input.lower() == "platforms":
display_available_platforms(kb)
continue
# Handle technique lookup
if user_input.lower().startswith("technique "):
technique_id = user_input.split(" ", 1)[1].strip()
display_technique_info(kb, technique_id)
continue
# Handle filtered queries
filter_tactics = None
filter_platforms = None
query = user_input
if user_input.lower().startswith("filter "):
# Parse filter command: "filter tactics:a,b platforms:x,y query text"
parts = user_input.split(" ")
query_start = 1
for i, part in enumerate(parts[1:], 1):
if part.startswith("tactics:"):
filter_tactics = part.split(":", 1)[1].split(",")
query_start = i + 1
elif part.startswith("platforms:"):
filter_platforms = part.split(":", 1)[1].split(",")
query_start = i + 1
else:
break
query = " ".join(parts[query_start:])
if not query.strip():
print("[ERROR] No query provided")
continue
# Regular search
print(f"\n[INFO] Search: '{query}'")
if filter_tactics:
print(f"[INFO] Filtering by tactics: {filter_tactics}")
if filter_platforms:
print(f"[INFO] Filtering by platforms: {filter_platforms}")
results = kb.search(
query,
top_k=20,
filter_tactics=filter_tactics,
filter_platforms=filter_platforms,
)
display_detailed_results(results)
except KeyboardInterrupt:
print("\n[INFO] Exiting interactive mode...")
break
except Exception as e:
print(f"[ERROR] Error: {e}")
def display_stats(kb):
"""Display detailed knowledge base statistics"""
stats = kb.get_stats()
print("\n[INFO] Knowledge Base Statistics:")
for key, value in stats.items():
if isinstance(value, dict):
print(f" {key}:")
for subkey, subvalue in value.items():
print(f" {subkey}: {subvalue}")
else:
print(f" {key}: {value}")
def display_available_tactics(kb):
"""Display available tactics"""
stats = kb.get_stats()
tactics = stats.get("techniques_by_tactic", {})
if tactics:
print("\n[INFO] Available Tactics:")
for tactic, count in sorted(tactics.items()):
print(f" {tactic}: {count} techniques")
else:
print("\n[INFO] No tactics information available")
def display_available_platforms(kb):
"""Display available platforms"""
stats = kb.get_stats()
platforms = stats.get("techniques_by_platform", {})
if platforms:
print("\n[INFO] Available Platforms:")
for platform, count in sorted(platforms.items()):
print(f" {platform}: {count} techniques")
else:
print("\n[INFO] No platforms information available")
def display_technique_info(kb, technique_id: str):
"""Display detailed information about a specific technique"""
technique = kb.get_technique_by_id(technique_id.upper())
if technique:
print(f"\n[INFO] Technique Details: {technique_id}")
print("-" * 40)
print(f"Name: {technique.get('name', 'Unknown')}")
print(
f"Type: {'Sub-technique' if technique.get('is_subtechnique') else 'Technique'}"
)
print(f"Tactics: {', '.join(technique.get('tactics', []))}")
print(f"Platforms: {', '.join(technique.get('platforms', []))}")
print(f"Mitigations: {len(technique.get('mitigations', []))}")
description = technique.get("description", "")
if description:
print(
f"Description: {description[:500]}{'...' if len(description) > 500 else ''}"
)
detection = technique.get("detection", "")
if detection:
print(
f"Detection: {detection[:300]}{'...' if len(detection) > 300 else ''}"
)
else:
print(f"\n[ERROR] Technique {technique_id} not found")
def run_test_suite(kb):
"""Run comprehensive test suite for cyber techniques"""
test_cases = [
# Process injection techniques
{"query": "process injection", "description": "Process injection techniques"},
{"query": "DLL injection", "description": "DLL injection methods"},
# Privilege escalation
{
"query": "privilege escalation Windows",
"description": "Windows privilege escalation",
},
{"query": "UAC bypass", "description": "UAC bypass techniques"},
# Persistence
{
"query": "scheduled task persistence",
"description": "Scheduled task persistence",
},
{"query": "registry persistence", "description": "Registry-based persistence"},
# Credential access
{
"query": "credential dumping LSASS",
"description": "LSASS credential dumping",
},
{"query": "password spraying", "description": "Password spraying attacks"},
# Defense evasion
{
"query": "defense evasion DLL hijacking",
"description": "DLL hijacking evasion",
},
{"query": "process hollowing", "description": "Process hollowing technique"},
# Lateral movement
{"query": "lateral movement SMB", "description": "SMB lateral movement"},
{"query": "remote desktop protocol", "description": "RDP-based movement"},
]
print("\n[INFO] Running Cyber Security Test Suite:")
print("=" * 50)
for i, test_case in enumerate(test_cases, 1):
print(f"\n#{i} {test_case['description']}")
print(f"Query: '{test_case['query']}'")
print("-" * 30)
try:
results = kb.search(test_case["query"], top_k=3)
display_detailed_results(results)
except Exception as e:
print(f"[ERROR] Error: {e}")
def main():
"""Main entry point with argument parsing"""
parser = argparse.ArgumentParser(
description="MITRE ATT&CK Cyber Knowledge Base Management",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python build_cyber_database.py ingest --techniques-json ./processed_data/cti/techniques.json
python build_cyber_database.py test --query "process injection"
python build_cyber_database.py test --interactive
python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation"
""",
)
# Subcommands
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Ingest command
ingest_parser = subparsers.add_parser(
"ingest", help="Ingest MITRE ATT&CK techniques and build knowledge base"
)
ingest_parser.add_argument(
"--techniques-json",
default="./processed_data/cti/techniques.json",
help="Path to techniques.json file",
)
ingest_parser.add_argument(
"--persist-dir",
default="./cyber_knowledge_base",
help="Directory to store the knowledge base",
)
ingest_parser.add_argument(
"--embedding-model",
default="google/embeddinggemma-300m",
help="Embedding model name",
)
ingest_parser.add_argument(
"--reset",
action="store_true",
default=True,
help="Reset knowledge base before ingestion (default: True)",
)
ingest_parser.add_argument(
"--no-reset",
dest="reset",
action="store_false",
help="Do not reset existing knowledge base",
)
# Test command
test_parser = subparsers.add_parser(
"test", help="Test retrieval on existing knowledge base"
)
test_parser.add_argument("--query", help="Single query to test")
test_parser.add_argument(
"--filter-tactics",
nargs="+",
help="Filter by tactics (e.g., --filter-tactics defense-evasion privilege-escalation)",
)
test_parser.add_argument(
"--filter-platforms",
nargs="+",
help="Filter by platforms (e.g., --filter-platforms Windows Linux)",
)
test_parser.add_argument(
"--interactive", action="store_true", help="Interactive testing mode"
)
test_parser.add_argument(
"--persist-dir",
default="./cyber_knowledge_base",
help="Directory where knowledge base is stored",
)
test_parser.add_argument(
"--embedding-model",
default="google/embeddinggemma-300m",
help="Embedding model name",
)
args = parser.parse_args()
if args.command == "ingest":
success = ingest_techniques(args)
sys.exit(0 if success else 1)
elif args.command == "test":
test_retrieval(args)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()