|
|
""" |
|
|
CTI Bench Evaluation Runner |
|
|
|
|
|
This script provides a command-line interface to run the CTI Bench evaluation |
|
|
with your Retrieval Supervisor system. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from dotenv import load_dotenv |
|
|
from huggingface_hub import login as huggingface_login |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.evaluator.cti_bench_evaluator import CTIBenchEvaluator |
|
|
from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor |
|
|
|
|
|
|
|
|
def setup_environment( |
|
|
dataset_dir: str = "cti_bench/datasets", output_dir: str = "cti_bench/eval_output" |
|
|
): |
|
|
"""Set up the environment for evaluation.""" |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
if os.getenv("GOOGLE_API_KEY"): |
|
|
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
if os.getenv("GROQ_API_KEY"): |
|
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
if os.getenv("OPENAI_API_KEY"): |
|
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
if os.getenv("HF_TOKEN"): |
|
|
huggingface_login(token=os.getenv("HF_TOKEN")) |
|
|
|
|
|
|
|
|
os.makedirs(dataset_dir, exist_ok=True) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
dataset_path = Path(dataset_dir) |
|
|
ate_file = dataset_path / "cti-ate.tsv" |
|
|
mcq_file = dataset_path / "cti-mcq.tsv" |
|
|
|
|
|
if not ate_file.exists() or not mcq_file.exists(): |
|
|
print("ERROR: CTI Bench dataset files not found!") |
|
|
print(f"Expected files:") |
|
|
print(f" - {ate_file}") |
|
|
print(f" - {mcq_file}") |
|
|
print( |
|
|
"Please download the CTI Bench dataset and place the files in the correct location." |
|
|
) |
|
|
sys.exit(1) |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def run_evaluation_quick_test( |
|
|
dataset_dir: str, |
|
|
output_dir: str, |
|
|
llm_model: str, |
|
|
kb_path: str, |
|
|
max_iterations: int, |
|
|
num_samples: int = 2, |
|
|
datasets: str = "all", |
|
|
): |
|
|
"""Run a quick test with a few samples.""" |
|
|
print("Running quick test evaluation...") |
|
|
|
|
|
try: |
|
|
|
|
|
supervisor = RetrievalSupervisor( |
|
|
llm_model=llm_model, |
|
|
kb_path=kb_path, |
|
|
max_iterations=max_iterations, |
|
|
) |
|
|
|
|
|
|
|
|
evaluator = CTIBenchEvaluator( |
|
|
supervisor=supervisor, |
|
|
dataset_dir=dataset_dir, |
|
|
output_dir=output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
ate_df, mcq_df = evaluator.load_datasets() |
|
|
ate_filtered = evaluator.filter_dataset(ate_df, "ate") |
|
|
mcq_filtered = evaluator.filter_dataset(mcq_df, "mcq") |
|
|
|
|
|
|
|
|
print(f"Testing with first {num_samples} samples of each dataset...") |
|
|
|
|
|
ate_sample = ate_filtered.head(num_samples) |
|
|
mcq_sample = mcq_filtered.head(num_samples) |
|
|
|
|
|
|
|
|
ate_results = None |
|
|
mcq_results = None |
|
|
ate_metrics = None |
|
|
mcq_metrics = None |
|
|
|
|
|
if datasets in ["ate", "all"]: |
|
|
print(f"\nEvaluating ATE dataset...") |
|
|
ate_results = evaluator.evaluate_ate_dataset(ate_sample) |
|
|
ate_metrics = evaluator.calculate_ate_metrics(ate_results) |
|
|
|
|
|
if datasets in ["mcq", "all"]: |
|
|
print(f"\nEvaluating MCQ dataset...") |
|
|
mcq_results = evaluator.evaluate_mcq_dataset(mcq_sample) |
|
|
mcq_metrics = evaluator.calculate_mcq_metrics(mcq_results) |
|
|
|
|
|
|
|
|
print("\nQuick Test Results:") |
|
|
if ate_metrics: |
|
|
print(f"ATE - Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") |
|
|
print(f"ATE - Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") |
|
|
if mcq_metrics: |
|
|
print(f"MCQ - Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") |
|
|
print(f"MCQ - Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Quick test failed: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def run_csv_metrics_calculation( |
|
|
csv_path: str, |
|
|
output_dir: str, |
|
|
model_name: str = None, |
|
|
): |
|
|
"""Calculate metrics from existing CSV results file.""" |
|
|
print("Calculating metrics from existing CSV file...") |
|
|
|
|
|
try: |
|
|
|
|
|
evaluator = CTIBenchEvaluator( |
|
|
supervisor=None, |
|
|
dataset_dir="", |
|
|
output_dir=output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
results = evaluator.calculate_metrics_from_csv( |
|
|
csv_path=csv_path, |
|
|
model_name=model_name, |
|
|
) |
|
|
|
|
|
print("CSV metrics calculation completed successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"CSV metrics calculation failed: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def run_full_evaluation( |
|
|
dataset_dir: str, |
|
|
output_dir: str, |
|
|
llm_model: str, |
|
|
kb_path: str, |
|
|
max_iterations: int, |
|
|
datasets: str = "all", |
|
|
): |
|
|
"""Run the complete evaluation.""" |
|
|
print("Running full evaluation...") |
|
|
|
|
|
try: |
|
|
|
|
|
supervisor = RetrievalSupervisor( |
|
|
llm_model=llm_model, |
|
|
kb_path=kb_path, |
|
|
max_iterations=max_iterations, |
|
|
) |
|
|
|
|
|
|
|
|
evaluator = CTIBenchEvaluator( |
|
|
supervisor=supervisor, |
|
|
dataset_dir=dataset_dir, |
|
|
output_dir=output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
if datasets == "all": |
|
|
results = evaluator.run_full_evaluation() |
|
|
elif datasets == "ate": |
|
|
results = evaluator.run_ate_evaluation() |
|
|
elif datasets == "mcq": |
|
|
results = evaluator.run_mcq_evaluation() |
|
|
else: |
|
|
print(f"Invalid dataset selection: {datasets}") |
|
|
return False |
|
|
|
|
|
print("Full evaluation completed successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Full evaluation failed: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_supervisor_connection(llm_model: str, kb_path: str): |
|
|
"""Test the supervisor connection.""" |
|
|
try: |
|
|
supervisor = RetrievalSupervisor( |
|
|
llm_model=llm_model, |
|
|
kb_path=kb_path, |
|
|
max_iterations=1, |
|
|
) |
|
|
response = supervisor.invoke_direct_query("Test query: What is T1071?") |
|
|
print("Supervisor connection successful!") |
|
|
print(f"Sample response length: {len(str(response))} characters") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Supervisor connection failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="CTI Bench Evaluation Runner", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Run quick test with default settings |
|
|
python cti_bench_evaluation.py --mode quick |
|
|
|
|
|
# Run full evaluation with custom settings |
|
|
python cti_bench_evaluation.py --mode full --llm-model google_genai:gemini-2.0-flash --max-iterations 5 |
|
|
|
|
|
# Run full evaluation on ATE dataset only |
|
|
python cti_bench_evaluation.py --mode full --datasets ate |
|
|
|
|
|
# Run full evaluation on MCQ dataset only |
|
|
python cti_bench_evaluation.py --mode full --datasets mcq |
|
|
|
|
|
# Test supervisor connection |
|
|
python cti_bench_evaluation.py --mode test |
|
|
|
|
|
# Run quick test with 5 samples |
|
|
python cti_bench_evaluation.py --mode quick --num-samples 5 |
|
|
|
|
|
# Calculate metrics from existing CSV file |
|
|
python cti_bench_evaluation.py --mode csv --csv-path cti_bench/eval_output/cti-ate_gemini-2.0-flash_20251024_193022.csv |
|
|
|
|
|
# Calculate metrics from CSV with custom model name |
|
|
python cti_bench_evaluation.py --mode csv --csv-path results.csv --csv-model-name my-model |
|
|
""", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--mode", |
|
|
choices=["quick", "full", "test", "csv"], |
|
|
required=True, |
|
|
help="Evaluation mode: 'quick' for quick test, 'full' for complete evaluation, 'test' for connection test, 'csv' for processing existing CSV files", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--datasets", |
|
|
choices=["ate", "mcq", "all"], |
|
|
default="all", |
|
|
help="Which datasets to evaluate: 'ate' for CTI-ATE only, 'mcq' for CTI-MCQ only, 'all' for both (default: all)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--dataset-dir", |
|
|
default="cti_bench/datasets", |
|
|
help="Directory containing CTI Bench dataset files (default: cti_bench/datasets)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
default="cti_bench/eval_output", |
|
|
help="Directory for evaluation output files (default: cti_bench/eval_output)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--llm-model", |
|
|
default="google_genai:gemini-2.0-flash", |
|
|
help="LLM model to use (default: google_genai:gemini-2.0-flash)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--kb-path", |
|
|
default="./cyber_knowledge_base", |
|
|
help="Path to knowledge base (default: ./cyber_knowledge_base)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max-iterations", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Maximum iterations for supervisor (default: 3)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--num-samples", |
|
|
type=int, |
|
|
default=2, |
|
|
help="Number of samples for quick test (default: 2)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--csv-path", |
|
|
help="Path to existing CSV results file (required for csv mode)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--csv-model-name", |
|
|
help="Model name to use in summary (optional, will be extracted from filename if not provided)", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function.""" |
|
|
args = parse_arguments() |
|
|
|
|
|
print("CTI Bench Evaluation Runner") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
if args.mode != "csv": |
|
|
if not setup_environment(args.dataset_dir, args.output_dir): |
|
|
return |
|
|
else: |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.mode == "quick": |
|
|
success = run_evaluation_quick_test( |
|
|
dataset_dir=args.dataset_dir, |
|
|
output_dir=args.output_dir, |
|
|
llm_model=args.llm_model, |
|
|
kb_path=args.kb_path, |
|
|
max_iterations=args.max_iterations, |
|
|
num_samples=args.num_samples, |
|
|
datasets=args.datasets, |
|
|
) |
|
|
elif args.mode == "full": |
|
|
success = run_full_evaluation( |
|
|
dataset_dir=args.dataset_dir, |
|
|
output_dir=args.output_dir, |
|
|
llm_model=args.llm_model, |
|
|
kb_path=args.kb_path, |
|
|
max_iterations=args.max_iterations, |
|
|
datasets=args.datasets, |
|
|
) |
|
|
elif args.mode == "test": |
|
|
success = test_supervisor_connection( |
|
|
llm_model=args.llm_model, kb_path=args.kb_path |
|
|
) |
|
|
elif args.mode == "csv": |
|
|
|
|
|
if not args.csv_path: |
|
|
print("ERROR: --csv-path is required for csv mode") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if not os.path.exists(args.csv_path): |
|
|
print(f"ERROR: CSV file not found: {args.csv_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
success = run_csv_metrics_calculation( |
|
|
csv_path=args.csv_path, |
|
|
output_dir=args.output_dir, |
|
|
model_name=args.csv_model_name, |
|
|
) |
|
|
|
|
|
if success: |
|
|
print("\nOperation completed successfully!") |
|
|
else: |
|
|
print("\nOperation failed!") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|