chinmayjha commited on
Commit
8c6064d
Β·
unverified Β·
1 Parent(s): 4af9b90

Improve agent output formatting with inline citations and full sources

Browse files

- Extract answer_with_sources output directly from agent steps to bypass final_answer reformatting
- Add inline [Doc X] citations in the answer section
- Include full Sources section with document metadata, summaries, and key findings
- Update AgentWrapper to extract Step 2 output when answer_with_sources is found
- Increase max_steps from 2 to 4 to allow for more complex queries
- Pass AgentWrapper directly to UI instead of unwrapping to use custom run() method
- Update both tools/app.py and app.py (HF entry point) to use AgentWrapper
- Simplify UI to display raw agent output without additional parsing

app.py CHANGED
@@ -54,11 +54,9 @@ def main():
54
  # Initialize agent
55
  agent = get_agent(retriever_config_path=Path(retriever_config_path))
56
 
57
- # Get the actual agent from the wrapper
58
- actual_agent = agent._AgentWrapper__agent
59
-
60
  # Launch custom UI
61
- CustomGradioUI(actual_agent).launch(
62
  server_name="0.0.0.0",
63
  server_port=7860,
64
  share=False
 
54
  # Initialize agent
55
  agent = get_agent(retriever_config_path=Path(retriever_config_path))
56
 
57
+ # Pass the AgentWrapper directly so it uses our custom run() method with extraction logic
 
 
58
  # Launch custom UI
59
+ CustomGradioUI(agent).launch(
60
  server_name="0.0.0.0",
61
  server_port=7860,
62
  share=False
src/second_brain_online/application/agents/agents.py CHANGED
@@ -73,19 +73,50 @@ class AgentWrapper:
73
  def run(self, task: str, **kwargs) -> Any:
74
  result = self.__agent.run(task, return_full_result=True, **kwargs)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Extract the raw output from answer_with_sources (Step 2) instead of using final_answer
77
  if hasattr(result, 'steps') and len(result.steps) >= 2:
78
  # Find the step where answer_with_sources was called
79
- for step in result.steps:
80
- if 'tool_calls' in step and step['tool_calls']:
81
  for tool_call in step['tool_calls']:
82
- if tool_call.get('function', {}).get('name') == 'answer_with_sources':
 
 
 
 
 
 
 
 
 
 
83
  # Found the answer_with_sources step - return its observations
84
  if 'observations' in step and step['observations']:
85
- logger.info("Returning raw answer_with_sources output (bypassing final reformatting)")
86
  return step['observations']
87
 
88
  # Fallback to regular result.output
 
89
  if hasattr(result, 'output'):
90
  return result.output
91
 
 
73
  def run(self, task: str, **kwargs) -> Any:
74
  result = self.__agent.run(task, return_full_result=True, **kwargs)
75
 
76
+ # Debug: Print step structure to understand the data
77
+ logger.info(f"Result type: {type(result)}")
78
+ if hasattr(result, 'steps'):
79
+ logger.info(f"Number of steps: {len(result.steps)}")
80
+ for i, step in enumerate(result.steps):
81
+ logger.info(f"Step {i}: type={type(step)}, keys={step.keys() if isinstance(step, dict) else 'not a dict'}")
82
+ if isinstance(step, dict) and 'tool_calls' in step:
83
+ logger.info(f" Tool calls: {step['tool_calls']}")
84
+ if step['tool_calls']:
85
+ for tc in step['tool_calls']:
86
+ tc_type = type(tc)
87
+ if isinstance(tc, dict):
88
+ logger.info(f" Tool call dict: {tc}")
89
+ else:
90
+ logger.info(f" Tool call object: {tc}, type: {tc_type}")
91
+ if hasattr(tc, 'function'):
92
+ logger.info(f" Function: {tc.function}")
93
+ if hasattr(tc, 'name'):
94
+ logger.info(f" Name: {tc.name}")
95
+
96
  # Extract the raw output from answer_with_sources (Step 2) instead of using final_answer
97
  if hasattr(result, 'steps') and len(result.steps) >= 2:
98
  # Find the step where answer_with_sources was called
99
+ for step_idx, step in enumerate(result.steps):
100
+ if isinstance(step, dict) and 'tool_calls' in step and step['tool_calls']:
101
  for tool_call in step['tool_calls']:
102
+ # Handle both dict and object formats
103
+ tool_name = None
104
+ if isinstance(tool_call, dict):
105
+ tool_name = tool_call.get('function', {}).get('name')
106
+ elif hasattr(tool_call, 'function'):
107
+ if hasattr(tool_call.function, 'name'):
108
+ tool_name = tool_call.function.name
109
+ elif hasattr(tool_call, 'name'):
110
+ tool_name = tool_call.name
111
+
112
+ if tool_name == 'answer_with_sources':
113
  # Found the answer_with_sources step - return its observations
114
  if 'observations' in step and step['observations']:
115
+ logger.info(f"βœ… Found answer_with_sources at step {step_idx}, returning its observations")
116
  return step['observations']
117
 
118
  # Fallback to regular result.output
119
+ logger.warning("⚠️ answer_with_sources output not found, falling back to result.output")
120
  if hasattr(result, 'output'):
121
  return result.output
122
 
src/second_brain_online/application/ui/custom_gradio_ui.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  import re
3
- from typing import Any, Dict, List, Tuple, Optional
4
  from datetime import datetime
5
 
6
  import gradio as gr
@@ -14,7 +14,12 @@ from second_brain_online.config import settings
14
  class CustomGradioUI:
15
  """Custom Gradio UI for better formatting of agent responses with source attribution."""
16
 
17
- def __init__(self, agent: ToolCallingAgent):
 
 
 
 
 
18
  self.agent = agent
19
  self.mongodb_client = None
20
  self.database = None
@@ -170,6 +175,16 @@ class CustomGradioUI:
170
  # Quick post-processing steps
171
  progress(0.8, desc="✨ Displaying results...")
172
 
 
 
 
 
 
 
 
 
 
 
173
  # Convert result to string
174
  result_str = str(result)
175
 
 
1
  import json
2
  import re
3
+ from typing import Any, Dict, List, Tuple, Optional, Union
4
  from datetime import datetime
5
 
6
  import gradio as gr
 
14
  class CustomGradioUI:
15
  """Custom Gradio UI for better formatting of agent responses with source attribution."""
16
 
17
+ def __init__(self, agent: Union[ToolCallingAgent, Any]):
18
+ """Initialize the UI with either a ToolCallingAgent or AgentWrapper.
19
+
20
+ Args:
21
+ agent: Either a raw ToolCallingAgent or an AgentWrapper that wraps it.
22
+ """
23
  self.agent = agent
24
  self.mongodb_client = None
25
  self.database = None
 
175
  # Quick post-processing steps
176
  progress(0.8, desc="✨ Displaying results...")
177
 
178
+ # CRITICAL DEBUG: Print what result actually is
179
+ print("\n" + "="*80)
180
+ print("DEBUG: WHAT IS RESULT?")
181
+ print("="*80)
182
+ print(f"Type: {type(result)}")
183
+ print(f"Is string? {isinstance(result, str)}")
184
+ print(f"Has πŸ“š Sources? {'πŸ“š Sources' in str(result) if result else False}")
185
+ print(f"First 1500 chars of result:\n{str(result)[:1500]}")
186
+ print("="*80)
187
+
188
  # Convert result to string
189
  result_str = str(result)
190
 
tools/app.py CHANGED
@@ -35,9 +35,8 @@ def main(retriever_config_path: Path, ui: bool, query: str) -> None:
35
  """
36
  agent = get_agent(retriever_config_path=Path(retriever_config_path))
37
  if ui:
38
- # Get the actual agent from the wrapper
39
- actual_agent = agent._AgentWrapper__agent
40
- CustomGradioUI(actual_agent).launch()
41
  else:
42
  assert query, "Query is required in CLI mode"
43
 
@@ -71,33 +70,12 @@ def main(retriever_config_path: Path, ui: bool, query: str) -> None:
71
  print(f"State: {actual_agent.state}")
72
  print("="*80)
73
 
74
- # Parse the result using the same logic as the UI
75
- ui_instance = CustomGradioUI(None) # We don't need the agent for parsing
76
-
77
- # Get agent logs if available
78
- agent_logs = []
79
- if hasattr(agent, '_AgentWrapper__agent'):
80
- actual_agent = agent._AgentWrapper__agent
81
- if hasattr(actual_agent, 'logs'):
82
- agent_logs = actual_agent.logs
83
-
84
- answer, sources, tools_used = ui_instance.parse_agent_response(result, agent_logs)
85
-
86
- print("\n" + "="*80)
87
- print("DEBUG: PARSED RESULTS")
88
- print("="*80)
89
- print(f"Answer: {answer}")
90
- print(f"Sources ({len(sources)}): {sources}")
91
- print(f"Tools Used: {tools_used}")
92
- print("="*80)
93
-
94
  print("\n" + "="*80)
95
  print("FINAL OUTPUT")
96
  print("="*80)
97
-
98
- # Format the answer for better display
99
- formatted_answer = ui_instance.format_answer(answer)
100
- print(formatted_answer)
101
 
102
 
103
  if __name__ == "__main__":
 
35
  """
36
  agent = get_agent(retriever_config_path=Path(retriever_config_path))
37
  if ui:
38
+ # Pass the AgentWrapper directly so it uses our custom run() method
39
+ CustomGradioUI(agent).launch()
 
40
  else:
41
  assert query, "Query is required in CLI mode"
42
 
 
70
  print(f"State: {actual_agent.state}")
71
  print("="*80)
72
 
73
+ # Display the raw result directly (it's already perfectly formatted)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  print("\n" + "="*80)
75
  print("FINAL OUTPUT")
76
  print("="*80)
77
+ print(result)
78
+ print("="*80)
 
 
79
 
80
 
81
  if __name__ == "__main__":