File size: 30,439 Bytes
9e3d618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
"""

Retrieval Supervisor - Coordinates CTI Agent, Database Agent, and Grader Agent



This supervisor manages the retrieval pipeline for cybersecurity analysis, coordinating

multiple specialized agents to provide comprehensive threat intelligence and MITRE ATT&CK

technique retrieval.

"""

import json
import os
from typing import Dict, Any, List, Optional
from pathlib import Path
from langchain_core.messages import convert_to_messages

# LangGraph and LangChain imports
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor

# Import your agent classes
from src.agents.cti_agent.cti_agent import CTIAgent
from src.agents.database_agent.agent import DatabaseAgent

# Import prompts
from src.agents.retrieval_supervisor.prompts import (
    GRADER_AGENT_PROMPT,
    SUPERVISOR_PROMPT_TEMPLATE,
    INPUT_MESSAGE_TEMPLATE,
    LOG_ANALYSIS_SECTION_TEMPLATE,
    CONTEXT_SECTION_TEMPLATE,
)


class RetrievalSupervisor:
    """

    Retrieval Supervisor that coordinates CTI Agent, Database Agent, and Grader Agent

    using LangGraph's supervisor pattern for comprehensive threat intelligence retrieval.

    """

    def __init__(

        self,

        llm_model: str = "google_genai:gemini-2.0-flash",

        kb_path: str = "./cyber_knowledge_base",

        max_iterations: int = 3,

        llm_client=None,

    ):
        """

        Initialize the Retrieval Supervisor.



        Args:

            llm_model: Specific model to use

            kb_path: Path to the cyber knowledge base

            max_iterations: Maximum iterations for the retrieval pipeline

            llm_client: Optional pre-initialized LLM client (overrides llm_model)

        """
        self.max_iterations = max_iterations
        self.llm_model = llm_model

        # Initialize the supervisor LLM
        if llm_client:
            self.llm_client = llm_client
            print(f"[INFO] Retrieval Supervisor: Using provided LLM client")
        elif "gpt-oss" in llm_model:
            reasoning_effort = "low"
            reasoning_format = "hidden"
            self.llm_client = init_chat_model(
                llm_model,
                temperature=0.1,
                reasoning_effort=reasoning_effort,
                reasoning_format=reasoning_format,
            )
            print(
                f"[INFO] Retrieval Supervisor: Using GPT-OSS model: {llm_model} with reasoning effort: {reasoning_effort}"
            )
        else:
            self.llm_client = init_chat_model(llm_model, temperature=0.1)
            print(f"[INFO] Retrieval Supervisor: Initialized with {llm_model}")

        # Initialize agents
        # self.cti_agent = self._initialize_cti_agent()
        self.database_agent = self._initialize_database_agent(kb_path)
        self.grader_agent = self._initialize_grader_agent()

        # Create the supervisor
        self.supervisor = self._create_supervisor()

    def _initialize_cti_agent(self) -> CTIAgent:
        """Initialize the CTI Agent."""
        try:
            cti_agent = CTIAgent(llm=self.llm_client)
            print("CTI Agent initialized successfully")
            return cti_agent
        except Exception as e:
            print(f"Failed to initialize CTI Agent: {e}")
            raise

    def _initialize_database_agent(self, kb_path: str) -> DatabaseAgent:
        """Initialize the Database Agent."""
        try:
            database_agent = DatabaseAgent(
                kb_path=kb_path,
                llm_client=self.llm_client,
            )
            print("Database Agent initialized successfully")
            return database_agent
        except Exception as e:
            print(f"Failed to initialize Database Agent: {e}")
            raise

    def _initialize_grader_agent(self):
        """Initialize the Grader Agent as a ReAct agent with no tools."""
        return create_react_agent(
            model=self.llm_client,
            tools=[],  # No tools for grader
            prompt=GRADER_AGENT_PROMPT,
            name="retrieval_grader_agent",
        )

    def _create_supervisor(self):
        """Create the supervisor using langgraph_supervisor."""

        # Prepare agent list with CompiledStateGraph objects
        agents = [
            self.database_agent.agent,  # Database Agent's ReAct agent
            self.grader_agent,  # Grader Agent (ReAct agent)
        ]

        # Format supervisor prompt with max_iterations
        supervisor_prompt = SUPERVISOR_PROMPT_TEMPLATE.format(
            max_iterations=self.max_iterations
        )

        return create_supervisor(
            model=self.llm_client,
            agents=agents,
            prompt=supervisor_prompt,
            add_handoff_back_messages=True,
            # output_mode="full_history",
            supervisor_name="retrieval_supervisor",
        ).compile(name="retrieval_supervisor")

    def invoke(

        self,

        query: str,

        log_analysis_report: Optional[Dict[str, Any]] = None,

        context: Optional[str] = None,

        trace: bool = False,

    ) -> Dict[str, Any]:
        """

        Invoke the retrieval supervisor pipeline.



        Args:

            query: The intelligence retrieval query/task

            log_analysis_report: Optional log analysis report from log analysis agent

            context: Optional additional context

            trace: Whether to trace the pipeline

        Returns:

            Dictionary containing the structured retrieval results

        """
        try:
            # Build the input message with context
            input_content = self._build_input_message(
                query, log_analysis_report, context
            )

            # Initialize state
            initial_state = {"messages": [HumanMessage(content=input_content)]}

            # print("\n" + "=" * 60)
            # print("RETRIEVAL SUPERVISOR PIPELINE STARTING")
            # print("=" * 60)
            # print(f"Query: {query}")
            # if log_analysis_report:
            #     print(
            #         f"Log Analysis Report Assessment: {log_analysis_report.get('overall_assessment', 'Unknown')} assessment"
            #     )
            # print()

            # Execute the supervisor pipeline
            raw_result = self.supervisor.invoke(initial_state)

            if trace:
                self._print_trace_pipeline(raw_result)

            # Parse structured output from the supervisor
            structured_result = self._parse_supervisor_output(raw_result, query)

            return structured_result
        except Exception as e:
            print(f"[ERROR] Retrieval Supervisor pipeline failed: {e}")
            raise

    def invoke_direct_query(self, query: str, trace: bool = False) -> Dict[str, Any]:
        """Invoke the retrieval supervisor pipeline with a direct query."""
        raw_result = self.supervisor.invoke({"messages": [HumanMessage(content=query)]})
        if trace:
            self._print_trace_pipeline(raw_result)

        # Parse structured output from the supervisor
        structured_result = self._parse_supervisor_output(raw_result, query)
        return structured_result

    def stream(

        self,

        query: str,

        log_analysis_report: Optional[Dict[str, Any]] = None,

        context: Optional[str] = None,

    ):
        # Build the input message with context
        input_content = self._build_input_message(query, log_analysis_report, context)

        # Initialize state
        initial_state = {"messages": [HumanMessage(content=input_content)]}

        for chunk in self.supervisor.stream(initial_state, subgraphs=True):
            self._pretty_print_messages(chunk, last_message=True)

    def _pretty_print_message(self, message, indent=False):
        pretty_message = message.pretty_repr(html=True)
        if not indent:
            print(pretty_message)
            return

        indented = "\n".join("\t" + c for c in pretty_message.split("\n"))
        print(indented)

    def _pretty_print_messages(self, update, last_message=False):
        is_subgraph = False
        if isinstance(update, tuple):
            ns, update = update
            # skip parent graph updates in the printouts
            if len(ns) == 0:
                return

            graph_id = ns[-1].split(":")[0]
            print(f"Update from subgraph {graph_id}:")
            print("\n")
            is_subgraph = True

        for node_name, node_update in update.items():
            update_label = f"Update from node {node_name}:"
            if is_subgraph:
                update_label = "\t" + update_label

            print(update_label)
            print("\n")

            messages = convert_to_messages(node_update["messages"])
            if last_message:
                messages = messages[-1:]

            for m in messages:
                self._pretty_print_message(m, indent=is_subgraph)
            print("\n")

    def _print_trace_pipeline(self, result: Dict[str, Any]):
        """Print detailed trace of the pipeline execution with message flow."""
        messages = result.get("messages", [])

        if not messages:
            print("[TRACE] No messages found in pipeline result")
            return

        print("\n" + "=" * 60)
        print("PIPELINE EXECUTION TRACE")
        print("=" * 60)

        # Print all messages with detailed formatting
        for i, msg in enumerate(messages, 1):
            print(f"\n--- Message {i} ---")

            if isinstance(msg, HumanMessage):
                print(f"[Human] {msg.content}")

            elif isinstance(msg, AIMessage):
                agent_name = getattr(msg, "name", None) or "agent"
                print(f"[Agent:{agent_name}] {msg.content}")

                # Check for function calls
                if (
                    hasattr(msg, "additional_kwargs")
                    and "function_call" in msg.additional_kwargs
                ):
                    fc = msg.additional_kwargs["function_call"]
                    print(f"  [ToolCall] {fc.get('name')}: {fc.get('arguments')}")

            elif isinstance(msg, ToolMessage):
                tool_name = getattr(msg, "name", None) or "tool"
                content = (
                    msg.content if isinstance(msg.content, str) else str(msg.content)
                )
                # Truncate long content for readability
                preview = content[:300] + ("..." if len(content) > 300 else "")
                print(f"[Tool:{tool_name}] {preview}")

            else:
                print(f"[Message] {getattr(msg, 'content', '')}")

        # Print final supervisor decision if available
        if messages:
            latest_message = messages[-1]
            if isinstance(latest_message, AIMessage):
                print(f"\n--- Final Supervisor Output ---")
                print(latest_message.content)

                # Check if this looks like a grader decision
                if "decision" in latest_message.content.lower():
                    try:
                        # Try to parse JSON decision
                        content = latest_message.content
                        if "{" in content and "}" in content:
                            start = content.find("{")
                            end = content.rfind("}") + 1
                            decision_json = json.loads(content[start:end])

                            decision = decision_json.get("decision", "unknown")
                            print(
                                f"\n[SUCCESS] Pipeline completed - Decision: {decision}"
                            )

                            if decision == "ACCEPT":
                                print("Results accepted by grader")
                            elif decision == "NEEDS_MITRE":
                                print("Additional MITRE technique analysis needed")

                    except (json.JSONDecodeError, KeyError):
                        print("\n[INFO] Pipeline completed (decision parsing failed)")

        print("\n" + "=" * 60)
        print("TRACE COMPLETED")
        print("=" * 60)

    def _build_input_message(

        self,

        query: str,

        log_analysis_report: Optional[Dict[str, Any]],

        context: Optional[str],

    ) -> str:
        """Build the input message for the supervisor."""

        # Build log analysis section
        log_analysis_section = ""
        if log_analysis_report:
            log_analysis_section = LOG_ANALYSIS_SECTION_TEMPLATE.format(
                log_analysis_report=json.dumps(log_analysis_report, indent=2)
            )

        # Build context section
        context_section = ""
        if context:
            context_section = CONTEXT_SECTION_TEMPLATE.format(context=context)

        # Build complete input message
        input_message = INPUT_MESSAGE_TEMPLATE.format(
            query=query,
            log_analysis_section=log_analysis_section,
            context_section=context_section,
        )

        return input_message

    def _parse_supervisor_output(

        self, raw_result: Dict[str, Any], original_query: str

    ) -> Dict[str, Any]:
        """Parse the supervisor's structured output from the raw result."""
        messages = raw_result.get("messages", [])

        # Look for the final supervisor message with structured JSON output
        final_supervisor_message = None
        for msg in reversed(messages):
            if (
                hasattr(msg, "name")
                and msg.name == "supervisor"
                and hasattr(msg, "content")
                and msg.content
            ):
                final_supervisor_message = msg.content
                break

        if not final_supervisor_message:
            # Fallback: use the last message
            if messages:
                final_supervisor_message = (
                    messages[-1].content if hasattr(messages[-1], "content") else ""
                )

        # Try to extract JSON from the supervisor's final message
        structured_output = self._extract_json_from_content(final_supervisor_message)

        if structured_output:
            # Validate and enhance the structured output
            return self._validate_and_enhance_output(
                structured_output, original_query, messages
            )
        else:
            # Fallback: create structured output from message analysis
            return self._create_fallback_output(messages, original_query)

    def _extract_json_from_content(self, content: str) -> Optional[Dict[str, Any]]:
        """Extract JSON from supervisor message content."""
        if not content:
            return None

        # Look for JSON blocks
        if "```json" in content:
            json_blocks = content.split("```json")
            for block in json_blocks[1:]:
                json_str = block.split("```")[0].strip()
                try:
                    return json.loads(json_str)
                except json.JSONDecodeError:
                    continue

        # Look for any JSON-like structures
        start_idx = 0
        while True:
            start_idx = content.find("{", start_idx)
            if start_idx == -1:
                break

            # Find matching closing brace
            brace_count = 0
            end_idx = start_idx
            for i in range(start_idx, len(content)):
                if content[i] == "{":
                    brace_count += 1
                elif content[i] == "}":
                    brace_count -= 1
                    if brace_count == 0:
                        end_idx = i + 1
                        break

            if brace_count == 0:
                json_str = content[start_idx:end_idx]
                try:
                    return json.loads(json_str)
                except json.JSONDecodeError:
                    pass

            start_idx += 1

        return None

    def _validate_and_enhance_output(

        self, structured_output: Dict[str, Any], original_query: str, messages: List

    ) -> Dict[str, Any]:
        """Validate and enhance the structured output."""
        # Ensure required fields exist
        if "status" not in structured_output:
            structured_output["status"] = "SUCCESS"

        if "final_assessment" not in structured_output:
            structured_output["final_assessment"] = "ACCEPTED"

        if "retrieved_techniques" not in structured_output:
            structured_output["retrieved_techniques"] = []

        if "agents_used" not in structured_output:
            # Extract agents used from messages
            agents_used = set()
            for msg in messages:
                if hasattr(msg, "name") and msg.name:
                    agents_used.add(str(msg.name))
            structured_output["agents_used"] = list(agents_used)

        if "summary" not in structured_output:
            technique_count = len(structured_output.get("retrieved_techniques", []))
            structured_output["summary"] = (
                f"Retrieved {technique_count} MITRE techniques for analysis"
            )

        if "iteration_count" not in structured_output:
            structured_output["iteration_count"] = 1

        # Add metadata
        structured_output["query"] = original_query
        structured_output["total_techniques"] = len(
            structured_output.get("retrieved_techniques", [])
        )

        return structured_output

    def _create_fallback_output(

        self, messages: List, original_query: str

    ) -> Dict[str, Any]:
        """Create fallback structured output when JSON parsing fails."""
        # Extract techniques from database agent messages
        techniques = []
        agents_used = set()

        for msg in messages:
            if hasattr(msg, "name") and msg.name:
                agents_used.add(str(msg.name))

                # Look for database agent results
                if "database" in str(msg.name).lower() and hasattr(msg, "content"):
                    try:
                        # Try to extract techniques from tool messages
                        if hasattr(msg, "name") and "search_techniques" in str(
                            msg.name
                        ):
                            tool_data = (
                                json.loads(msg.content)
                                if isinstance(msg.content, str)
                                else msg.content
                            )
                            if "techniques" in tool_data:
                                for tech in tool_data["techniques"]:
                                    # Convert tactics to list format
                                    tactics = tech.get("tactics", [])
                                    if isinstance(tactics, str):
                                        tactics = [tactics] if tactics else []
                                    elif not isinstance(tactics, list):
                                        tactics = []

                                    technique = {
                                        "technique_id": tech.get("attack_id", ""),
                                        "technique_name": tech.get("name", ""),
                                        "tactic": tactics,  # Now as list
                                        "description": tech.get("description", ""),
                                        "relevance_score": tech.get(
                                            "relevance_score", 0.5
                                        ),
                                    }
                                    techniques.append(technique)
                    except (json.JSONDecodeError, TypeError, AttributeError):
                        continue

        return {
            "status": "PARTIAL",
            "final_assessment": "NEEDS_MORE_INFO",
            "retrieved_techniques": techniques,
            "agents_used": list(agents_used),
            "summary": f"Retrieved {len(techniques)} MITRE techniques (fallback parsing)",
            "iteration_count": 1,
            "query": original_query,
            "total_techniques": len(techniques),
            "parsing_method": "fallback",
        }

    def _process_results(

        self, result: Dict[str, Any], original_query: str

    ) -> Dict[str, Any]:
        """Process and format the supervisor results."""

        messages = result.get("messages", [])

        # Extract information from messages
        agents_used = set()
        cti_results = []
        database_results = []
        grader_decisions = []

        for msg in messages:
            if hasattr(msg, "name"):
                agent_name = msg.name
                if agent_name:  # ignore None or empty
                    agents_used.add(str(agent_name))

                if agent_name == "database_agent":
                    database_results.append(msg.content)
                elif agent_name == "retrieval_grader_agent":
                    grader_decisions.append(msg.content)

        # Get final supervisor message
        final_message = ""
        for msg in reversed(messages):
            if (
                isinstance(msg, AIMessage)
                and hasattr(msg, "name")
                and msg.name == "supervisor"
            ):
                final_message = msg.content
                break

        # Determine final assessment
        final_assessment = self._determine_final_assessment(
            grader_decisions, final_message
        )

        # Extract recommendations
        recommendations = self._extract_recommendations(
            cti_results, database_results, grader_decisions
        )

        return {
            "status": "SUCCESS",
            "query": original_query,
            "agents_used": [
                a for a in list(agents_used) if isinstance(a, str) and a.strip()
            ],
            "results": {
                "cti_intelligence": cti_results,
                "mitre_techniques": database_results,
                "quality_assessments": grader_decisions,
                "supervisor_synthesis": final_message,
            },
            "final_assessment": final_assessment,
            "recommendations": recommendations,
            "message_history": messages,
            "summary": self._generate_summary(
                cti_results, database_results, final_assessment
            ),
        }

    def _determine_final_assessment(

        self, grader_decisions: List[str], final_message: str

    ) -> str:
        """Determine the final assessment based on grader decisions."""

        # Look for the latest grader decision
        if grader_decisions:
            latest_decision = grader_decisions[-1]
            try:
                # Try to parse JSON from grader
                if "{" in latest_decision and "}" in latest_decision:
                    start = latest_decision.find("{")
                    end = latest_decision.rfind("}") + 1
                    decision_json = json.loads(latest_decision[start:end])
                    return decision_json.get("decision", "UNKNOWN")
            except json.JSONDecodeError:
                pass

        # Fallback to content analysis
        content = (final_message + " " + " ".join(grader_decisions)).lower()
        if "accept" in content:
            return "ACCEPTED"
        elif "needs_both" in content:
            return "NEEDS_BOTH"
        elif "needs_cti" in content:
            return "NEEDS_CTI"
        elif "needs_mitre" in content:
            return "NEEDS_MITRE"
        else:
            return "COMPLETED"

    def _extract_recommendations(

        self,

        cti_results: List[str],

        database_results: List[str],

        grader_decisions: List[str],

    ) -> List[str]:
        """Extract actionable recommendations from the results."""

        recommendations = []

        # Standard recommendations based on results
        if cti_results:
            recommendations.append("Review CTI findings for threat actor attribution")
            recommendations.append("Implement IOC-based detection rules")

        if database_results:
            recommendations.append("Map detected techniques to defensive controls")
            recommendations.append("Update threat hunting playbooks")

        # Extract specific recommendations from grader
        for decision in grader_decisions:
            try:
                if "{" in decision and "}" in decision:
                    start = decision.find("{")
                    end = decision.rfind("}") + 1
                    decision_json = json.loads(decision[start:end])
                    suggestions = decision_json.get("improvement_suggestions", [])
                    recommendations.extend(suggestions)
            except json.JSONDecodeError:
                continue

        # Remove duplicates and limit
        unique_recommendations = list(dict.fromkeys(recommendations))
        return unique_recommendations[:5]  # Top 5 recommendations

    def _generate_summary(

        self, cti_results: List[str], database_results: List[str], final_assessment: str

    ) -> str:
        """Generate a concise summary of the retrieval results."""

        summary_parts = [
            f"Retrieval Status: {final_assessment}",
            f"CTI Sources Analyzed: {len(cti_results)}",
            f"MITRE Techniques Retrieved: {len(database_results)}",
        ]

        if cti_results:
            summary_parts.append("Threat intelligence gathered from external sources")
        if database_results:
            summary_parts.append("MITRE ATT&CK techniques mapped to findings")

        return " | ".join(summary_parts)

    def stream_invoke(

        self,

        query: str,

        log_analysis_report: Optional[Dict[str, Any]] = None,

        context: Optional[str] = None,

    ):
        """

        Stream the retrieval supervisor pipeline execution.



        Args:

            query: The intelligence retrieval query/task

            log_analysis_report: Optional log analysis report from log analysis agent

            context: Optional additional context



        Yields:

            Streaming updates from the supervisor pipeline

        """
        try:
            # Build the input message with context
            input_content = self._build_input_message(
                query, log_analysis_report, context
            )

            # Initialize state
            initial_state = {"messages": [HumanMessage(content=input_content)]}

            # print("\n" + "=" * 60)
            # print("RETRIEVAL SUPERVISOR PIPELINE STREAMING")
            # print("=" * 60)
            # print(f"Query: {query}")
            # print()

            # Stream the supervisor pipeline
            for chunk in self.supervisor.stream(initial_state):
                yield chunk

        except Exception as e:
            yield {"error": str(e)}


