Log-Analysis-MultiAgent / src /knowledge_base /cyber_knowledge_base.py
minhan6559's picture
Upload 101 files
e4932aa verified
raw
history blame
27.5 kB
"""
MITRE ATT&CK Cyber Knowledge Base
This knowledge base processes MITRE ATT&CK techniques from techniques.json and provides:
- Semantic search using google/embeddinggemma-300m embeddings
- Cross-encoder reranking using cross-encoder/ms-marco-MiniLM-L6-v2
- Hybrid search combining ChromaDB (semantic) and BM25 (keyword)
- Multi-query search with Reciprocal Rank Fusion (RRF)
- Metadata filtering by tactics, platforms, and other technique attributes
"""
import os
import json
import pickle
from typing import List, Dict, Optional, Any
from pathlib import Path
from collections import defaultdict
from langchain.schema import Document
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.runnables import ConfigurableField
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import torch
from nltk.tokenize import word_tokenize
import nltk
nltk.download("punkt_tab")
# Use newest import paths for langchain
try:
from langchain_chroma import Chroma
except ImportError:
from langchain_community.vectorstores import Chroma
# Use HuggingFaceEmbeddings for google/embeddinggemma-300m
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
class CyberKnowledgeBase:
"""MITRE ATT&CK knowledge base with semantic search and reranking"""
def __init__(self, embedding_model: str = "google/embeddinggemma-300m"):
"""
Initialize the cyber knowledge base
Args:
embedding_model: Embedding model to use for semantic search
"""
print(f"[INFO] Initializing CyberKnowledgeBase with {embedding_model}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.device = "cpu"
print(f"[INFO] Using device: {self.device}")
# Initialize embeddings with GPU support and trust_remote_code
model_kwargs = {"device": self.device}
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model, model_kwargs=model_kwargs
)
# Initialize retrievers as None
self.chroma_retriever = None
self.bm25_retriever = None
self.ensemble_retriever = None
# Initialize reranker
self.cross_encoder = HuggingFaceCrossEncoder(
model_name="cross-encoder/ms-marco-MiniLM-L12-v2",
model_kwargs=model_kwargs,
)
# Store original techniques data for filtering
self.techniques_data = None
def build_knowledge_base(
self,
techniques_json_path: str,
persist_dir: str = "./knowledge_base",
reset: bool = True,
) -> None:
"""
Build knowledge base from techniques.json
Args:
techniques_json_path: Path to the techniques.json file
persist_dir: Directory to persist the knowledge base
reset: Whether to reset existing knowledge base
"""
print("[INFO] Building MITRE ATT&CK knowledge base...")
# Load techniques data
self.techniques_data = self._load_techniques(techniques_json_path)
print(f"[INFO] Loaded {len(self.techniques_data)} techniques")
# Convert to documents
documents = self._create_documents(self.techniques_data)
print(f"[INFO] Created {len(documents)} documents")
# Create directories
os.makedirs(persist_dir, exist_ok=True)
chroma_dir = os.path.join(persist_dir, "chroma")
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
# Build ChromaDB retriever
print("[INFO] Building ChromaDB retriever...")
self.chroma_retriever = self._build_chroma_retriever(
documents, chroma_dir, reset
)
# Build BM25 retriever
print("[INFO] Building BM25 retriever...")
self.bm25_retriever = self._build_bm25_retriever(documents, bm25_path, reset)
# Create ensemble retriever
print("[INFO] Creating ensemble retriever...")
self.ensemble_retriever = self._build_ensemble_retriever(
self.bm25_retriever, self.chroma_retriever
)
# Reranking will be done at search time with dynamic top_k
print("[INFO] Reranker initialized and ready for search...")
print("[SUCCESS] Knowledge base built successfully!")
print("[INFO] Use kb.search(query, top_k) to perform searches.")
print(
"[INFO] Use kb.search_multi_query(queries, top_k) for multi-query RRF search."
)
def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool:
"""
Load existing knowledge base from disk
Args:
persist_dir: Directory where the knowledge base is stored
Returns:
bool: True if loaded successfully, False otherwise
"""
print("[INFO] Loading knowledge base from disk...")
chroma_dir = os.path.join(persist_dir, "chroma")
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
try:
# Load ChromaDB
if os.path.exists(chroma_dir):
vectorstore = Chroma(
persist_directory=chroma_dir, embedding_function=self.embeddings
)
self.chroma_retriever = vectorstore.as_retriever(
search_kwargs={"k": 20}
).configurable_fields(
search_kwargs=ConfigurableField(
id="chroma_search_kwargs",
name="Chroma Search Kwargs",
description="Search kwargs for Chroma DB retriever",
)
)
print("[SUCCESS] ChromaDB loaded")
else:
print("[ERROR] ChromaDB not found")
return False
# Load BM25 retriever
if os.path.exists(bm25_path):
with open(bm25_path, "rb") as f:
self.bm25_retriever = pickle.load(f)
print("[SUCCESS] BM25 retriever loaded")
else:
# Rebuild BM25 from ChromaDB if pickle not found
print("[INFO] BM25 pickle not found, rebuilding from ChromaDB...")
all_docs = vectorstore.get(include=["documents", "metadatas"])
documents = all_docs["documents"]
metadatas = all_docs["metadatas"]
doc_objects = []
for doc_content, metadata in zip(documents, metadatas):
if metadata is None:
metadata = {}
doc_obj = Document(page_content=doc_content, metadata=metadata)
doc_objects.append(doc_obj)
self.bm25_retriever = self._build_bm25_retriever(
doc_objects, bm25_path, reset=False
)
# Create ensemble retriever
self.ensemble_retriever = self._build_ensemble_retriever(
self.bm25_retriever, self.chroma_retriever
)
# Reranking will be done at search time with dynamic top_k
print("[INFO] Reranker ready for search...")
print("[SUCCESS] Knowledge base loaded successfully!")
return True
except Exception as e:
print(f"[ERROR] Error loading knowledge base: {e}")
return False
def search(
self,
query: str,
top_k: int = 10,
filter_tactics: Optional[List[str]] = None,
filter_platforms: Optional[List[str]] = None,
) -> List[Document]:
"""
Search for techniques using hybrid retrieval and reranking
Args:
query: Search query
top_k: Number of results to return
filter_tactics: Filter by specific tactics (e.g., ['defense-evasion'])
filter_platforms: Filter by platforms (e.g., ['Windows'])
Returns:
List of retrieved and reranked documents
"""
if not self.ensemble_retriever:
raise ValueError(
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
)
# Build config for retrievers
config = {
"configurable": {
"bm25_k": top_k * 10, # Get more from BM25 for diversity
"chroma_search_kwargs": {"k": top_k * 10},
}
}
# Get initial results from ensemble retriever
initial_results = self.ensemble_retriever.invoke(query, config=config)
# Create a reranker with the specified top_k for this search
temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k)
# Apply reranking to the initial results
results = temp_reranker.compress_documents(initial_results, query)
# Manually add relevance scores to metadata since CrossEncoderReranker doesn't preserve them
scores = self.cross_encoder.score(
[(query, doc.page_content) for doc in results]
)
for doc, score in zip(results, scores):
doc.metadata["relevance_score"] = float(score)
# Apply metadata filters if specified
if filter_tactics or filter_platforms:
filtered_results = []
for doc in results:
# Check tactics filter
if filter_tactics:
doc_tactics = doc.metadata.get("tactics", "").split(",")
doc_tactics = [
t.strip() for t in doc_tactics if t.strip()
] # Clean empty strings
if not any(tactic in doc_tactics for tactic in filter_tactics):
continue
# Check platforms filter
if filter_platforms:
doc_platforms = doc.metadata.get("platforms", "").split(",")
doc_platforms = [
p.strip() for p in doc_platforms if p.strip()
] # Clean empty strings
if not any(
platform in doc_platforms for platform in filter_platforms
):
continue
filtered_results.append(doc)
results = filtered_results[:top_k]
return results
def search_multi_query(
self,
queries: List[str],
top_k: int = 10,
rerank_query: Optional[str] = None,
filter_tactics: Optional[List[str]] = None,
filter_platforms: Optional[List[str]] = None,
rrf_k: int = 60,
) -> List[Document]:
"""
Search for techniques using multiple queries with Reciprocal Rank Fusion (RRF)
This method performs retrieval for each query separately, then combines the results
using RRF before applying cross-encoder reranking.
Args:
queries: List of search queries
top_k: Number of final results to return after reranking
rerank_query: Rerank query to use for cross-encoder reranking
filter_tactics: Filter by specific tactics
filter_platforms: Filter by platforms
rrf_k: RRF constant (default: 60, standard value from literature)
Returns:
List of retrieved, RRF-fused, and reranked documents
"""
if not self.ensemble_retriever:
raise ValueError(
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
)
if not queries:
return []
# If only one query, use regular search
if len(queries) == 1:
return self.search(queries[0], top_k, filter_tactics, filter_platforms)
print(f"[INFO] Performing multi-query search with {len(queries)} queries")
# Retrieve documents for each query
all_query_results = []
config = {
"configurable": {
"bm25_k": top_k * 15, # Get more documents for RRF fusion
"chroma_search_kwargs": {"k": top_k * 15},
}
}
for i, query in enumerate(queries, 1):
print(f"[INFO] Query {i}/{len(queries)}: '{query}'")
results = self.ensemble_retriever.invoke(query, config=config)
all_query_results.append(results)
# Apply Reciprocal Rank Fusion (RRF)
print(f"[INFO] Applying Reciprocal Rank Fusion (k={rrf_k})")
fused_results = self._reciprocal_rank_fusion(all_query_results, k=rrf_k)
# Get top candidates before reranking (more than final top_k for better reranking)
candidates = fused_results[: top_k * 5]
print(f"[INFO] Reranking {len(candidates)} candidates with cross-encoder")
reference_query = rerank_query or queries[0]
# Create a reranker with the specified top_k
temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k)
# Apply reranking
results = temp_reranker.compress_documents(candidates, reference_query)
# Manually add relevance scores to metadata since CrossEncoderReranker doesn't preserve them
scores = self.cross_encoder.score(
[(reference_query, doc.page_content) for doc in results]
)
for doc, score in zip(results, scores):
doc.metadata["relevance_score"] = float(score)
# Apply metadata filters if specified
if filter_tactics or filter_platforms:
filtered_results = []
for doc in results:
# Check tactics filter
if filter_tactics:
doc_tactics = doc.metadata.get("tactics", "").split(",")
doc_tactics = [t.strip() for t in doc_tactics if t.strip()]
if not any(tactic in doc_tactics for tactic in filter_tactics):
continue
# Check platforms filter
if filter_platforms:
doc_platforms = doc.metadata.get("platforms", "").split(",")
doc_platforms = [p.strip() for p in doc_platforms if p.strip()]
if not any(
platform in doc_platforms for platform in filter_platforms
):
continue
filtered_results.append(doc)
results = filtered_results[:top_k]
print(f"[INFO] Returning {len(results)} final results")
return results
def _reciprocal_rank_fusion(
self, doc_lists: List[List[Document]], k: int = 60
) -> List[Document]:
"""
Apply Reciprocal Rank Fusion to combine multiple ranked lists
RRF score for document d: sum over all rankings r of (1 / (k + rank(d, r)))
where k is a constant (typically 60) and rank is the position in ranking r
Args:
doc_lists: List of document lists from different queries
k: RRF constant (default: 60)
Returns:
Fused list of documents sorted by RRF score
"""
# Create a mapping from document ID to document and its RRF score
doc_scores = defaultdict(float)
doc_map = {}
# Process each ranking
for doc_list in doc_lists:
for rank, doc in enumerate(doc_list, start=1):
# Use attack_id as unique identifier
doc_id = doc.metadata.get("attack_id", "")
if not doc_id:
# Fallback to content hash if no attack_id
doc_id = hash(doc.page_content)
# Calculate RRF score: 1 / (k + rank)
rrf_score = 1.0 / (k + rank)
doc_scores[doc_id] += rrf_score
# Store document object (keep first occurrence)
if doc_id not in doc_map:
doc_map[doc_id] = doc
# Sort documents by RRF score (descending)
sorted_doc_ids = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
# Create sorted document list with RRF scores in metadata
fused_docs = []
for doc_id, score in sorted_doc_ids:
doc = doc_map[doc_id]
# Add RRF score to metadata
doc.metadata["rrf_score"] = score
fused_docs.append(doc)
return fused_docs
def get_technique_by_id(self, technique_id: str) -> Optional[Dict[str, Any]]:
"""Get technique data by attack ID"""
if not self.techniques_data:
return None
for technique in self.techniques_data:
if technique.get("attack_id") == technique_id:
return technique
return None
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about the knowledge base"""
stats = {}
if self.chroma_retriever:
try:
vectorstore = self.chroma_retriever.vectorstore
collection = vectorstore._collection
stats["chroma_documents"] = collection.count()
except:
stats["chroma_documents"] = "Unknown"
if self.bm25_retriever:
try:
stats["bm25_documents"] = len(self.bm25_retriever.docs)
except:
stats["bm25_documents"] = "Unknown"
stats["ensemble_available"] = self.ensemble_retriever is not None
stats["reranker_available"] = self.cross_encoder is not None
stats["reranker_model"] = self.cross_encoder.model_name
stats["embedding_model"] = self.embeddings.model_name
if self.techniques_data:
stats["total_techniques"] = len(self.techniques_data)
# Count by tactics
tactics_count = {}
for technique in self.techniques_data:
for tactic in technique.get("tactics", []):
tactics_count[tactic] = tactics_count.get(tactic, 0) + 1
stats["techniques_by_tactic"] = tactics_count
# Count by platforms
platforms_count = {}
for technique in self.techniques_data:
for platform in technique.get("platforms", []):
platforms_count[platform] = platforms_count.get(platform, 0) + 1
stats["techniques_by_platform"] = platforms_count
return stats
def _load_techniques(self, json_path: str) -> List[Dict[str, Any]]:
"""Load techniques from JSON file"""
if not os.path.exists(json_path):
raise FileNotFoundError(f"Techniques file not found: {json_path}")
with open(json_path, "r", encoding="utf-8") as f:
techniques = json.load(f)
return techniques
def _create_documents(self, techniques: List[Dict[str, Any]]) -> List[Document]:
"""Convert technique data to LangChain documents"""
documents = []
for technique in techniques:
# Main content for embedding: name + description
page_content = f"Technique: {technique.get('name', 'Unknown')}\n\n"
page_content += f"Description: {technique.get('description', 'No description available')}"
# Create metadata - ChromaDB requires simple data types
metadata = {
"attack_id": technique.get("attack_id", ""),
"name": technique.get("name", ""),
"is_subtechnique": technique.get("is_subtechnique", False),
"platforms": ",".join(
technique.get("platforms", [])
), # Convert list to comma-separated string
"tactics": ",".join(
technique.get("tactics", [])
), # Convert list to comma-separated string
"doc_type": "mitre_technique",
}
# Add mitigation count to metadata
mitigations = technique.get("mitigations", [])
metadata["mitigation_count"] = len(mitigations)
metadata["mitigations"] = "; ".join(mitigations)
doc = Document(page_content=page_content, metadata=metadata)
documents.append(doc)
return documents
def _build_chroma_retriever(
self, documents: List[Document], chroma_dir: str, reset: bool
):
"""Build ChromaDB retriever"""
if reset and os.path.exists(chroma_dir):
import shutil
shutil.rmtree(chroma_dir)
print("[INFO] Removed existing ChromaDB for rebuild")
# Create Chroma vectorstore
vectorstore = Chroma.from_documents(
documents=documents, embedding=self.embeddings, persist_directory=chroma_dir
)
# Create configurable retriever
retriever = vectorstore.as_retriever(
search_kwargs={"k": 20} # default value
).configurable_fields(
search_kwargs=ConfigurableField(
id="chroma_search_kwargs",
name="Chroma Search Kwargs",
description="Search kwargs for Chroma DB retriever",
)
)
print(f"[SUCCESS] ChromaDB created with {len(documents)} documents")
return retriever
def _build_bm25_retriever(
self, documents: List[Document], bm25_path: str, reset: bool
):
"""Build BM25 retriever"""
# Create BM25 retriever
retriever = BM25Retriever.from_documents(
documents=documents,
k=20, # default value
preprocess_func=word_tokenize,
).configurable_fields(
k=ConfigurableField(
id="bm25_k",
name="BM25 Top K",
description="Number of documents to return from BM25",
)
)
# Save BM25 retriever
try:
with open(bm25_path, "wb") as f:
pickle.dump(retriever, f)
print(f"[SUCCESS] BM25 retriever saved to {bm25_path}")
except Exception as e:
print(f"[WARNING] Could not save BM25 retriever: {e}")
print(f"[SUCCESS] BM25 retriever created with {len(documents)} documents")
return retriever
def _build_ensemble_retriever(self, bm25_retriever, chroma_retriever):
"""Build ensemble retriever combining BM25 and ChromaDB"""
return EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever],
weights=[0.3, 0.7], # Favor semantic search slightly
)
def test_cyber_kb(kb: CyberKnowledgeBase, test_queries: List[str]):
"""Test function for the cyber knowledge base"""
print("\n[INFO] Testing Cyber Knowledge Base")
print("=" * 60)
for i, query in enumerate(test_queries, 1):
print(f"\n#{i} Query: '{query}'")
print("-" * 40)
try:
# Test search
results = kb.search(query, top_k=3)
if results:
for j, doc in enumerate(results, 1):
attack_id = doc.metadata.get("attack_id", "Unknown")
name = doc.metadata.get("name", "Unknown")
tactics_str = doc.metadata.get("tactics", "")
platforms_str = doc.metadata.get("platforms", "")
content_preview = doc.page_content[:200].replace("\n", " ")
print(f" {j}. {attack_id} - {name}")
print(f" Tactics: {tactics_str}")
print(f" Platforms: {platforms_str}")
print(f" Preview: {content_preview}...")
print()
else:
print(" No results found")
except Exception as e:
print(f" [ERROR] Error: {e}")
def test_multi_query_search(kb: CyberKnowledgeBase):
"""Test multi-query search with RRF"""
print("\n[INFO] Testing Multi-Query Search with RRF")
print("=" * 60)
# Test case 1: Credential dumping with different query angles
print("\n### Test Case 1: Credential Dumping ###")
queries_1 = [
"credential dumping LSASS memory",
"stealing authentication secrets",
"SAM database access ntds.dit",
]
print(f"Queries: {queries_1}")
results = kb.search_multi_query(queries_1, top_k=5)
print("\nTop 5 Results:")
for i, doc in enumerate(results, 1):
attack_id = doc.metadata.get("attack_id", "Unknown")
name = doc.metadata.get("name", "Unknown")
rrf_score = doc.metadata.get("rrf_score", "N/A")
print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})")
# Test case 2: Process injection with different perspectives
print("\n\n### Test Case 2: Process Injection ###")
queries_2 = [
"process injection defense evasion",
"code injection into running processes",
"DLL injection CreateRemoteThread",
]
print(f"Queries: {queries_2}")
results = kb.search_multi_query(queries_2, top_k=5)
print("\nTop 5 Results:")
for i, doc in enumerate(results, 1):
attack_id = doc.metadata.get("attack_id", "Unknown")
name = doc.metadata.get("name", "Unknown")
rrf_score = doc.metadata.get("rrf_score", "N/A")
print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})")
# Example usage
if __name__ == "__main__":
# Initialize knowledge base
kb = CyberKnowledgeBase()
# Path to techniques.json
techniques_path = "../../processed_data/cti/techniques.json"
try:
# Build knowledge base
kb.build_knowledge_base(
techniques_json_path=techniques_path, persist_dir="./mitre_kb", reset=True
)
# Test queries
test_queries = [
"process injection techniques",
"privilege escalation Windows",
"scheduled task persistence",
"credential dumping LSASS",
"lateral movement SMB",
"defense evasion DLL hijacking",
]
# Test the knowledge base
test_cyber_kb(kb, test_queries)
# Test multi-query search with RRF
test_multi_query_search(kb)
# Show stats
print(f"\n[INFO] Knowledge Base Stats:")
stats = kb.get_stats()
for key, value in stats.items():
if isinstance(value, dict):
print(f" {key}:")
for subkey, subvalue in value.items():
print(f" {subkey}: {subvalue}")
else:
print(f" {key}: {value}")
except Exception as e:
print(f"[ERROR] Error: {e}")
import traceback
traceback.print_exc()