""" MITRE ATT&CK Cyber Knowledge Base Management Script This script manages the MITRE ATT&CK techniques knowledge base with: - Processing techniques.json file containing MITRE ATT&CK data - Semantic search using google/embeddinggemma-300m embeddings - Cross-encoder reranking using Qwen/Qwen3-Reranker-0.6B - Hybrid search combining ChromaDB (semantic) and BM25 (keyword) - Metadata filtering by tactics, platforms, and technique attributes Usage: python build_cyber_database.py ingest --techniques-json ./mitre_data/techniques.json python build_cyber_database.py test --query "process injection" python build_cyber_database.py test --interactive python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation" --filter-platforms "Windows" """ import argparse import os import sys from pathlib import Path from typing import Optional, List # Add the project root to Python path so we can import from src project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from langchain.text_splitter import TokenTextSplitter from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase def truncate_to_tokens(text: str, max_tokens: int = 300) -> str: """ Truncate text to a maximum number of tokens using LangChain's TokenTextSplitter. Args: text: The text to truncate max_tokens: Maximum number of tokens (default: 300) Returns: Truncated text within the token limit """ if not text: return "" # Clean the text by replacing newlines with spaces cleaned_text = text.replace("\n", " ") # Use TokenTextSplitter to split by tokens splitter = TokenTextSplitter( encoding_name="cl100k_base", chunk_size=max_tokens, chunk_overlap=0 ) chunks = splitter.split_text(cleaned_text) return chunks[0] if chunks else "" def validate_techniques_file(techniques_json_path: str) -> bool: """Validate that techniques.json exists and is readable""" if not os.path.exists(techniques_json_path): print(f"[ERROR] Techniques file not found: {techniques_json_path}") return False try: import json with open(techniques_json_path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): print(f"[ERROR] Invalid format: techniques.json should contain a list") return False if len(data) == 0: print(f"[ERROR] Empty techniques file") return False # Check first item has required fields first_technique = data[0] required_fields = ["attack_id", "name", "description"] missing_fields = [ field for field in required_fields if field not in first_technique ] if missing_fields: print(f"[ERROR] Missing required fields in techniques: {missing_fields}") return False print(f"[SUCCESS] Valid techniques file with {len(data)} techniques") return True except json.JSONDecodeError as e: print(f"[ERROR] Invalid JSON format: {e}") return False except Exception as e: print(f"[ERROR] Error reading techniques file: {e}") return False def ingest_techniques(args): """Ingest MITRE ATT&CK techniques and build knowledge base""" print("=" * 60) print("[INFO] INGESTING MITRE ATT&CK TECHNIQUES") print("=" * 60) # Validate techniques file if not validate_techniques_file(args.techniques_json): sys.exit(1) # Initialize knowledge base kb = CyberKnowledgeBase(embedding_model=args.embedding_model) try: # Build knowledge base kb.build_knowledge_base( techniques_json_path=args.techniques_json, persist_dir=args.persist_dir, reset=args.reset, ) # Show final statistics print("\n[INFO] Knowledge Base Statistics:") stats = kb.get_stats() for key, value in stats.items(): if isinstance(value, dict): print(f" {key}:") for subkey, subvalue in list(value.items())[:5]: # Show first 5 items print(f" {subkey}: {subvalue}") if len(value) > 5: print(f" ... and {len(value) - 5} more") else: print(f" {key}: {value}") print(f"\n[SUCCESS] Knowledge base saved successfully to {args.persist_dir}!") return True except Exception as e: print(f"[ERROR] Error during ingestion: {e}") import traceback traceback.print_exc() return False def test_retrieval(args): """Test retrieval on existing knowledge base""" print("=" * 60) print("[INFO] TESTING CYBER KNOWLEDGE BASE") print("=" * 60) # Load knowledge base kb = CyberKnowledgeBase(embedding_model=args.embedding_model) # Load knowledge base success = kb.load_knowledge_base(persist_dir=args.persist_dir) if not success: print("[ERROR] Failed to load knowledge base. Run 'ingest' first.") sys.exit(1) # Show knowledge base stats print("\n[INFO] Knowledge Base Statistics:") stats = kb.get_stats() for key, value in stats.items(): if isinstance(value, dict): print(f" {key}:") for subkey, subvalue in list(value.items())[:5]: # Show first 5 items print(f" {subkey}: {subvalue}") if len(value) > 5: print(f" ... and {len(value) - 5} more") else: print(f" {key}: {value}") if args.interactive: # Interactive testing mode run_interactive_tests(kb) elif args.query: # Single query testing test_single_query(kb, args.query, args.filter_tactics, args.filter_platforms) else: # Run default test suite run_test_suite(kb) def test_single_query( kb, query: str, filter_tactics: Optional[List[str]] = None, filter_platforms: Optional[List[str]] = None, ): """Test a single query with filters""" print(f"\n[INFO] Testing Query: '{query}'") if filter_tactics: print(f"[INFO] Filtering by tactics: {filter_tactics}") if filter_platforms: print(f"[INFO] Filtering by platforms: {filter_platforms}") print("-" * 40) try: # Test search with filters results = kb.search( query, top_k=20, filter_tactics=filter_tactics, filter_platforms=filter_platforms, ) display_detailed_results(results) except Exception as e: print(f"[ERROR] Error during search: {e}") import traceback traceback.print_exc() def display_detailed_results(results): """Display search results with detailed MITRE ATT&CK information""" if results: for i, doc in enumerate(results, 1): attack_id = doc.metadata.get("attack_id", "Unknown") name = doc.metadata.get("name", "Unknown") tactics_str = doc.metadata.get("tactics", "") platforms_str = doc.metadata.get("platforms", "") is_subtechnique = doc.metadata.get("is_subtechnique", False) mitigation_count = doc.metadata.get("mitigation_count", 0) mitigations = doc.metadata.get("mitigations", "") # Get content preview from description content_lines = doc.page_content.split("\n") description_line = next( (line for line in content_lines if line.startswith("Description:")), "" ) if description_line: description = description_line.replace("Description: ", "") content_preview = truncate_to_tokens(description, 300) else: content_preview = truncate_to_tokens(doc.page_content, 300) mitigation_preview = truncate_to_tokens(mitigations, 300) print(f" {i}. {attack_id} - {name}") print(f" Type: {'Sub-technique' if is_subtechnique else 'Technique'}") print(f" Tactics: {tactics_str if tactics_str else 'None'}") print(f" Platforms: {platforms_str if platforms_str else 'None'}") print( f" Mitigations: {mitigation_preview if mitigation_preview else 'None'}" ) print(f" Mitigation Count: {mitigation_count}") print(f" Description: {content_preview}") print() else: print(" No results found") def run_interactive_tests(kb): """Run interactive testing session with filtering options""" print("\n[INFO] Interactive Testing Mode") print("Available commands:") print(" - Enter a query to search") print(" - 'stats' to view knowledge base statistics") print(" - 'tactics' to list available tactics") print(" - 'platforms' to list available platforms") print( " - 'filter tactics:defense-evasion,privilege-escalation query' to filter by tactics" ) print(" - 'filter platforms:Windows,Linux query' to filter by platforms") print(" - 'technique T1055' to get specific technique info") print(" - 'quit' to exit") print("-" * 50) while True: try: user_input = input("\n[INPUT] Enter command: ").strip() if user_input.lower() in ["quit", "exit", "q"]: break if not user_input: continue # Handle special commands if user_input.lower() == "stats": display_stats(kb) continue if user_input.lower() == "tactics": display_available_tactics(kb) continue if user_input.lower() == "platforms": display_available_platforms(kb) continue # Handle technique lookup if user_input.lower().startswith("technique "): technique_id = user_input.split(" ", 1)[1].strip() display_technique_info(kb, technique_id) continue # Handle filtered queries filter_tactics = None filter_platforms = None query = user_input if user_input.lower().startswith("filter "): # Parse filter command: "filter tactics:a,b platforms:x,y query text" parts = user_input.split(" ") query_start = 1 for i, part in enumerate(parts[1:], 1): if part.startswith("tactics:"): filter_tactics = part.split(":", 1)[1].split(",") query_start = i + 1 elif part.startswith("platforms:"): filter_platforms = part.split(":", 1)[1].split(",") query_start = i + 1 else: break query = " ".join(parts[query_start:]) if not query.strip(): print("[ERROR] No query provided") continue # Regular search print(f"\n[INFO] Search: '{query}'") if filter_tactics: print(f"[INFO] Filtering by tactics: {filter_tactics}") if filter_platforms: print(f"[INFO] Filtering by platforms: {filter_platforms}") results = kb.search( query, top_k=20, filter_tactics=filter_tactics, filter_platforms=filter_platforms, ) display_detailed_results(results) except KeyboardInterrupt: print("\n[INFO] Exiting interactive mode...") break except Exception as e: print(f"[ERROR] Error: {e}") def display_stats(kb): """Display detailed knowledge base statistics""" stats = kb.get_stats() print("\n[INFO] Knowledge Base Statistics:") for key, value in stats.items(): if isinstance(value, dict): print(f" {key}:") for subkey, subvalue in value.items(): print(f" {subkey}: {subvalue}") else: print(f" {key}: {value}") def display_available_tactics(kb): """Display available tactics""" stats = kb.get_stats() tactics = stats.get("techniques_by_tactic", {}) if tactics: print("\n[INFO] Available Tactics:") for tactic, count in sorted(tactics.items()): print(f" {tactic}: {count} techniques") else: print("\n[INFO] No tactics information available") def display_available_platforms(kb): """Display available platforms""" stats = kb.get_stats() platforms = stats.get("techniques_by_platform", {}) if platforms: print("\n[INFO] Available Platforms:") for platform, count in sorted(platforms.items()): print(f" {platform}: {count} techniques") else: print("\n[INFO] No platforms information available") def display_technique_info(kb, technique_id: str): """Display detailed information about a specific technique""" technique = kb.get_technique_by_id(technique_id.upper()) if technique: print(f"\n[INFO] Technique Details: {technique_id}") print("-" * 40) print(f"Name: {technique.get('name', 'Unknown')}") print( f"Type: {'Sub-technique' if technique.get('is_subtechnique') else 'Technique'}" ) print(f"Tactics: {', '.join(technique.get('tactics', []))}") print(f"Platforms: {', '.join(technique.get('platforms', []))}") print(f"Mitigations: {len(technique.get('mitigations', []))}") description = technique.get("description", "") if description: print( f"Description: {description[:500]}{'...' if len(description) > 500 else ''}" ) detection = technique.get("detection", "") if detection: print( f"Detection: {detection[:300]}{'...' if len(detection) > 300 else ''}" ) else: print(f"\n[ERROR] Technique {technique_id} not found") def run_test_suite(kb): """Run comprehensive test suite for cyber techniques""" test_cases = [ # Process injection techniques {"query": "process injection", "description": "Process injection techniques"}, {"query": "DLL injection", "description": "DLL injection methods"}, # Privilege escalation { "query": "privilege escalation Windows", "description": "Windows privilege escalation", }, {"query": "UAC bypass", "description": "UAC bypass techniques"}, # Persistence { "query": "scheduled task persistence", "description": "Scheduled task persistence", }, {"query": "registry persistence", "description": "Registry-based persistence"}, # Credential access { "query": "credential dumping LSASS", "description": "LSASS credential dumping", }, {"query": "password spraying", "description": "Password spraying attacks"}, # Defense evasion { "query": "defense evasion DLL hijacking", "description": "DLL hijacking evasion", }, {"query": "process hollowing", "description": "Process hollowing technique"}, # Lateral movement {"query": "lateral movement SMB", "description": "SMB lateral movement"}, {"query": "remote desktop protocol", "description": "RDP-based movement"}, ] print("\n[INFO] Running Cyber Security Test Suite:") print("=" * 50) for i, test_case in enumerate(test_cases, 1): print(f"\n#{i} {test_case['description']}") print(f"Query: '{test_case['query']}'") print("-" * 30) try: results = kb.search(test_case["query"], top_k=3) display_detailed_results(results) except Exception as e: print(f"[ERROR] Error: {e}") def main(): """Main entry point with argument parsing""" parser = argparse.ArgumentParser( description="MITRE ATT&CK Cyber Knowledge Base Management", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python build_cyber_database.py ingest --techniques-json ./mitre_data/techniques.json python build_cyber_database.py test --query "process injection" python build_cyber_database.py test --interactive python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation" """, ) # Subcommands subparsers = parser.add_subparsers(dest="command", help="Available commands") # Ingest command ingest_parser = subparsers.add_parser( "ingest", help="Ingest MITRE ATT&CK techniques and build knowledge base" ) ingest_parser.add_argument( "--techniques-json", default="./mitre_data/techniques.json", help="Path to techniques.json file", ) ingest_parser.add_argument( "--persist-dir", default="./cyber_knowledge_base", help="Directory to store the knowledge base", ) ingest_parser.add_argument( "--embedding-model", default="google/embeddinggemma-300m", help="Embedding model name", ) ingest_parser.add_argument( "--reset", action="store_true", default=True, help="Reset knowledge base before ingestion (default: True)", ) ingest_parser.add_argument( "--no-reset", dest="reset", action="store_false", help="Do not reset existing knowledge base", ) # Test command test_parser = subparsers.add_parser( "test", help="Test retrieval on existing knowledge base" ) test_parser.add_argument("--query", help="Single query to test") test_parser.add_argument( "--filter-tactics", nargs="+", help="Filter by tactics (e.g., --filter-tactics defense-evasion privilege-escalation)", ) test_parser.add_argument( "--filter-platforms", nargs="+", help="Filter by platforms (e.g., --filter-platforms Windows Linux)", ) test_parser.add_argument( "--interactive", action="store_true", help="Interactive testing mode" ) test_parser.add_argument( "--persist-dir", default="./cyber_knowledge_base", help="Directory where knowledge base is stored", ) test_parser.add_argument( "--embedding-model", default="google/embeddinggemma-300m", help="Embedding model name", ) args = parser.parse_args() if args.command == "ingest": success = ingest_techniques(args) sys.exit(0 if success else 1) elif args.command == "test": test_retrieval(args) else: parser.print_help() sys.exit(1) if __name__ == "__main__": main()