# Example usage and testing
def test_retrieval_supervisor():
    """Test the Retrieval Supervisor with sample data."""

    # Sample log analysis report
    sample_report = {
        "overall_assessment": "ABNORMAL",
        "total_events_analyzed": 245,
        "analysis_summary": "Detected suspicious PowerShell execution with base64 encoding and potential credential access attempts targeting LSASS process",
        "abnormal_events": [
            {
                "event_id": "4688",
                "event_description": "PowerShell process creation with encoded command parameter",
                "why_abnormal": "Base64 encoded command suggests obfuscation and evasion techniques",
                "severity": "HIGH",
                "potential_threat": "Defense evasion or malware execution",
                "attack_category": "defense_evasion",
            },
            {
                "event_id": "4656",
                "event_description": "Handle request to LSASS process memory",
                "why_abnormal": "Unusual access pattern to sensitive authentication process",
                "severity": "CRITICAL",
                "potential_threat": "Credential dumping attack",
                "attack_category": "credential_access",
            },
        ],
    }

    try:
        # Initialize supervisor
        supervisor = RetrievalSupervisor()

        # Test query
        query = "Analyze the detected PowerShell and LSASS access patterns. Provide threat intelligence on related attack campaigns and map to MITRE ATT&CK techniques."

        # Execute retrieval with trace enabled
        results = supervisor.invoke(
            query=query,
            log_analysis_report=sample_report,
            context="High-priority security incident requiring immediate threat intelligence",
            trace=True,
        )

        # Display results
        print("=" * 60)
        print("RETRIEVAL RESULTS SUMMARY")
        print("=" * 60)
        print(f"Status: {results['status']}")
        print(f"Final Assessment: {results['final_assessment']}")
        print(f"Agents Used: {', '.join(results['agents_used'])}")
        print(f"\nSummary: {results['summary']}")

        print("\nRecommendations:")
        for i, rec in enumerate(results["recommendations"], 1):
            print(f"{i}. {rec}")

        return results

    except Exception as e:
        print(f"Test failed: {e}")
        return None


if __name__ == "__main__":
    test_retrieval_supervisor()