|
|
"""
|
|
|
MITRE ATT&CK Cyber Knowledge Base Management Script
|
|
|
|
|
|
This script manages the MITRE ATT&CK techniques knowledge base with:
|
|
|
- Processing techniques.json file containing MITRE ATT&CK data
|
|
|
- Semantic search using google/embeddinggemma-300m embeddings
|
|
|
- Cross-encoder reranking using Qwen/Qwen3-Reranker-0.6B
|
|
|
- Hybrid search combining ChromaDB (semantic) and BM25 (keyword)
|
|
|
- Metadata filtering by tactics, platforms, and technique attributes
|
|
|
|
|
|
Usage:
|
|
|
python build_cyber_database.py ingest --techniques-json ./processed_data/cti/techniques.json
|
|
|
python build_cyber_database.py test --query "process injection"
|
|
|
python build_cyber_database.py test --interactive
|
|
|
python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation" --filter-platforms "Windows"
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import os
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
from typing import Optional, List
|
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent
|
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
|
|
from langchain.text_splitter import TokenTextSplitter
|
|
|
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
|
|
|
|
|
|
|
|
|
def truncate_to_tokens(text: str, max_tokens: int = 300) -> str:
|
|
|
"""
|
|
|
Truncate text to a maximum number of tokens using LangChain's TokenTextSplitter.
|
|
|
|
|
|
Args:
|
|
|
text: The text to truncate
|
|
|
max_tokens: Maximum number of tokens (default: 300)
|
|
|
|
|
|
Returns:
|
|
|
Truncated text within the token limit
|
|
|
"""
|
|
|
if not text:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
cleaned_text = text.replace("\n", " ")
|
|
|
|
|
|
|
|
|
splitter = TokenTextSplitter(
|
|
|
encoding_name="cl100k_base", chunk_size=max_tokens, chunk_overlap=0
|
|
|
)
|
|
|
|
|
|
chunks = splitter.split_text(cleaned_text)
|
|
|
return chunks[0] if chunks else ""
|
|
|
|
|
|
|
|
|
def validate_techniques_file(techniques_json_path: str) -> bool:
|
|
|
"""Validate that techniques.json exists and is readable"""
|
|
|
|
|
|
if not os.path.exists(techniques_json_path):
|
|
|
print(f"[ERROR] Techniques file not found: {techniques_json_path}")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
import json
|
|
|
|
|
|
with open(techniques_json_path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
if not isinstance(data, list):
|
|
|
print(f"[ERROR] Invalid format: techniques.json should contain a list")
|
|
|
return False
|
|
|
|
|
|
if len(data) == 0:
|
|
|
print(f"[ERROR] Empty techniques file")
|
|
|
return False
|
|
|
|
|
|
|
|
|
first_technique = data[0]
|
|
|
required_fields = ["attack_id", "name", "description"]
|
|
|
missing_fields = [
|
|
|
field for field in required_fields if field not in first_technique
|
|
|
]
|
|
|
|
|
|
if missing_fields:
|
|
|
print(f"[ERROR] Missing required fields in techniques: {missing_fields}")
|
|
|
return False
|
|
|
|
|
|
print(f"[SUCCESS] Valid techniques file with {len(data)} techniques")
|
|
|
return True
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
print(f"[ERROR] Invalid JSON format: {e}")
|
|
|
return False
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error reading techniques file: {e}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
def ingest_techniques(args):
|
|
|
"""Ingest MITRE ATT&CK techniques and build knowledge base"""
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("[INFO] INGESTING MITRE ATT&CK TECHNIQUES")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
if not validate_techniques_file(args.techniques_json):
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
|
|
|
|
|
|
try:
|
|
|
|
|
|
kb.build_knowledge_base(
|
|
|
techniques_json_path=args.techniques_json,
|
|
|
persist_dir=args.persist_dir,
|
|
|
reset=args.reset,
|
|
|
)
|
|
|
|
|
|
|
|
|
print("\n[INFO] Knowledge Base Statistics:")
|
|
|
stats = kb.get_stats()
|
|
|
for key, value in stats.items():
|
|
|
if isinstance(value, dict):
|
|
|
print(f" {key}:")
|
|
|
for subkey, subvalue in list(value.items())[:5]:
|
|
|
print(f" {subkey}: {subvalue}")
|
|
|
if len(value) > 5:
|
|
|
print(f" ... and {len(value) - 5} more")
|
|
|
else:
|
|
|
print(f" {key}: {value}")
|
|
|
|
|
|
print(f"\n[SUCCESS] Knowledge base saved successfully to {args.persist_dir}!")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error during ingestion: {e}")
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
|
|
|
def test_retrieval(args):
|
|
|
"""Test retrieval on existing knowledge base"""
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("[INFO] TESTING CYBER KNOWLEDGE BASE")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
|
|
|
|
|
|
|
|
|
success = kb.load_knowledge_base(persist_dir=args.persist_dir)
|
|
|
|
|
|
if not success:
|
|
|
print("[ERROR] Failed to load knowledge base. Run 'ingest' first.")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
print("\n[INFO] Knowledge Base Statistics:")
|
|
|
stats = kb.get_stats()
|
|
|
for key, value in stats.items():
|
|
|
if isinstance(value, dict):
|
|
|
print(f" {key}:")
|
|
|
for subkey, subvalue in list(value.items())[:5]:
|
|
|
print(f" {subkey}: {subvalue}")
|
|
|
if len(value) > 5:
|
|
|
print(f" ... and {len(value) - 5} more")
|
|
|
else:
|
|
|
print(f" {key}: {value}")
|
|
|
|
|
|
if args.interactive:
|
|
|
|
|
|
run_interactive_tests(kb)
|
|
|
elif args.query:
|
|
|
|
|
|
test_single_query(kb, args.query, args.filter_tactics, args.filter_platforms)
|
|
|
else:
|
|
|
|
|
|
run_test_suite(kb)
|
|
|
|
|
|
|
|
|
def test_single_query(
|
|
|
kb,
|
|
|
query: str,
|
|
|
filter_tactics: Optional[List[str]] = None,
|
|
|
filter_platforms: Optional[List[str]] = None,
|
|
|
):
|
|
|
"""Test a single query with filters"""
|
|
|
|
|
|
print(f"\n[INFO] Testing Query: '{query}'")
|
|
|
if filter_tactics:
|
|
|
print(f"[INFO] Filtering by tactics: {filter_tactics}")
|
|
|
if filter_platforms:
|
|
|
print(f"[INFO] Filtering by platforms: {filter_platforms}")
|
|
|
print("-" * 40)
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = kb.search(
|
|
|
query,
|
|
|
top_k=20,
|
|
|
filter_tactics=filter_tactics,
|
|
|
filter_platforms=filter_platforms,
|
|
|
)
|
|
|
display_detailed_results(results)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error during search: {e}")
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
def display_detailed_results(results):
|
|
|
"""Display search results with detailed MITRE ATT&CK information"""
|
|
|
|
|
|
if results:
|
|
|
for i, doc in enumerate(results, 1):
|
|
|
attack_id = doc.metadata.get("attack_id", "Unknown")
|
|
|
name = doc.metadata.get("name", "Unknown")
|
|
|
tactics_str = doc.metadata.get("tactics", "")
|
|
|
platforms_str = doc.metadata.get("platforms", "")
|
|
|
is_subtechnique = doc.metadata.get("is_subtechnique", False)
|
|
|
mitigation_count = doc.metadata.get("mitigation_count", 0)
|
|
|
mitigations = doc.metadata.get("mitigations", "")
|
|
|
|
|
|
|
|
|
content_lines = doc.page_content.split("\n")
|
|
|
description_line = next(
|
|
|
(line for line in content_lines if line.startswith("Description:")), ""
|
|
|
)
|
|
|
if description_line:
|
|
|
description = description_line.replace("Description: ", "")
|
|
|
content_preview = truncate_to_tokens(description, 300)
|
|
|
else:
|
|
|
content_preview = truncate_to_tokens(doc.page_content, 300)
|
|
|
|
|
|
mitigation_preview = truncate_to_tokens(mitigations, 300)
|
|
|
|
|
|
print(f" {i}. {attack_id} - {name}")
|
|
|
print(f" Type: {'Sub-technique' if is_subtechnique else 'Technique'}")
|
|
|
print(f" Tactics: {tactics_str if tactics_str else 'None'}")
|
|
|
print(f" Platforms: {platforms_str if platforms_str else 'None'}")
|
|
|
print(
|
|
|
f" Mitigations: {mitigation_preview if mitigation_preview else 'None'}"
|
|
|
)
|
|
|
print(f" Mitigation Count: {mitigation_count}")
|
|
|
print(f" Description: {content_preview}")
|
|
|
print()
|
|
|
else:
|
|
|
print(" No results found")
|
|
|
|
|
|
|
|
|
def run_interactive_tests(kb):
|
|
|
"""Run interactive testing session with filtering options"""
|
|
|
|
|
|
print("\n[INFO] Interactive Testing Mode")
|
|
|
print("Available commands:")
|
|
|
print(" - Enter a query to search")
|
|
|
print(" - 'stats' to view knowledge base statistics")
|
|
|
print(" - 'tactics' to list available tactics")
|
|
|
print(" - 'platforms' to list available platforms")
|
|
|
print(
|
|
|
" - 'filter tactics:defense-evasion,privilege-escalation query' to filter by tactics"
|
|
|
)
|
|
|
print(" - 'filter platforms:Windows,Linux query' to filter by platforms")
|
|
|
print(" - 'technique T1055' to get specific technique info")
|
|
|
print(" - 'quit' to exit")
|
|
|
print("-" * 50)
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("\n[INPUT] Enter command: ").strip()
|
|
|
|
|
|
if user_input.lower() in ["quit", "exit", "q"]:
|
|
|
break
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
|
|
|
|
|
|
if user_input.lower() == "stats":
|
|
|
display_stats(kb)
|
|
|
continue
|
|
|
|
|
|
if user_input.lower() == "tactics":
|
|
|
display_available_tactics(kb)
|
|
|
continue
|
|
|
|
|
|
if user_input.lower() == "platforms":
|
|
|
display_available_platforms(kb)
|
|
|
continue
|
|
|
|
|
|
|
|
|
if user_input.lower().startswith("technique "):
|
|
|
technique_id = user_input.split(" ", 1)[1].strip()
|
|
|
display_technique_info(kb, technique_id)
|
|
|
continue
|
|
|
|
|
|
|
|
|
filter_tactics = None
|
|
|
filter_platforms = None
|
|
|
query = user_input
|
|
|
|
|
|
if user_input.lower().startswith("filter "):
|
|
|
|
|
|
parts = user_input.split(" ")
|
|
|
query_start = 1
|
|
|
|
|
|
for i, part in enumerate(parts[1:], 1):
|
|
|
if part.startswith("tactics:"):
|
|
|
filter_tactics = part.split(":", 1)[1].split(",")
|
|
|
query_start = i + 1
|
|
|
elif part.startswith("platforms:"):
|
|
|
filter_platforms = part.split(":", 1)[1].split(",")
|
|
|
query_start = i + 1
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
query = " ".join(parts[query_start:])
|
|
|
|
|
|
if not query.strip():
|
|
|
print("[ERROR] No query provided")
|
|
|
continue
|
|
|
|
|
|
|
|
|
print(f"\n[INFO] Search: '{query}'")
|
|
|
if filter_tactics:
|
|
|
print(f"[INFO] Filtering by tactics: {filter_tactics}")
|
|
|
if filter_platforms:
|
|
|
print(f"[INFO] Filtering by platforms: {filter_platforms}")
|
|
|
|
|
|
results = kb.search(
|
|
|
query,
|
|
|
top_k=20,
|
|
|
filter_tactics=filter_tactics,
|
|
|
filter_platforms=filter_platforms,
|
|
|
)
|
|
|
display_detailed_results(results)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n[INFO] Exiting interactive mode...")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error: {e}")
|
|
|
|
|
|
|
|
|
def display_stats(kb):
|
|
|
"""Display detailed knowledge base statistics"""
|
|
|
stats = kb.get_stats()
|
|
|
print("\n[INFO] Knowledge Base Statistics:")
|
|
|
for key, value in stats.items():
|
|
|
if isinstance(value, dict):
|
|
|
print(f" {key}:")
|
|
|
for subkey, subvalue in value.items():
|
|
|
print(f" {subkey}: {subvalue}")
|
|
|
else:
|
|
|
print(f" {key}: {value}")
|
|
|
|
|
|
|
|
|
def display_available_tactics(kb):
|
|
|
"""Display available tactics"""
|
|
|
stats = kb.get_stats()
|
|
|
tactics = stats.get("techniques_by_tactic", {})
|
|
|
if tactics:
|
|
|
print("\n[INFO] Available Tactics:")
|
|
|
for tactic, count in sorted(tactics.items()):
|
|
|
print(f" {tactic}: {count} techniques")
|
|
|
else:
|
|
|
print("\n[INFO] No tactics information available")
|
|
|
|
|
|
|
|
|
def display_available_platforms(kb):
|
|
|
"""Display available platforms"""
|
|
|
stats = kb.get_stats()
|
|
|
platforms = stats.get("techniques_by_platform", {})
|
|
|
if platforms:
|
|
|
print("\n[INFO] Available Platforms:")
|
|
|
for platform, count in sorted(platforms.items()):
|
|
|
print(f" {platform}: {count} techniques")
|
|
|
else:
|
|
|
print("\n[INFO] No platforms information available")
|
|
|
|
|
|
|
|
|
def display_technique_info(kb, technique_id: str):
|
|
|
"""Display detailed information about a specific technique"""
|
|
|
technique = kb.get_technique_by_id(technique_id.upper())
|
|
|
if technique:
|
|
|
print(f"\n[INFO] Technique Details: {technique_id}")
|
|
|
print("-" * 40)
|
|
|
print(f"Name: {technique.get('name', 'Unknown')}")
|
|
|
print(
|
|
|
f"Type: {'Sub-technique' if technique.get('is_subtechnique') else 'Technique'}"
|
|
|
)
|
|
|
print(f"Tactics: {', '.join(technique.get('tactics', []))}")
|
|
|
print(f"Platforms: {', '.join(technique.get('platforms', []))}")
|
|
|
print(f"Mitigations: {len(technique.get('mitigations', []))}")
|
|
|
|
|
|
description = technique.get("description", "")
|
|
|
if description:
|
|
|
print(
|
|
|
f"Description: {description[:500]}{'...' if len(description) > 500 else ''}"
|
|
|
)
|
|
|
|
|
|
detection = technique.get("detection", "")
|
|
|
if detection:
|
|
|
print(
|
|
|
f"Detection: {detection[:300]}{'...' if len(detection) > 300 else ''}"
|
|
|
)
|
|
|
else:
|
|
|
print(f"\n[ERROR] Technique {technique_id} not found")
|
|
|
|
|
|
|
|
|
def run_test_suite(kb):
|
|
|
"""Run comprehensive test suite for cyber techniques"""
|
|
|
|
|
|
test_cases = [
|
|
|
|
|
|
{"query": "process injection", "description": "Process injection techniques"},
|
|
|
{"query": "DLL injection", "description": "DLL injection methods"},
|
|
|
|
|
|
{
|
|
|
"query": "privilege escalation Windows",
|
|
|
"description": "Windows privilege escalation",
|
|
|
},
|
|
|
{"query": "UAC bypass", "description": "UAC bypass techniques"},
|
|
|
|
|
|
{
|
|
|
"query": "scheduled task persistence",
|
|
|
"description": "Scheduled task persistence",
|
|
|
},
|
|
|
{"query": "registry persistence", "description": "Registry-based persistence"},
|
|
|
|
|
|
{
|
|
|
"query": "credential dumping LSASS",
|
|
|
"description": "LSASS credential dumping",
|
|
|
},
|
|
|
{"query": "password spraying", "description": "Password spraying attacks"},
|
|
|
|
|
|
{
|
|
|
"query": "defense evasion DLL hijacking",
|
|
|
"description": "DLL hijacking evasion",
|
|
|
},
|
|
|
{"query": "process hollowing", "description": "Process hollowing technique"},
|
|
|
|
|
|
{"query": "lateral movement SMB", "description": "SMB lateral movement"},
|
|
|
{"query": "remote desktop protocol", "description": "RDP-based movement"},
|
|
|
]
|
|
|
|
|
|
print("\n[INFO] Running Cyber Security Test Suite:")
|
|
|
print("=" * 50)
|
|
|
|
|
|
for i, test_case in enumerate(test_cases, 1):
|
|
|
print(f"\n#{i} {test_case['description']}")
|
|
|
print(f"Query: '{test_case['query']}'")
|
|
|
print("-" * 30)
|
|
|
|
|
|
try:
|
|
|
results = kb.search(test_case["query"], top_k=3)
|
|
|
display_detailed_results(results)
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error: {e}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main entry point with argument parsing"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="MITRE ATT&CK Cyber Knowledge Base Management",
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
epilog="""
|
|
|
Examples:
|
|
|
python build_cyber_database.py ingest --techniques-json ./processed_data/cti/techniques.json
|
|
|
python build_cyber_database.py test --query "process injection"
|
|
|
python build_cyber_database.py test --interactive
|
|
|
python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation"
|
|
|
""",
|
|
|
)
|
|
|
|
|
|
|
|
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
|
|
|
|
|
|
|
ingest_parser = subparsers.add_parser(
|
|
|
"ingest", help="Ingest MITRE ATT&CK techniques and build knowledge base"
|
|
|
)
|
|
|
ingest_parser.add_argument(
|
|
|
"--techniques-json",
|
|
|
default="./processed_data/cti/techniques.json",
|
|
|
help="Path to techniques.json file",
|
|
|
)
|
|
|
ingest_parser.add_argument(
|
|
|
"--persist-dir",
|
|
|
default="./cyber_knowledge_base",
|
|
|
help="Directory to store the knowledge base",
|
|
|
)
|
|
|
ingest_parser.add_argument(
|
|
|
"--embedding-model",
|
|
|
default="google/embeddinggemma-300m",
|
|
|
help="Embedding model name",
|
|
|
)
|
|
|
ingest_parser.add_argument(
|
|
|
"--reset",
|
|
|
action="store_true",
|
|
|
default=True,
|
|
|
help="Reset knowledge base before ingestion (default: True)",
|
|
|
)
|
|
|
ingest_parser.add_argument(
|
|
|
"--no-reset",
|
|
|
dest="reset",
|
|
|
action="store_false",
|
|
|
help="Do not reset existing knowledge base",
|
|
|
)
|
|
|
|
|
|
|
|
|
test_parser = subparsers.add_parser(
|
|
|
"test", help="Test retrieval on existing knowledge base"
|
|
|
)
|
|
|
test_parser.add_argument("--query", help="Single query to test")
|
|
|
test_parser.add_argument(
|
|
|
"--filter-tactics",
|
|
|
nargs="+",
|
|
|
help="Filter by tactics (e.g., --filter-tactics defense-evasion privilege-escalation)",
|
|
|
)
|
|
|
test_parser.add_argument(
|
|
|
"--filter-platforms",
|
|
|
nargs="+",
|
|
|
help="Filter by platforms (e.g., --filter-platforms Windows Linux)",
|
|
|
)
|
|
|
test_parser.add_argument(
|
|
|
"--interactive", action="store_true", help="Interactive testing mode"
|
|
|
)
|
|
|
test_parser.add_argument(
|
|
|
"--persist-dir",
|
|
|
default="./cyber_knowledge_base",
|
|
|
help="Directory where knowledge base is stored",
|
|
|
)
|
|
|
test_parser.add_argument(
|
|
|
"--embedding-model",
|
|
|
default="google/embeddinggemma-300m",
|
|
|
help="Embedding model name",
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.command == "ingest":
|
|
|
success = ingest_techniques(args)
|
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
elif args.command == "test":
|
|
|
test_retrieval(args)
|
|
|
|
|
|
else:
|
|
|
parser.print_help()
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|