kamorou's picture
Update agent.py
9590080 verified
# ==============================================================================
# 1. IMPORTS AND SETUP
# ==============================================================================
import os
from dotenv import load_dotenv
from typing import TypedDict, Annotated, List
# LangChain and LangGraph imports
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
# ==============================================================================
# 2. LOAD API KEYS AND DEFINE TOOLS
# ==============================================================================
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
tavily_api_key = os.getenv("TAVILY_API_KEY")
if not hf_token or not tavily_api_key:
# This will show a clear error in the logs if keys are missing
raise ValueError("HF_TOKEN or TAVILY_API_KEY not set. Please add them to your Space secrets.")
os.environ["TAVILY_API_KEY"] = tavily_api_key
# The agent's tools
tools = [TavilySearchResults(max_results=3, description="A search engine for finding up-to-date information on the web."), PythonREPLTool()]
tool_node = ToolNode(tools)
# ==============================================================================
# 3. CONFIGURE THE LLM (THE "BRAIN")
# ==============================================================================
# The model we'll use as the agent's brain
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# The system prompt gives the agent its mission and instructions
SYSTEM_PROMPT = """You are a highly capable AI agent named 'GAIA-Solver'. Your mission is to accurately answer complex questions.
**Your Instructions:**
1. **Analyze:** Carefully read the user's question to understand all parts of what is being asked.
2. **Plan:** Think step-by-step. Break the problem into smaller tasks. Decide which tool is best for each task. (e.g., use 'tavily_search_results_json' for web searches, use 'python_repl' for calculations or code execution).
3. **Execute:** Call ONE tool at a time.
4. **Observe & Reason:** After getting a tool's result, observe it. Decide if you have the final answer or if you need to use another tool.
5. **Final Answer:** Once you are confident, provide a clear, direct, and concise final answer. Do not include your thought process in the final answer.
"""
# Initialize the LLM endpoint
llm = HuggingFaceEndpoint(
repo_id=repo_id,
huggingfacehub_api_token=hf_token,
temperature=0, # Set to 0 for deterministic, less random output
max_new_tokens=2048,
)
# ==============================================================================
# 4. BUILD THE LANGGRAPH AGENT
# ==============================================================================
# Define the Agent's State (its memory)
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], lambda x, y: x + y]
# This is a more robust way to combine the prompt, model, and tool binding
# It ensures the system prompt is always used.
llm_with_tools = llm.bind_tools(tools)
# Define the Agent Node
def agent_node(state):
# Get the last message to pass to the model
last_message = state['messages'][-1]
# Prepend the system prompt to every call
prompt_with_system = [
HumanMessage(content=SYSTEM_PROMPT, name="system_prompt"),
last_message
]
response = llm_with_tools.invoke(prompt_with_system)
return {"messages": [response]}
# Define the Edge Logic
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "tools" # Route to the tool node
return END # End the process
# Assemble the graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{"tools": "tools", "end": END},
)
workflow.add_edge("tools", "agent")
# Compile the graph into a runnable app
app = workflow.compile()
# ==============================================================================
# 5. THE BASICAGENT CLASS (FOR THE TEST HARNESS)
# This MUST be at the end, after `app` is defined.
# ==============================================================================
class BasicAgent:
"""
This is the agent class that the GAIA test harness will use.
"""
def __init__(self):
# The compiled LangGraph app is our agent executor
self.agent_executor = app
def run(self, question: str) -> str:
"""
This method is called by the test script with each question.
It runs the LangGraph agent and returns the final answer.
"""
print(f"Agent received question (first 80 chars): {question[:80]}...")
try:
# Format the input for our graph
inputs = {"messages": [HumanMessage(content=question)]}
# Stream the response to get the final answer
final_response = ""
for s in self.agent_executor.stream(inputs, {"recursion_limit": 15}):
if "agent" in s:
# The final answer is the content of the last message from the agent node
if s["agent"]["messages"][-1].content:
final_response = s["agent"]["messages"][-1].content
# A fallback in case the agent finishes without a clear message
if not final_response:
final_response = "Agent finished but did not produce a final answer."
print(f"Agent returning final answer (first 80 chars): {final_response[:80]}...")
return final_response
except Exception as e:
print(f"An error occurred in agent execution: {e}")
return f"Error: {e}"