minhan6559's picture
Upload 126 files
223ef32 verified
raw
history blame
12.6 kB
#!/usr/bin/env python3
"""
Simple Integrated Pipeline - Direct connection between Log Analysis Agent and Retrieval Supervisor
This file replaces the complex full_pipeline structure with a straightforward LangGraph
that passes log analysis results directly to the retrieval supervisor.
"""
# --model groq:openai/gpt-oss-120b
import os
import sys
import time
from pathlib import Path
from typing import Dict, Any, TypedDict
from langchain.chat_models import init_chat_model
from dotenv import load_dotenv
# LangGraph imports
from langgraph.graph import StateGraph, END, START
from langchain_core.messages import HumanMessage
# Add project root to path for agent imports
# Since we're in src/full_pipeline/, go up two levels to project root
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.agents.log_analysis_agent.agent import LogAnalysisAgent
from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor
from src.agents.response_agent.response_agent import ResponseAgent
# Simple state for the pipeline
class PipelineState(TypedDict):
log_file: str
log_analysis_result: Dict[str, Any]
retrieval_result: Dict[str, Any]
response_analysis: Dict[str, Any]
query: str
tactic: str
markdown_report: str
def create_simple_pipeline(
model_name: str = "google_genai:gemini-2.0-flash",
temperature: float = 0.1,
log_agent_output_dir: str = "analysis",
response_agent_output_dir: str = "final_response",
):
"""
Create the simplified pipeline that directly connects the agents.
Args:
model_name: Name of the model to use (e.g., "gemini-2.0-flash", "gpt-oss-120b", "llama-3.1-8b-instant")
temperature: Temperature for model generation
Returns:
Compiled pipeline workflow
"""
# Initialize LLM client directly
print("\n" + "=" * 60)
print("INITIALIZING LLM CLIENT")
print("=" * 60)
print(f"Model: {model_name}")
print(f"Temperature: {temperature}")
print("=" * 60 + "\n")
if "gpt-oss" in model_name:
reasoning_effort = "medium"
reasoning_format = "hidden"
llm_client = init_chat_model(
model_name,
temperature=temperature,
reasoning_effort=reasoning_effort,
reasoning_format=reasoning_format,
)
print(
f"[INFO] Using GPT-OSS model: {model_name} with reasoning effort: {reasoning_effort}"
)
else:
llm_client = init_chat_model(model_name, temperature=temperature)
print(f"[INFO] Initialized with {model_name}")
# Initialize agents with shared LLM client
log_agent = LogAnalysisAgent(
output_dir=log_agent_output_dir, max_iterations=2, llm_client=llm_client
)
retrieval_supervisor = RetrievalSupervisor(
kb_path="./cyber_knowledge_base", max_iterations=2, llm_client=llm_client
)
response_agent = ResponseAgent(
model_name=model_name,
output_dir=response_agent_output_dir,
llm_client=llm_client,
)
def run_log_analysis(state: PipelineState) -> PipelineState:
"""Run log analysis and capture results."""
print("\n" + "=" * 60)
print("PHASE 1: LOG ANALYSIS")
print("=" * 60)
log_file = state["log_file"]
print(f"Analyzing log file: {log_file}")
# Run log analysis (agent should not print its own phase headers)
analysis_result = log_agent.analyze(log_file)
# Store results in state
state["log_analysis_result"] = analysis_result
print(
f"\nLog Analysis Assessment: {analysis_result.get('overall_assessment', 'UNKNOWN')}"
)
print(f"Abnormal Events: {len(analysis_result.get('abnormal_events', []))}")
return state
def run_retrieval_with_context(state: PipelineState) -> PipelineState:
"""Transform log analysis results and run retrieval supervisor."""
print("\n" + "=" * 60)
print("PHASE 2: THREAT INTELLIGENCE RETRIEVAL")
print("=" * 60)
# Get log analysis results
log_analysis_result = state["log_analysis_result"]
assessment = log_analysis_result.get("overall_assessment", "UNKNOWN")
# Create retrieval query based on log analysis
query = create_retrieval_query(log_analysis_result, state.get("query"))
print(f"Generated retrieval query based on {assessment} assessment")
print("\nStarting retrieval supervisor with log analysis context...\n")
# Run retrieval supervisor with trace=True to show terminal output
retrieval_result = retrieval_supervisor.invoke(
query=query,
log_analysis_report=log_analysis_result,
context=state.get("query"),
trace=False, # This shows the agent conversations in terminal
)
# Store retrieval results in state
state["retrieval_result"] = retrieval_result
return state
def run_response_analysis(state: PipelineState) -> PipelineState:
"""Run response agent to create Event ID → MITRE technique mappings."""
print("\n" + "=" * 60)
print("PHASE 3: RESPONSE CORRELATION ANALYSIS")
print("=" * 60)
print("Creating Event ID → MITRE technique mappings...")
# Run response agent analysis (agent should not print its own phase headers)
response_analysis, markdown_report = response_agent.analyze_and_map(
log_analysis_result=state["log_analysis_result"],
retrieval_result=state["retrieval_result"],
log_file=state["log_file"],
tactic=state.get("tactic"),
)
# Store response analysis in state
state["response_analysis"] = response_analysis
# Store the markdown report in state
state["markdown_report"] = markdown_report
# The output path is already saved by analyze_and_map
print(f"Analysis complete! Results saved to final_response folder.")
print(f"\n" + "=" * 60)
print("PIPELINE COMPLETED")
print("=" * 60)
return state
# Create the workflow
workflow = StateGraph(PipelineState)
# Add nodes
workflow.add_node("log_analysis", run_log_analysis)
workflow.add_node("retrieval", run_retrieval_with_context)
workflow.add_node("response", run_response_analysis)
# Define flow
workflow.set_entry_point("log_analysis")
workflow.add_edge("log_analysis", "retrieval")
workflow.add_edge("retrieval", "response")
workflow.add_edge("response", END)
return workflow.compile(name="simple_integrated_pipeline")
def create_retrieval_query(
log_analysis_result: Dict[str, Any], user_query: str = None
) -> str:
"""Transform log analysis results into a retrieval query."""
assessment = log_analysis_result.get("overall_assessment", "UNKNOWN")
analysis_summary = log_analysis_result.get("analysis_summary", "")
abnormal_events = log_analysis_result.get("abnormal_events", [])
if assessment == "NORMAL" and not user_query:
return "Analyze this normal log activity and provide baseline threat intelligence for monitoring purposes."
query_parts = [
"Analyze the detected security anomalies and provide comprehensive threat intelligence.",
"",
f"Log Analysis Assessment: {assessment}",
f"Summary: {analysis_summary}",
"",
]
if abnormal_events:
query_parts.append("Detected Anomalies:")
for i, event in enumerate(abnormal_events[:5], 1): # Top 5 events
event_desc = event.get("event_description", "Unknown event")
severity = event.get("severity", "Unknown")
event_id = event.get("event_id", "N/A")
query_parts.append(f"{i}. Event {event_id} [{severity}]: {event_desc}")
query_parts.append("")
# Add intelligence requirements
query_parts.extend(
[
"Intelligence Requirements:",
"1. Map findings to relevant MITRE ATT&CK techniques and tactics",
"2. Provide threat actor attribution and campaign intelligence",
"3. Generate actionable IOCs and detection recommendations",
"4. Assess threat severity and recommend response actions",
]
)
if user_query:
query_parts.extend(["", f"Additional Context: {user_query}"])
return "\n".join(query_parts)
def analyze_log_file(
log_file: str,
query: str = None,
tactic: str = None,
model_name: str = "google_genai:gemini-2.0-flash",
temperature: float = 0.1,
log_agent_output_dir: str = "analysis",
response_agent_output_dir: str = "final_response",
):
"""
Analyze a single log file through the integrated pipeline.
Args:
log_file: Path to the log file to analyze
query: Optional user query for additional context
tactic: Optional tactic name for organizing output
model_name: Name of the model to use (e.g., "gemini-2.0-flash", "gpt-oss-120b", "llama-3.1-8b-instant")
temperature: Temperature for model generation
log_agent_output_dir: Directory to save log agent output
response_agent_output_dir: Directory to save response agent output
"""
if not os.path.exists(log_file):
print(f"Error: Log file not found: {log_file}")
return
print(f"Starting integrated pipeline analysis...")
print(f"Log file: {log_file}")
print(f"Model: {model_name}")
if tactic:
print(f"Tactic: {tactic}")
print(f"User query: {query or 'None'}")
# Create pipeline with specified model
pipeline = create_simple_pipeline(
model_name=model_name,
temperature=temperature,
log_agent_output_dir=log_agent_output_dir,
response_agent_output_dir=response_agent_output_dir,
)
# Initialize state
initial_state = {
"log_file": log_file,
"log_analysis_result": {},
"retrieval_result": {},
"response_analysis": {},
"query": query or "",
"tactic": tactic or "",
"markdown_report": "",
}
# Run pipeline
start_time = time.time()
final_state = pipeline.invoke(initial_state)
end_time = time.time()
print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")
print("Analysis complete!")
return final_state
def main():
"""Main entry point."""
if len(sys.argv) < 2:
print(
"Usage: python simple_pipeline.py <log_file> [query] [--model MODEL_NAME]"
)
print("\nExamples:")
print(" python simple_pipeline.py sample_log.json")
print(
" python simple_pipeline.py sample_log.json 'Focus on credential access attacks'"
)
print(" python simple_pipeline.py sample_log.json --model gpt-oss-120b")
print("\nAvailable models:")
print(" - google_genai:gemini-2.0-flash")
print(" - google_genai:gemini-1.5-flash")
print(" - groq:gpt-oss-120b")
print(" - groq:gpt-oss-20b")
print(" - groq:llama-3.1-8b-instant")
print(" - groq:llama-3.3-70b-versatile")
sys.exit(1)
log_file = sys.argv[1]
query = None
model_name = "gemini-2.0-flash" # Default model
temperature = 0.1
log_agent_output_dir = "analysis"
response_agent_output_dir = "final_response"
# Parse arguments
i = 2
while i < len(sys.argv):
if sys.argv[i] == "--model" and i + 1 < len(sys.argv):
model_name = sys.argv[i + 1]
i += 2
else:
query = sys.argv[i]
i += 1
# Setup environment
load_dotenv()
# Run analysis
try:
final_state = analyze_log_file(
log_file,
query,
tactic=None,
model_name=model_name,
temperature=temperature,
log_agent_output_dir=log_agent_output_dir,
response_agent_output_dir=response_agent_output_dir,
)
print(final_state["markdown_report"])
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()