|
|
"""
|
|
|
MITRE ATT&CK Cyber Knowledge Base
|
|
|
|
|
|
This knowledge base processes MITRE ATT&CK techniques from techniques.json and provides:
|
|
|
- Semantic search using google/embeddinggemma-300m embeddings
|
|
|
- Cross-encoder reranking using cross-encoder/ms-marco-MiniLM-L6-v2
|
|
|
- Hybrid search combining ChromaDB (semantic) and BM25 (keyword)
|
|
|
- Multi-query search with Reciprocal Rank Fusion (RRF)
|
|
|
- Metadata filtering by tactics, platforms, and other technique attributes
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import pickle
|
|
|
from typing import List, Dict, Optional, Any
|
|
|
from pathlib import Path
|
|
|
from collections import defaultdict
|
|
|
|
|
|
from langchain.schema import Document
|
|
|
from langchain.retrievers import EnsembleRetriever
|
|
|
from langchain_community.retrievers import BM25Retriever
|
|
|
from langchain_core.runnables import ConfigurableField
|
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
|
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from nltk.tokenize import word_tokenize
|
|
|
import nltk
|
|
|
|
|
|
nltk.download("punkt_tab")
|
|
|
|
|
|
|
|
|
try:
|
|
|
from langchain_chroma import Chroma
|
|
|
except ImportError:
|
|
|
from langchain_community.vectorstores import Chroma
|
|
|
|
|
|
|
|
|
try:
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
except ImportError:
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
class CyberKnowledgeBase:
|
|
|
"""MITRE ATT&CK knowledge base with semantic search and reranking"""
|
|
|
|
|
|
def __init__(self, embedding_model: str = "google/embeddinggemma-300m"):
|
|
|
"""
|
|
|
Initialize the cyber knowledge base
|
|
|
|
|
|
Args:
|
|
|
embedding_model: Embedding model to use for semantic search
|
|
|
"""
|
|
|
print(f"[INFO] Initializing CyberKnowledgeBase with {embedding_model}")
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
print(f"[INFO] Using device: {self.device}")
|
|
|
|
|
|
|
|
|
model_kwargs = {"device": self.device}
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings(
|
|
|
model_name=embedding_model, model_kwargs=model_kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
self.chroma_retriever = None
|
|
|
self.bm25_retriever = None
|
|
|
self.ensemble_retriever = None
|
|
|
|
|
|
|
|
|
self.cross_encoder = HuggingFaceCrossEncoder(
|
|
|
model_name="cross-encoder/ms-marco-MiniLM-L12-v2",
|
|
|
model_kwargs=model_kwargs,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.techniques_data = None
|
|
|
|
|
|
def build_knowledge_base(
|
|
|
self,
|
|
|
techniques_json_path: str,
|
|
|
persist_dir: str = "./knowledge_base",
|
|
|
reset: bool = True,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Build knowledge base from techniques.json
|
|
|
|
|
|
Args:
|
|
|
techniques_json_path: Path to the techniques.json file
|
|
|
persist_dir: Directory to persist the knowledge base
|
|
|
reset: Whether to reset existing knowledge base
|
|
|
"""
|
|
|
print("[INFO] Building MITRE ATT&CK knowledge base...")
|
|
|
|
|
|
|
|
|
self.techniques_data = self._load_techniques(techniques_json_path)
|
|
|
print(f"[INFO] Loaded {len(self.techniques_data)} techniques")
|
|
|
|
|
|
|
|
|
documents = self._create_documents(self.techniques_data)
|
|
|
print(f"[INFO] Created {len(documents)} documents")
|
|
|
|
|
|
|
|
|
os.makedirs(persist_dir, exist_ok=True)
|
|
|
chroma_dir = os.path.join(persist_dir, "chroma")
|
|
|
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
|
|
|
|
|
|
|
|
|
print("[INFO] Building ChromaDB retriever...")
|
|
|
self.chroma_retriever = self._build_chroma_retriever(
|
|
|
documents, chroma_dir, reset
|
|
|
)
|
|
|
|
|
|
|
|
|
print("[INFO] Building BM25 retriever...")
|
|
|
self.bm25_retriever = self._build_bm25_retriever(documents, bm25_path, reset)
|
|
|
|
|
|
|
|
|
print("[INFO] Creating ensemble retriever...")
|
|
|
self.ensemble_retriever = self._build_ensemble_retriever(
|
|
|
self.bm25_retriever, self.chroma_retriever
|
|
|
)
|
|
|
|
|
|
|
|
|
print("[INFO] Reranker initialized and ready for search...")
|
|
|
|
|
|
print("[SUCCESS] Knowledge base built successfully!")
|
|
|
print("[INFO] Use kb.search(query, top_k) to perform searches.")
|
|
|
print(
|
|
|
"[INFO] Use kb.search_multi_query(queries, top_k) for multi-query RRF search."
|
|
|
)
|
|
|
|
|
|
def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool:
|
|
|
"""
|
|
|
Load existing knowledge base from disk
|
|
|
|
|
|
Args:
|
|
|
persist_dir: Directory where the knowledge base is stored
|
|
|
|
|
|
Returns:
|
|
|
bool: True if loaded successfully, False otherwise
|
|
|
"""
|
|
|
print("[INFO] Loading knowledge base from disk...")
|
|
|
|
|
|
chroma_dir = os.path.join(persist_dir, "chroma")
|
|
|
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.exists(chroma_dir):
|
|
|
vectorstore = Chroma(
|
|
|
persist_directory=chroma_dir, embedding_function=self.embeddings
|
|
|
)
|
|
|
self.chroma_retriever = vectorstore.as_retriever(
|
|
|
search_kwargs={"k": 20}
|
|
|
).configurable_fields(
|
|
|
search_kwargs=ConfigurableField(
|
|
|
id="chroma_search_kwargs",
|
|
|
name="Chroma Search Kwargs",
|
|
|
description="Search kwargs for Chroma DB retriever",
|
|
|
)
|
|
|
)
|
|
|
print("[SUCCESS] ChromaDB loaded")
|
|
|
else:
|
|
|
print("[ERROR] ChromaDB not found")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if os.path.exists(bm25_path):
|
|
|
with open(bm25_path, "rb") as f:
|
|
|
self.bm25_retriever = pickle.load(f)
|
|
|
print("[SUCCESS] BM25 retriever loaded")
|
|
|
else:
|
|
|
|
|
|
print("[INFO] BM25 pickle not found, rebuilding from ChromaDB...")
|
|
|
all_docs = vectorstore.get(include=["documents", "metadatas"])
|
|
|
documents = all_docs["documents"]
|
|
|
metadatas = all_docs["metadatas"]
|
|
|
|
|
|
doc_objects = []
|
|
|
for doc_content, metadata in zip(documents, metadatas):
|
|
|
if metadata is None:
|
|
|
metadata = {}
|
|
|
doc_obj = Document(page_content=doc_content, metadata=metadata)
|
|
|
doc_objects.append(doc_obj)
|
|
|
|
|
|
self.bm25_retriever = self._build_bm25_retriever(
|
|
|
doc_objects, bm25_path, reset=False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.ensemble_retriever = self._build_ensemble_retriever(
|
|
|
self.bm25_retriever, self.chroma_retriever
|
|
|
)
|
|
|
|
|
|
|
|
|
print("[INFO] Reranker ready for search...")
|
|
|
|
|
|
print("[SUCCESS] Knowledge base loaded successfully!")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error loading knowledge base: {e}")
|
|
|
return False
|
|
|
|
|
|
def search(
|
|
|
self,
|
|
|
query: str,
|
|
|
top_k: int = 10,
|
|
|
filter_tactics: Optional[List[str]] = None,
|
|
|
filter_platforms: Optional[List[str]] = None,
|
|
|
) -> List[Document]:
|
|
|
"""
|
|
|
Search for techniques using hybrid retrieval and reranking
|
|
|
|
|
|
Args:
|
|
|
query: Search query
|
|
|
top_k: Number of results to return
|
|
|
filter_tactics: Filter by specific tactics (e.g., ['defense-evasion'])
|
|
|
filter_platforms: Filter by platforms (e.g., ['Windows'])
|
|
|
|
|
|
Returns:
|
|
|
List of retrieved and reranked documents
|
|
|
"""
|
|
|
if not self.ensemble_retriever:
|
|
|
raise ValueError(
|
|
|
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
|
|
|
)
|
|
|
|
|
|
|
|
|
config = {
|
|
|
"configurable": {
|
|
|
"bm25_k": top_k * 10,
|
|
|
"chroma_search_kwargs": {"k": top_k * 10},
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
initial_results = self.ensemble_retriever.invoke(query, config=config)
|
|
|
|
|
|
|
|
|
temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k)
|
|
|
|
|
|
|
|
|
results = temp_reranker.compress_documents(initial_results, query)
|
|
|
|
|
|
|
|
|
scores = self.cross_encoder.score(
|
|
|
[(query, doc.page_content) for doc in results]
|
|
|
)
|
|
|
for doc, score in zip(results, scores):
|
|
|
doc.metadata["relevance_score"] = float(score)
|
|
|
|
|
|
|
|
|
if filter_tactics or filter_platforms:
|
|
|
filtered_results = []
|
|
|
for doc in results:
|
|
|
|
|
|
if filter_tactics:
|
|
|
doc_tactics = doc.metadata.get("tactics", "").split(",")
|
|
|
doc_tactics = [
|
|
|
t.strip() for t in doc_tactics if t.strip()
|
|
|
]
|
|
|
if not any(tactic in doc_tactics for tactic in filter_tactics):
|
|
|
continue
|
|
|
|
|
|
|
|
|
if filter_platforms:
|
|
|
doc_platforms = doc.metadata.get("platforms", "").split(",")
|
|
|
doc_platforms = [
|
|
|
p.strip() for p in doc_platforms if p.strip()
|
|
|
]
|
|
|
if not any(
|
|
|
platform in doc_platforms for platform in filter_platforms
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
filtered_results.append(doc)
|
|
|
|
|
|
results = filtered_results[:top_k]
|
|
|
|
|
|
return results
|
|
|
|
|
|
def search_multi_query(
|
|
|
self,
|
|
|
queries: List[str],
|
|
|
top_k: int = 10,
|
|
|
rerank_query: Optional[str] = None,
|
|
|
filter_tactics: Optional[List[str]] = None,
|
|
|
filter_platforms: Optional[List[str]] = None,
|
|
|
rrf_k: int = 60,
|
|
|
) -> List[Document]:
|
|
|
"""
|
|
|
Search for techniques using multiple queries with Reciprocal Rank Fusion (RRF)
|
|
|
|
|
|
This method performs retrieval for each query separately, then combines the results
|
|
|
using RRF before applying cross-encoder reranking.
|
|
|
|
|
|
Args:
|
|
|
queries: List of search queries
|
|
|
top_k: Number of final results to return after reranking
|
|
|
rerank_query: Rerank query to use for cross-encoder reranking
|
|
|
filter_tactics: Filter by specific tactics
|
|
|
filter_platforms: Filter by platforms
|
|
|
rrf_k: RRF constant (default: 60, standard value from literature)
|
|
|
|
|
|
Returns:
|
|
|
List of retrieved, RRF-fused, and reranked documents
|
|
|
"""
|
|
|
if not self.ensemble_retriever:
|
|
|
raise ValueError(
|
|
|
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
|
|
|
)
|
|
|
|
|
|
if not queries:
|
|
|
return []
|
|
|
|
|
|
|
|
|
if len(queries) == 1:
|
|
|
return self.search(queries[0], top_k, filter_tactics, filter_platforms)
|
|
|
|
|
|
print(f"[INFO] Performing multi-query search with {len(queries)} queries")
|
|
|
|
|
|
|
|
|
all_query_results = []
|
|
|
|
|
|
config = {
|
|
|
"configurable": {
|
|
|
"bm25_k": top_k * 15,
|
|
|
"chroma_search_kwargs": {"k": top_k * 15},
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for i, query in enumerate(queries, 1):
|
|
|
print(f"[INFO] Query {i}/{len(queries)}: '{query}'")
|
|
|
results = self.ensemble_retriever.invoke(query, config=config)
|
|
|
all_query_results.append(results)
|
|
|
|
|
|
|
|
|
print(f"[INFO] Applying Reciprocal Rank Fusion (k={rrf_k})")
|
|
|
fused_results = self._reciprocal_rank_fusion(all_query_results, k=rrf_k)
|
|
|
|
|
|
|
|
|
candidates = fused_results[: top_k * 5]
|
|
|
|
|
|
print(f"[INFO] Reranking {len(candidates)} candidates with cross-encoder")
|
|
|
|
|
|
reference_query = rerank_query or queries[0]
|
|
|
|
|
|
|
|
|
temp_reranker = CrossEncoderReranker(model=self.cross_encoder, top_n=top_k)
|
|
|
|
|
|
|
|
|
results = temp_reranker.compress_documents(candidates, reference_query)
|
|
|
|
|
|
|
|
|
scores = self.cross_encoder.score(
|
|
|
[(reference_query, doc.page_content) for doc in results]
|
|
|
)
|
|
|
for doc, score in zip(results, scores):
|
|
|
doc.metadata["relevance_score"] = float(score)
|
|
|
|
|
|
|
|
|
if filter_tactics or filter_platforms:
|
|
|
filtered_results = []
|
|
|
for doc in results:
|
|
|
|
|
|
if filter_tactics:
|
|
|
doc_tactics = doc.metadata.get("tactics", "").split(",")
|
|
|
doc_tactics = [t.strip() for t in doc_tactics if t.strip()]
|
|
|
if not any(tactic in doc_tactics for tactic in filter_tactics):
|
|
|
continue
|
|
|
|
|
|
|
|
|
if filter_platforms:
|
|
|
doc_platforms = doc.metadata.get("platforms", "").split(",")
|
|
|
doc_platforms = [p.strip() for p in doc_platforms if p.strip()]
|
|
|
if not any(
|
|
|
platform in doc_platforms for platform in filter_platforms
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
filtered_results.append(doc)
|
|
|
|
|
|
results = filtered_results[:top_k]
|
|
|
|
|
|
print(f"[INFO] Returning {len(results)} final results")
|
|
|
return results
|
|
|
|
|
|
def _reciprocal_rank_fusion(
|
|
|
self, doc_lists: List[List[Document]], k: int = 60
|
|
|
) -> List[Document]:
|
|
|
"""
|
|
|
Apply Reciprocal Rank Fusion to combine multiple ranked lists
|
|
|
|
|
|
RRF score for document d: sum over all rankings r of (1 / (k + rank(d, r)))
|
|
|
where k is a constant (typically 60) and rank is the position in ranking r
|
|
|
|
|
|
Args:
|
|
|
doc_lists: List of document lists from different queries
|
|
|
k: RRF constant (default: 60)
|
|
|
|
|
|
Returns:
|
|
|
Fused list of documents sorted by RRF score
|
|
|
"""
|
|
|
|
|
|
doc_scores = defaultdict(float)
|
|
|
doc_map = {}
|
|
|
|
|
|
|
|
|
for doc_list in doc_lists:
|
|
|
for rank, doc in enumerate(doc_list, start=1):
|
|
|
|
|
|
doc_id = doc.metadata.get("attack_id", "")
|
|
|
if not doc_id:
|
|
|
|
|
|
doc_id = hash(doc.page_content)
|
|
|
|
|
|
|
|
|
rrf_score = 1.0 / (k + rank)
|
|
|
doc_scores[doc_id] += rrf_score
|
|
|
|
|
|
|
|
|
if doc_id not in doc_map:
|
|
|
doc_map[doc_id] = doc
|
|
|
|
|
|
|
|
|
sorted_doc_ids = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
fused_docs = []
|
|
|
for doc_id, score in sorted_doc_ids:
|
|
|
doc = doc_map[doc_id]
|
|
|
|
|
|
doc.metadata["rrf_score"] = score
|
|
|
fused_docs.append(doc)
|
|
|
|
|
|
return fused_docs
|
|
|
|
|
|
def get_technique_by_id(self, technique_id: str) -> Optional[Dict[str, Any]]:
|
|
|
"""Get technique data by attack ID"""
|
|
|
if not self.techniques_data:
|
|
|
return None
|
|
|
|
|
|
for technique in self.techniques_data:
|
|
|
if technique.get("attack_id") == technique_id:
|
|
|
return technique
|
|
|
return None
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get statistics about the knowledge base"""
|
|
|
stats = {}
|
|
|
|
|
|
if self.chroma_retriever:
|
|
|
try:
|
|
|
vectorstore = self.chroma_retriever.vectorstore
|
|
|
collection = vectorstore._collection
|
|
|
stats["chroma_documents"] = collection.count()
|
|
|
except:
|
|
|
stats["chroma_documents"] = "Unknown"
|
|
|
|
|
|
if self.bm25_retriever:
|
|
|
try:
|
|
|
stats["bm25_documents"] = len(self.bm25_retriever.docs)
|
|
|
except:
|
|
|
stats["bm25_documents"] = "Unknown"
|
|
|
|
|
|
stats["ensemble_available"] = self.ensemble_retriever is not None
|
|
|
stats["reranker_available"] = self.cross_encoder is not None
|
|
|
stats["reranker_model"] = self.cross_encoder.model_name
|
|
|
stats["embedding_model"] = self.embeddings.model_name
|
|
|
|
|
|
if self.techniques_data:
|
|
|
stats["total_techniques"] = len(self.techniques_data)
|
|
|
|
|
|
|
|
|
tactics_count = {}
|
|
|
for technique in self.techniques_data:
|
|
|
for tactic in technique.get("tactics", []):
|
|
|
tactics_count[tactic] = tactics_count.get(tactic, 0) + 1
|
|
|
stats["techniques_by_tactic"] = tactics_count
|
|
|
|
|
|
|
|
|
platforms_count = {}
|
|
|
for technique in self.techniques_data:
|
|
|
for platform in technique.get("platforms", []):
|
|
|
platforms_count[platform] = platforms_count.get(platform, 0) + 1
|
|
|
stats["techniques_by_platform"] = platforms_count
|
|
|
|
|
|
return stats
|
|
|
|
|
|
def _load_techniques(self, json_path: str) -> List[Dict[str, Any]]:
|
|
|
"""Load techniques from JSON file"""
|
|
|
if not os.path.exists(json_path):
|
|
|
raise FileNotFoundError(f"Techniques file not found: {json_path}")
|
|
|
|
|
|
with open(json_path, "r", encoding="utf-8") as f:
|
|
|
techniques = json.load(f)
|
|
|
|
|
|
return techniques
|
|
|
|
|
|
def _create_documents(self, techniques: List[Dict[str, Any]]) -> List[Document]:
|
|
|
"""Convert technique data to LangChain documents"""
|
|
|
documents = []
|
|
|
|
|
|
for technique in techniques:
|
|
|
|
|
|
page_content = f"Technique: {technique.get('name', 'Unknown')}\n\n"
|
|
|
page_content += f"Description: {technique.get('description', 'No description available')}"
|
|
|
|
|
|
|
|
|
metadata = {
|
|
|
"attack_id": technique.get("attack_id", ""),
|
|
|
"name": technique.get("name", ""),
|
|
|
"is_subtechnique": technique.get("is_subtechnique", False),
|
|
|
"platforms": ",".join(
|
|
|
technique.get("platforms", [])
|
|
|
),
|
|
|
"tactics": ",".join(
|
|
|
technique.get("tactics", [])
|
|
|
),
|
|
|
"doc_type": "mitre_technique",
|
|
|
}
|
|
|
|
|
|
|
|
|
mitigations = technique.get("mitigations", [])
|
|
|
metadata["mitigation_count"] = len(mitigations)
|
|
|
|
|
|
metadata["mitigations"] = "; ".join(mitigations)
|
|
|
|
|
|
doc = Document(page_content=page_content, metadata=metadata)
|
|
|
documents.append(doc)
|
|
|
|
|
|
return documents
|
|
|
|
|
|
def _build_chroma_retriever(
|
|
|
self, documents: List[Document], chroma_dir: str, reset: bool
|
|
|
):
|
|
|
"""Build ChromaDB retriever"""
|
|
|
if reset and os.path.exists(chroma_dir):
|
|
|
import shutil
|
|
|
|
|
|
shutil.rmtree(chroma_dir)
|
|
|
print("[INFO] Removed existing ChromaDB for rebuild")
|
|
|
|
|
|
|
|
|
vectorstore = Chroma.from_documents(
|
|
|
documents=documents, embedding=self.embeddings, persist_directory=chroma_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
retriever = vectorstore.as_retriever(
|
|
|
search_kwargs={"k": 20}
|
|
|
).configurable_fields(
|
|
|
search_kwargs=ConfigurableField(
|
|
|
id="chroma_search_kwargs",
|
|
|
name="Chroma Search Kwargs",
|
|
|
description="Search kwargs for Chroma DB retriever",
|
|
|
)
|
|
|
)
|
|
|
|
|
|
print(f"[SUCCESS] ChromaDB created with {len(documents)} documents")
|
|
|
return retriever
|
|
|
|
|
|
def _build_bm25_retriever(
|
|
|
self, documents: List[Document], bm25_path: str, reset: bool
|
|
|
):
|
|
|
"""Build BM25 retriever"""
|
|
|
|
|
|
retriever = BM25Retriever.from_documents(
|
|
|
documents=documents,
|
|
|
k=20,
|
|
|
preprocess_func=word_tokenize,
|
|
|
).configurable_fields(
|
|
|
k=ConfigurableField(
|
|
|
id="bm25_k",
|
|
|
name="BM25 Top K",
|
|
|
description="Number of documents to return from BM25",
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
with open(bm25_path, "wb") as f:
|
|
|
pickle.dump(retriever, f)
|
|
|
print(f"[SUCCESS] BM25 retriever saved to {bm25_path}")
|
|
|
except Exception as e:
|
|
|
print(f"[WARNING] Could not save BM25 retriever: {e}")
|
|
|
|
|
|
print(f"[SUCCESS] BM25 retriever created with {len(documents)} documents")
|
|
|
return retriever
|
|
|
|
|
|
def _build_ensemble_retriever(self, bm25_retriever, chroma_retriever):
|
|
|
"""Build ensemble retriever combining BM25 and ChromaDB"""
|
|
|
return EnsembleRetriever(
|
|
|
retrievers=[bm25_retriever, chroma_retriever],
|
|
|
weights=[0.3, 0.7],
|
|
|
)
|
|
|
|
|
|
|
|
|
def test_cyber_kb(kb: CyberKnowledgeBase, test_queries: List[str]):
|
|
|
"""Test function for the cyber knowledge base"""
|
|
|
|
|
|
print("\n[INFO] Testing Cyber Knowledge Base")
|
|
|
print("=" * 60)
|
|
|
|
|
|
for i, query in enumerate(test_queries, 1):
|
|
|
print(f"\n#{i} Query: '{query}'")
|
|
|
print("-" * 40)
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = kb.search(query, top_k=3)
|
|
|
|
|
|
if results:
|
|
|
for j, doc in enumerate(results, 1):
|
|
|
attack_id = doc.metadata.get("attack_id", "Unknown")
|
|
|
name = doc.metadata.get("name", "Unknown")
|
|
|
tactics_str = doc.metadata.get("tactics", "")
|
|
|
platforms_str = doc.metadata.get("platforms", "")
|
|
|
|
|
|
content_preview = doc.page_content[:200].replace("\n", " ")
|
|
|
|
|
|
print(f" {j}. {attack_id} - {name}")
|
|
|
print(f" Tactics: {tactics_str}")
|
|
|
print(f" Platforms: {platforms_str}")
|
|
|
print(f" Preview: {content_preview}...")
|
|
|
print()
|
|
|
else:
|
|
|
print(" No results found")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f" [ERROR] Error: {e}")
|
|
|
|
|
|
|
|
|
def test_multi_query_search(kb: CyberKnowledgeBase):
|
|
|
"""Test multi-query search with RRF"""
|
|
|
print("\n[INFO] Testing Multi-Query Search with RRF")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
print("\n### Test Case 1: Credential Dumping ###")
|
|
|
queries_1 = [
|
|
|
"credential dumping LSASS memory",
|
|
|
"stealing authentication secrets",
|
|
|
"SAM database access ntds.dit",
|
|
|
]
|
|
|
|
|
|
print(f"Queries: {queries_1}")
|
|
|
results = kb.search_multi_query(queries_1, top_k=5)
|
|
|
|
|
|
print("\nTop 5 Results:")
|
|
|
for i, doc in enumerate(results, 1):
|
|
|
attack_id = doc.metadata.get("attack_id", "Unknown")
|
|
|
name = doc.metadata.get("name", "Unknown")
|
|
|
rrf_score = doc.metadata.get("rrf_score", "N/A")
|
|
|
print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})")
|
|
|
|
|
|
|
|
|
print("\n\n### Test Case 2: Process Injection ###")
|
|
|
queries_2 = [
|
|
|
"process injection defense evasion",
|
|
|
"code injection into running processes",
|
|
|
"DLL injection CreateRemoteThread",
|
|
|
]
|
|
|
|
|
|
print(f"Queries: {queries_2}")
|
|
|
results = kb.search_multi_query(queries_2, top_k=5)
|
|
|
|
|
|
print("\nTop 5 Results:")
|
|
|
for i, doc in enumerate(results, 1):
|
|
|
attack_id = doc.metadata.get("attack_id", "Unknown")
|
|
|
name = doc.metadata.get("name", "Unknown")
|
|
|
rrf_score = doc.metadata.get("rrf_score", "N/A")
|
|
|
print(f" {i}. {attack_id} - {name} (RRF Score: {rrf_score:.4f})")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
kb = CyberKnowledgeBase()
|
|
|
|
|
|
|
|
|
techniques_path = "../../processed_data/cti/techniques.json"
|
|
|
|
|
|
try:
|
|
|
|
|
|
kb.build_knowledge_base(
|
|
|
techniques_json_path=techniques_path, persist_dir="./mitre_kb", reset=True
|
|
|
)
|
|
|
|
|
|
|
|
|
test_queries = [
|
|
|
"process injection techniques",
|
|
|
"privilege escalation Windows",
|
|
|
"scheduled task persistence",
|
|
|
"credential dumping LSASS",
|
|
|
"lateral movement SMB",
|
|
|
"defense evasion DLL hijacking",
|
|
|
]
|
|
|
|
|
|
|
|
|
test_cyber_kb(kb, test_queries)
|
|
|
|
|
|
|
|
|
test_multi_query_search(kb)
|
|
|
|
|
|
|
|
|
print(f"\n[INFO] Knowledge Base Stats:")
|
|
|
stats = kb.get_stats()
|
|
|
for key, value in stats.items():
|
|
|
if isinstance(value, dict):
|
|
|
print(f" {key}:")
|
|
|
for subkey, subvalue in value.items():
|
|
|
print(f" {subkey}: {subvalue}")
|
|
|
else:
|
|
|
print(f" {key}: {value}")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] Error: {e}")
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|