Spaces:
Sleeping
Sleeping
| """ | |
| MITRE ATT&CK Cyber Knowledge Base | |
| This knowledge base processes MITRE ATT&CK techniques from techniques.json and provides: | |
| - Semantic search using google/embeddinggemma-300m embeddings | |
| - Cross-encoder reranking using cross-encoder/ms-marco-MiniLM-L6-v2 | |
| - Hybrid search combining ChromaDB (semantic) and BM25 (keyword) | |
| - Multi-query search with Reciprocal Rank Fusion (RRF) | |
| - Metadata filtering by tactics, platforms, and other technique attributes | |
| """ | |
| import os | |
| import json | |
| import pickle | |
| from typing import List, Dict, Optional, Any | |
| from pathlib import Path | |
| from collections import defaultdict | |
| from langchain.schema import Document | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_core.runnables import ConfigurableField | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| import torch | |
| from nltk.tokenize import word_tokenize | |
| import nltk | |
| nltk.download("punkt_tab") | |
| # Use newest import paths for langchain | |
| try: | |
| from langchain_chroma import Chroma | |
| except ImportError: | |
| from langchain_community.vectorstores import Chroma | |
| # Use HuggingFaceEmbeddings for google/embeddinggemma-300m | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| except ImportError: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| class CyberKnowledgeBase: | |
| """MITRE ATT&CK knowledge base with semantic search and reranking""" | |
| def __init__(self, embedding_model: str = "google/embeddinggemma-300m"): | |
| """ | |
| Initialize the cyber knowledge base | |
| Args: | |
| embedding_model: Embedding model to use for semantic search | |
| """ | |
| print(f"[INFO] Initializing CyberKnowledgeBase with {embedding_model}") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # self.device = "cpu" | |
| print(f"[INFO] Using device: {self.device}") | |
| # Initialize embeddings with GPU support and trust_remote_code | |
| model_kwargs = {"device": self.device} | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model, model_kwargs=model_kwargs | |
| ) | |
| # Initialize retrievers as None | |
| self.chroma_retriever = None | |
| self.bm25_retriever = None | |
| self.ensemble_retriever = None | |
| # Initialize reranker | |
| self.cross_encoder = HuggingFaceCrossEncoder( | |
| model_name="cross-encoder/ms-marco-MiniLM-L12-v2", | |
| model_kwargs=model_kwargs, | |
| ) | |
| # Store original techniques data for filtering | |
| self.techniques_data = None | |
| def build_knowledge_base( | |
| self, | |
| techniques_json_path: str, | |
| persist_dir: str = "./knowledge_base", | |
| reset: bool = True, | |
| ) -> None: | |
| """ | |
| Build knowledge base from techniques.json | |
| Args: | |
| techniques_json_path: Path to the techniques.json file | |
| persist_dir: Directory to persist the knowledge base | |
| reset: Whether to reset existing knowledge base | |
| """ | |
| print("[INFO] Building MITRE ATT&CK knowledge base...") | |
| # Load techniques data | |
| self.techniques_data = self._load_techniques(techniques_json_path) | |
| print(f"[INFO] Loaded {len(self.techniques_data)} techniques") | |
| # Convert to documents | |
| documents = self._create_documents(self.techniques_data) | |
| print(f"[INFO] Created {len(documents)} documents") | |
| # Create directories | |
| os.makedirs(persist_dir, exist_ok=True) | |
| chroma_dir = os.path.join(persist_dir, "chroma") | |
| bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl") | |
| # Build ChromaDB retriever | |
| print("[INFO] Building ChromaDB retriever...") | |
| self.chroma_retriever = self._build_chroma_retriever( | |
| documents, chroma_dir, reset | |
| ) | |
| # Build BM25 retriever | |
| print("[INFO] Building BM25 retriever...") | |
| self.bm25_retriever = self._build_bm25_retriever(documents, bm25_path, reset) | |
| # Create ensemble retriever | |
| print("[INFO] Creating ensemble retriever...") | |
| self.ensemble_retriever = self._build_ensemble_retriever( | |
| self.bm25_retriever, self.chroma_retriever | |
| ) | |
| # Reranking will be done at search time with dynamic top_k | |
| print("[INFO] Reranker initialized and ready for search...") | |
| print("[SUCCESS] Knowledge base built successfully!") | |
| print("[INFO] Use kb.search(query, top_k) to perform searches.") | |
| print( | |
| "[INFO] Use kb.search_multi_query(queries, top_k) for multi-query RRF search." | |
| ) | |
| def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool: | |
| """ | |
| Load existing knowledge base from disk | |
| Args: | |
| persist_dir: Directory where the knowledge base is stored | |
| Returns: | |
| bool: True if loaded successfully, False otherwise | |
| """ | |
| print("[INFO] Loading knowledge base from disk...") | |
| chroma_dir = os.path.join(persist_dir, "chroma") | |
| bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl") | |
| try: | |
| # Load ChromaDB | |
| if os.path.exists(chroma_dir): | |
| vectorstore = Chroma( | |
| persist_directory=chroma_dir, embedding_function=self.embeddings | |
| ) | |
| self.chroma_retriever = vectorstore.as_retriever( | |
| search_kwargs={"k": 20} | |
| ).configurable_fields( | |
| search_kwargs=ConfigurableField( | |
| id="chroma_search_kwargs", | |
| name="Chroma Search Kwargs", | |
| description="Search kwargs for Chroma DB retriever", | |
| ) | |
| ) | |
| print("[SUCCESS] ChromaDB loaded") | |
| else: | |
| print("[ERROR] ChromaDB not found") | |
| return False | |
| # Load BM25 retriever | |
| if os.path.exists(bm25_path): | |
| with open(bm25_path, "rb") as f: | |
| self.bm25_retriever = pickle.load(f) | |
| print("[SUCCESS] BM25 retriever loaded") | |
| else: | |
| # Rebuild BM25 from ChromaDB if pickle not found | |
| print("[INFO] BM25 pickle not found, rebuilding from ChromaDB...") | |
| all_docs = vectorstore.get(include=["documents", "metadatas"]) | |
| documents = all_docs["documents"] | |
| metadatas = all_docs["metadatas"] | |
| doc_objects = [] | |
| for doc_content, metadata in zip(documents, metadatas): | |
| if metadata is None: | |
| metadata = {} | |
| doc_obj = Document(page_content=doc_content, metadata=metadata) | |
| doc_objects.append(doc_obj) | |
| self.bm25_retriever = self._build_bm25_retriever( | |
| doc_objects, bm25_path, reset=False | |
| ) | |
| # Create ensemble retriever | |
| self.ensemble_retriever = self._build_ensemble_retriever( | |
| self.bm25_retriever, self.chroma_retriever | |
| ) | |
| # Reranking will be done at search time with dynamic top_k | |
| print("[INFO] Reranker ready for search...") | |
| print("[SUCCESS] Knowledge base loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"[ERROR] Error loading knowledge base: {e}") | |
| return False | |
| def search( | |
| self, | |
| query: str, | |
| top_k: int = 10, | |
| filter_tactics: Optional[List[str]] = None, | |
| filter_platforms: Optional[List[str]] = None, | |
| ) -> List[Document]: | |
| """ | |
| Search for techniques using hybrid retrieval and reranking | |
| Args: | |
| query: Search query | |
| top_k: Number of results to return | |
| filter_tactics: Filter by specific tactics (e.g., ['defense-evasion']) | |
| filter_platforms: Filter by platforms (e.g., ['Windows']) | |
| Returns: | |
| List of retrieved and reranked documents | |
| """ | |
| if not self.ensemble_retriever: | |
| raise ValueError( | |
| "Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first." | |
| ) | |
| # Build config for retrievers | |
| config = { | |
| "configurable": { | |
| "bm25_k": top_k * 10, # Get more from BM25 for diversity | |
| "chroma_search_kwargs": {"k": top_k * 10}, | |
| } | |
| } | |
| # Get initial results from ensemble retriever | |
| initial_results = self.ensemble_retriever.invoke(query, config=config) | |
| # Create a reranker with the specified top_k for this search | |
| temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k) | |
| # Apply reranking to the initial results | |
| results = temp_reranker.compress_documents(initial_results, query) | |
| # Manually add relevance scores to metadata since CrossEncoderReranker doesn't preserve them | |
| scores = self.cross_encoder.score( | |
| [(query, doc.page_content) for doc in results] | |
| ) | |
| for doc, score in zip(results, scores): | |
| doc.metadata["relevance_score"] = float(score) | |
| # Apply metadata filters if specified | |
| if filter_tactics or filter_platforms: | |
| filtered_results = [] | |
| for doc in results: | |
| # Check tactics filter | |
| if filter_tactics: | |
| doc_tactics = doc.metadata.get("tactics", "").split(",") | |
| doc_tactics = [ | |
| t.strip() for t in doc_tactics if t.strip() | |
| ] # Clean empty strings | |
| if not any(tactic in doc_tactics for tactic in filter_tactics): | |
| continue | |
| # Check platforms filter | |
| if filter_platforms: | |
| doc_platforms = doc.metadata.get("platforms", "").split(",") | |
| doc_platforms = [ | |
| p.strip() for p in doc_platforms if p.strip() | |
| ] # Clean empty strings | |
| if not any( | |
| platform in doc_platforms for platform in filter_platforms | |
| ): | |
| continue | |
| filtered_results.append(doc) | |
| results = filtered_results[:top_k] | |
| return results | |
| def search_multi_query( | |
| self, | |
| queries: List[str], | |
| top_k: int = 10, | |
| rerank_query: Optional[str] = None, | |
| filter_tactics: Optional[List[str]] = None, | |
| filter_platforms: Optional[List[str]] = None, | |
| rrf_k: int = 60, | |
| ) -> List[Document]: | |
| """ | |
| Search for techniques using multiple queries with Reciprocal Rank Fusion (RRF) | |
| This method performs retrieval for each query separately, then combines the results | |
| using RRF before applying cross-encoder reranking. | |
| Args: | |
| queries: List of search queries | |
| top_k: Number of final results to return after reranking | |
| rerank_query: Rerank query to use for cross-encoder reranking | |
| filter_tactics: Filter by specific tactics | |
| filter_platforms: Filter by platforms | |
| rrf_k: RRF constant (default: 60, standard value from literature) | |
| Returns: | |
| List of retrieved, RRF-fused, and reranked documents | |
| """ | |
| if not self.ensemble_retriever: | |
| raise ValueError( | |
| "Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first." | |
| ) | |
| if not queries: | |
| return [] | |
| # If only one query, use regular search | |
| if len(queries) == 1: | |
| return self.search(queries[0], top_k, filter_tactics, filter_platforms) | |
| print(f"[INFO] Performing multi-query search with {len(queries)} queries") | |
| # Retrieve documents for each query | |
| all_query_results = [] | |
| config = { | |
| "configurable": { | |
| "bm25_k": top_k * 15, # Get more documents for RRF fusion | |
| "chroma_search_kwargs": {"k": top_k * 15}, | |
| } | |
| } | |
| for i, query in enumerate(queries, 1): | |
| print(f"[INFO] Query {i}/{len(queries)}: '{query}'") | |
| results = self.ensemble_retriever.invoke(query, config=config) | |
| all_query_results.append(results) | |
| # Apply Reciprocal Rank Fusion (RRF) | |
| print(f"[INFO] Applying Reciprocal Rank Fusion (k={rrf_k})") | |
| fused_results = self._reciprocal_rank_fusion(all_query_results, k=rrf_k) | |
| # Get top candidates before reranking (more than final top_k for better reranking) | |
| candidates = fused_results[: top_k * 5] | |
| print(f"[INFO] Reranking {len(candidates)} candidates with cross-encoder") | |
| reference_query = rerank_query or queries[0] | |
| # Create a reranker with the specified top_k | |
| temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k) | |
| # Apply reranking | |
| results = temp_reranker.compress_documents(candidates, reference_query) | |
| # Manually add relevance scores to metadata since CrossEncoderReranker doesn't preserve them | |
| scores = self.cross_encoder.score( | |
| [(reference_query, doc.page_content) for doc in results] | |
| ) | |
| for doc, score in zip(results, scores): | |
| doc.metadata["relevance_score"] = float(score) | |
| # Apply metadata filters if specified | |
| if filter_tactics or filter_platforms: | |
| filtered_results = [] | |
| for doc in results: | |
| # Check tactics filter | |
| if filter_tactics: | |
| doc_tactics = doc.metadata.get("tactics", "").split(",") | |
| doc_tactics = [t.strip() for t in doc_tactics if t.strip()] | |
| if not any(tactic in doc_tactics for tactic in filter_tactics): | |
| continue | |
| # Check platforms filter | |
| if filter_platforms: | |
| doc_platforms = doc.metadata.get("platforms", "").split(",") | |
| doc_platforms = [p.strip() for p in doc_platforms if p.strip()] | |
| if not any( | |
| platform in doc_platforms for platform in filter_platforms | |
| ): | |
| continue | |
| filtered_results.append(doc) | |
| results = filtered_results[:top_k] | |
| print(f"[INFO] Returning {len(results)} final results") | |
| return results | |
| def _reciprocal_rank_fusion( | |
| self, doc_lists: List[List[Document]], k: int = 60 | |
| ) -> List[Document]: | |
| """ | |
| Apply Reciprocal Rank Fusion to combine multiple ranked lists | |
| RRF score for document d: sum over all rankings r of (1 / (k + rank(d, r))) | |
| where k is a constant (typically 60) and rank is the position in ranking r | |
| Args: | |
| doc_lists: List of document lists from different queries | |
| k: RRF constant (default: 60) | |
| Returns: | |
| Fused list of documents sorted by RRF score | |
| """ | |
| # Create a mapping from document ID to document and its RRF score | |
| doc_scores = defaultdict(float) | |
| doc_map = {} | |
| # Process each ranking | |
| for doc_list in doc_lists: | |
| for rank, doc in enumerate(doc_list, start=1): | |
| # Use attack_id as unique identifier | |
| doc_id = doc.metadata.get("attack_id", "") | |
| if not doc_id: | |
| # Fallback to content hash if no attack_id | |
| doc_id = hash(doc.page_content) | |
| # Calculate RRF score: 1 / (k + rank) | |
| rrf_score = 1.0 / (k + rank) | |
| doc_scores[doc_id] += rrf_score | |
| # Store document object (keep first occurrence) | |
| if doc_id not in doc_map: | |
| doc_map[doc_id] = doc | |
| # Sort documents by RRF score (descending) | |
| sorted_doc_ids = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True) | |
| # Create sorted document list with RRF scores in metadata | |
| fused_docs = [] | |
| for doc_id, score in sorted_doc_ids: | |
| doc = doc_map[doc_id] | |
| # Add RRF score to metadata | |
| doc.metadata["rrf_score"] = score | |
| fused_docs.append(doc) | |
| return fused_docs | |
| def get_technique_by_id(self, technique_id: str) -> Optional[Dict[str, Any]]: | |
| """Get technique data by attack ID""" | |
| if not self.techniques_data: | |
| return None | |
| for technique in self.techniques_data: | |
| if technique.get("attack_id") == technique_id: | |
| return technique | |
| return None | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the knowledge base""" | |
| stats = {} | |
| if self.chroma_retriever: | |
| try: | |
| vectorstore = self.chroma_retriever.vectorstore | |
| collection = vectorstore._collection | |
| stats["chroma_documents"] = collection.count() | |
| except: | |
| stats["chroma_documents"] = "Unknown" | |
| if self.bm25_retriever: | |
| try: | |
| stats["bm25_documents"] = len(self.bm25_retriever.docs) | |
| except: | |
| stats["bm25_documents"] = "Unknown" | |
| stats["ensemble_available"] = self.ensemble_retriever is not None | |
| stats["reranker_available"] = self.cross_encoder is not None | |
| stats["reranker_model"] = self.cross_encoder.model_name | |
| stats["embedding_model"] = self.embeddings.model_name | |
| if self.techniques_data: | |
| stats["total_techniques"] = len(self.techniques_data) | |
| # Count by tactics | |
| tactics_count = {} | |
| for technique in self.techniques_data: | |
| for tactic in technique.get("tactics", []): | |
| tactics_count[tactic] = tactics_count.get(tactic, 0) + 1 | |
| stats["techniques_by_tactic"] = tactics_count | |
| # Count by platforms | |
| platforms_count = {} | |
| for technique in self.techniques_data: | |
| for platform in technique.get("platforms", []): | |
| platforms_count[platform] = platforms_count.get(platform, 0) + 1 | |
| stats["techniques_by_platform"] = platforms_count | |
| return stats | |
| def _load_techniques(self, json_path: str) -> List[Dict[str, Any]]: | |
| """Load techniques from JSON file""" | |
| if not os.path.exists(json_path): | |
| raise FileNotFoundError(f"Techniques file not found: {json_path}") | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| techniques = json.load(f) | |
| return techniques | |
| def _create_documents(self, techniques: List[Dict[str, Any]]) -> List[Document]: | |
| """Convert technique data to LangChain documents""" | |
| documents = [] | |
| for technique in techniques: | |
| # Main content for embedding: name + description | |
| page_content = f"Technique: {technique.get('name', 'Unknown')}\n\n" | |
| page_content += f"Description: {technique.get('description', 'No description available')}" | |
| # Create metadata - ChromaDB requires simple data types | |
| metadata = { | |
| "attack_id": technique.get("attack_id", ""), | |
| "name": technique.get("name", ""), | |
| "is_subtechnique": technique.get("is_subtechnique", False), | |
| "platforms": ",".join( | |
| technique.get("platforms", []) | |
| ), # Convert list to comma-separated string | |
| "tactics": ",".join( | |
| technique.get("tactics", []) | |
| ), # Convert list to comma-separated string | |
| "doc_type": "mitre_technique", | |
| } | |
| # Add mitigation count to metadata | |
| mitigations = technique.get("mitigations", []) | |
| metadata["mitigation_count"] = len(mitigations) | |
| metadata["mitigations"] = "; ".join(mitigations) | |
| doc = Document(page_content=page_content, metadata=metadata) | |
| documents.append(doc) | |
| return documents | |
| def _build_chroma_retriever( | |
| self, documents: List[Document], chroma_dir: str, reset: bool | |
| ): | |
| """Build ChromaDB retriever""" | |
| if reset and os.path.exists(chroma_dir): | |
| import shutil | |
| shutil.rmtree(chroma_dir) | |
| print("[INFO] Removed existing ChromaDB for rebuild") | |
| # Create Chroma vectorstore | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, embedding=self.embeddings, persist_directory=chroma_dir | |
| ) | |
| # Create configurable retriever | |
| retriever = vectorstore.as_retriever( | |
| search_kwargs={"k": 20} # default value | |
| ).configurable_fields( | |
| search_kwargs=ConfigurableField( | |
| id="chroma_search_kwargs", | |
| name="Chroma Search Kwargs", | |
| description="Search kwargs for Chroma DB retriever", | |
| ) | |
| ) | |
| print(f"[SUCCESS] ChromaDB created with {len(documents)} documents") | |
| return retriever | |
| def _build_bm25_retriever( | |
| self, documents: List[Document], bm25_path: str, reset: bool | |
| ): | |
| """Build BM25 retriever""" | |
| # Create BM25 retriever | |
| retriever = BM25Retriever.from_documents( | |
| documents=documents, | |
| k=20, # default value | |
| preprocess_func=word_tokenize, | |
| ).configurable_fields( | |
| k=ConfigurableField( | |
| id="bm25_k", | |
| name="BM25 Top K", | |
| description="Number of documents to return from BM25", | |
| ) | |
| ) | |
| # Save BM25 retriever | |
| try: | |
| with open(bm25_path, "wb") as f: | |
| pickle.dump(retriever, f) | |
| print(f"[SUCCESS] BM25 retriever saved to {bm25_path}") | |
| except Exception as e: | |
| print(f"[WARNING] Could not save BM25 retriever: {e}") | |
| print(f"[SUCCESS] BM25 retriever created with {len(documents)} documents") | |
| return retriever | |
| def _build_ensemble_retriever(self, bm25_retriever, chroma_retriever): | |
| """Build ensemble retriever combining BM25 and ChromaDB""" | |
| return EnsembleRetriever( | |
| retrievers=[bm25_retriever, chroma_retriever], | |
| weights=[0.3, 0.7], # Favor semantic search slightly | |
| ) | |
| def test_cyber_kb(kb: CyberKnowledgeBase, test_queries: List[str]): | |
| """Test function for the cyber knowledge base""" | |
| print("\n[INFO] Testing Cyber Knowledge Base") | |
| print("=" * 60) | |
| for i, query in enumerate(test_queries, 1): | |
| print(f"\n#{i} Query: '{query}'") | |
| print("-" * 40) | |
| try: | |
| # Test search | |
| results = kb.search(query, top_k=3) | |
| if results: | |
| for j, doc in enumerate(results, 1): | |
| attack_id = doc.metadata.get("attack_id", "Unknown") | |
| name = doc.metadata.get("name", "Unknown") | |
| tactics_str = doc.metadata.get("tactics", "") | |
| platforms_str = doc.metadata.get("platforms", "") | |
| content_preview = doc.page_content[:200].replace("\n", " ") | |
| print(f" {j}. {attack_id} - {name}") | |
| print(f" Tactics: {tactics_str}") | |
| print(f" Platforms: {platforms_str}") | |
| print(f" Preview: {content_preview}...") | |
| print() | |
| else: | |
| print(" No results found") | |
| except Exception as e: | |
| print(f" [ERROR] Error: {e}") | |
| def test_multi_query_search(kb: CyberKnowledgeBase): | |
| """Test multi-query search with RRF""" | |
| print("\n[INFO] Testing Multi-Query Search with RRF") | |
| print("=" * 60) | |
| # Test case 1: Credential dumping with different query angles | |
| print("\n### Test Case 1: Credential Dumping ###") | |
| queries_1 = [ | |
| "credential dumping LSASS memory", | |
| "stealing authentication secrets", | |
| "SAM database access ntds.dit", | |
| ] | |
| print(f"Queries: {queries_1}") | |
| results = kb.search_multi_query(queries_1, top_k=5) | |
| print("\nTop 5 Results:") | |
| for i, doc in enumerate(results, 1): | |
| attack_id = doc.metadata.get("attack_id", "Unknown") | |
| name = doc.metadata.get("name", "Unknown") | |
| rrf_score = doc.metadata.get("rrf_score", "N/A") | |
| print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})") | |
| # Test case 2: Process injection with different perspectives | |
| print("\n\n### Test Case 2: Process Injection ###") | |
| queries_2 = [ | |
| "process injection defense evasion", | |
| "code injection into running processes", | |
| "DLL injection CreateRemoteThread", | |
| ] | |
| print(f"Queries: {queries_2}") | |
| results = kb.search_multi_query(queries_2, top_k=5) | |
| print("\nTop 5 Results:") | |
| for i, doc in enumerate(results, 1): | |
| attack_id = doc.metadata.get("attack_id", "Unknown") | |
| name = doc.metadata.get("name", "Unknown") | |
| rrf_score = doc.metadata.get("rrf_score", "N/A") | |
| print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})") | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize knowledge base | |
| kb = CyberKnowledgeBase() | |
| # Path to techniques.json | |
| techniques_path = "../../processed_data/cti/techniques.json" | |
| try: | |
| # Build knowledge base | |
| kb.build_knowledge_base( | |
| techniques_json_path=techniques_path, persist_dir="./mitre_kb", reset=True | |
| ) | |
| # Test queries | |
| test_queries = [ | |
| "process injection techniques", | |
| "privilege escalation Windows", | |
| "scheduled task persistence", | |
| "credential dumping LSASS", | |
| "lateral movement SMB", | |
| "defense evasion DLL hijacking", | |
| ] | |
| # Test the knowledge base | |
| test_cyber_kb(kb, test_queries) | |
| # Test multi-query search with RRF | |
| test_multi_query_search(kb) | |
| # Show stats | |
| print(f"\n[INFO] Knowledge Base Stats:") | |
| stats = kb.get_stats() | |
| for key, value in stats.items(): | |
| if isinstance(value, dict): | |
| print(f" {key}:") | |
| for subkey, subvalue in value.items(): | |
| print(f" {subkey}: {subvalue}") | |
| else: | |
| print(f" {key}: {value}") | |
| except Exception as e: | |
| print(f"[ERROR] Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |