File size: 6,058 Bytes
9590080
 
 
527c902
0fa0473
9590080
0fa0473
9590080
527c902
 
9590080
 
 
 
 
 
 
 
 
527c902
 
 
 
 
9590080
 
527c902
 
9590080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa0473
9590080
 
 
527c902
 
9590080
 
527c902
0fa0473
9590080
 
 
0fa0473
9590080
527c902
 
 
9590080
 
 
0fa0473
9590080
 
 
 
527c902
9590080
 
 
 
 
527c902
9590080
527c902
 
9590080
527c902
 
 
9590080
 
527c902
9590080
527c902
 
 
 
 
 
 
9590080
527c902
 
 
9590080
527c902
 
 
9590080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# ==============================================================================
# 1. IMPORTS AND SETUP
# ==============================================================================
import os
from dotenv import load_dotenv
from typing import TypedDict, Annotated, List

# LangChain and LangGraph imports
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

# ==============================================================================
# 2. LOAD API KEYS AND DEFINE TOOLS
# ==============================================================================
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
tavily_api_key = os.getenv("TAVILY_API_KEY")

if not hf_token or not tavily_api_key:
    # This will show a clear error in the logs if keys are missing
    raise ValueError("HF_TOKEN or TAVILY_API_KEY not set. Please add them to your Space secrets.")
os.environ["TAVILY_API_KEY"] = tavily_api_key

# The agent's tools
tools = [TavilySearchResults(max_results=3, description="A search engine for finding up-to-date information on the web."), PythonREPLTool()]
tool_node = ToolNode(tools)

# ==============================================================================
# 3. CONFIGURE THE LLM (THE "BRAIN")
# ==============================================================================
# The model we'll use as the agent's brain
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# The system prompt gives the agent its mission and instructions
SYSTEM_PROMPT = """You are a highly capable AI agent named 'GAIA-Solver'. Your mission is to accurately answer complex questions.

**Your Instructions:**
1.  **Analyze:** Carefully read the user's question to understand all parts of what is being asked.
2.  **Plan:** Think step-by-step. Break the problem into smaller tasks. Decide which tool is best for each task. (e.g., use 'tavily_search_results_json' for web searches, use 'python_repl' for calculations or code execution).
3.  **Execute:** Call ONE tool at a time.
4.  **Observe & Reason:** After getting a tool's result, observe it. Decide if you have the final answer or if you need to use another tool.
5.  **Final Answer:** Once you are confident, provide a clear, direct, and concise final answer. Do not include your thought process in the final answer.
"""

# Initialize the LLM endpoint
llm = HuggingFaceEndpoint(
    repo_id=repo_id,
    huggingfacehub_api_token=hf_token,
    temperature=0, # Set to 0 for deterministic, less random output
    max_new_tokens=2048,
)

# ==============================================================================
# 4. BUILD THE LANGGRAPH AGENT
# ==============================================================================

# Define the Agent's State (its memory)
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], lambda x, y: x + y]

# This is a more robust way to combine the prompt, model, and tool binding
# It ensures the system prompt is always used.
llm_with_tools = llm.bind_tools(tools)

# Define the Agent Node
def agent_node(state):
    # Get the last message to pass to the model
    last_message = state['messages'][-1]

    # Prepend the system prompt to every call
    prompt_with_system = [
        HumanMessage(content=SYSTEM_PROMPT, name="system_prompt"),
        last_message
    ]

    response = llm_with_tools.invoke(prompt_with_system)
    return {"messages": [response]}

# Define the Edge Logic
def should_continue(state):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "tools" # Route to the tool node
    return END # End the process

# Assemble the graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {"tools": "tools", "end": END},
)
workflow.add_edge("tools", "agent")

# Compile the graph into a runnable app
app = workflow.compile()


# ==============================================================================
# 5. THE BASICAGENT CLASS (FOR THE TEST HARNESS)
# This MUST be at the end, after `app` is defined.
# ==============================================================================
class BasicAgent:
    """
    This is the agent class that the GAIA test harness will use.
    """
    def __init__(self):
        # The compiled LangGraph app is our agent executor
        self.agent_executor = app

    def run(self, question: str) -> str:
        """
        This method is called by the test script with each question.
        It runs the LangGraph agent and returns the final answer.
        """
        print(f"Agent received question (first 80 chars): {question[:80]}...")
        try:
            # Format the input for our graph
            inputs = {"messages": [HumanMessage(content=question)]}

            # Stream the response to get the final answer
            final_response = ""
            for s in self.agent_executor.stream(inputs, {"recursion_limit": 15}):
                if "agent" in s:
                    # The final answer is the content of the last message from the agent node
                    if s["agent"]["messages"][-1].content:
                         final_response = s["agent"]["messages"][-1].content

            # A fallback in case the agent finishes without a clear message
            if not final_response:
                final_response = "Agent finished but did not produce a final answer."

            print(f"Agent returning final answer (first 80 chars): {final_response[:80]}...")
            return final_response

        except Exception as e:
            print(f"An error occurred in agent execution: {e}")
            return f"Error: {e}"