jesusgj
Modified files
a1e218b
raw
history blame
10.7 kB
import os
import time
import logging
import urllib.parse as urlparse
import io
import contextlib
import re
from functools import lru_cache, wraps
from typing import Optional, Dict, Any
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
from dotenv import load_dotenv
from requests.exceptions import RequestException
import serpapi
import wikipedia
from llama_index.core import VectorStoreIndex, download_loader
from llama_index.core.schema import Document
from smolagents import (
CodeAgent,
InferenceClientModel,
GoogleSearchTool,
tool,
)
# --- Configuration and Setup ---
def configure_logging():
"""Sets up detailed logging configuration."""
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
def load_api_keys() -> Dict[str, Optional[str]]:
"""Loads API keys from environment variables."""
load_dotenv()
keys = {'together': os.getenv('TOGETHER_API_KEY'), 'serpapi': os.getenv('SERPAPI_API_KEY')}
if not keys['together']: raise ValueError("TOGETHER_API_KEY is required but not found.")
return keys
# --- Custom Exceptions ---
class SerpApiClientException(Exception): pass
class YouTubeTranscriptApiError(Exception): pass
# --- Decorators ---
def retry(max_retries=3, initial_delay=1, backoff=2):
"""A robust retry decorator with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_retries + 1):
try:
return func(*args, **kwargs)
except (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound) as e:
if attempt == max_retries:
logging.error(f"{func.__name__} failed after {attempt} attempts: {e}")
return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}"
time.sleep(initial_delay * (backoff ** (attempt - 1)))
except Exception as e:
logging.error(f"{func.__name__} failed with a non-retryable error: {e}")
return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}"
return wrapper
return decorator
# --- Answer Formatting and Extraction (CRITICAL FOR GAIA) ---
def extract_final_answer(response: str) -> str:
"""Extracts the final answer from the agent's full response string."""
if not response: return ""
match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL)
if match: return match.group(1).strip()
lines = response.strip().split('\n')
return lines[-1].strip() if lines else ""
def normalize_answer_format(answer: str) -> str:
"""Normalizes the extracted answer to meet strict GAIA formatting requirements."""
if not answer: return ""
answer = answer.strip().rstrip('.')
is_list = ',' in answer and len(answer.split(',')) > 1
try:
is_numeric = not is_list and float(answer.replace(',', '')) is not None
except ValueError:
is_numeric = False
if is_numeric: return re.sub(r'[,$%]', '', answer).strip()
if is_list:
elements = [normalize_answer_format(elem.strip()) for elem in answer.split(',')]
return ', '.join(elements)
return answer
# --- Agent Wrapper for GAIA Compliance ---
def create_gaia_agent_wrapper(agent: CodeAgent):
"""Creates a callable wrapper around the agent to enforce GAIA answer formatting."""
def gaia_compliant_agent(question: str) -> str:
logging.info(f"Received question for GAIA compliant agent: '{question}'")
full_response = agent.run(question)
logging.info(f"Agent raw response:\n---\n{full_response}\n---")
final_answer = extract_final_answer(full_response)
normalized_answer = normalize_answer_format(final_answer)
logging.info(f"Normalized answer for submission: '{normalized_answer}'")
return normalized_answer
return gaia_compliant_agent
# --- Main Agent Initialization ---
def initialize_agent():
"""Initializes the enhanced multi-disciplinary agent for the GAIA benchmark."""
configure_logging()
logging.info("πŸš€ Starting GAIA agent initialization...")
try:
api_keys = load_api_keys()
except ValueError as e:
logging.error(f"FATAL: {e}")
return None
# --- Tool Definitions for the Agent ---
@tool
@retry
def query_webpage(url: str, query: str) -> str:
"""Extracts specific information from a webpage by asking a targeted question."""
logging.info(f"πŸ“„ Querying webpage: {url}")
loader = download_loader("BeautifulSoupWebReader")()
docs = loader.load_data(urls=[url])
if not docs: raise ValueError(f"No content could be extracted from {url}")
index = VectorStoreIndex.from_documents(docs)
query_engine = index.as_query_engine(response_mode="tree_summarize")
response = query_engine.query(query)
return str(response)
@tool
@retry
def query_youtube_video(video_url: str, query: str) -> str:
"""Extracts specific information from a YouTube video transcript."""
logging.info(f"🎬 Querying YouTube video: {video_url}")
video_id_match = re.search(r'(?:v=|\/)([a-zA-Z0-9_-]{11}).*', video_url)
if not video_id_match: return "Error: Invalid YouTube URL."
video_id = video_id_match.group(1)
transcript = YouTubeTranscriptApi.get_transcript(video_id)
doc = Document(text=' '.join([t['text'] for t in transcript]))
index = VectorStoreIndex.from_documents([doc])
query_engine = index.as_query_engine()
response = query_engine.query(query)
return str(response)
@tool
@retry
def wikipedia_search(query: str) -> str:
"""Searches Wikipedia for a given query and returns a summary."""
try:
return wikipedia.summary(query, sentences=5)
except wikipedia.exceptions.PageError:
return f"No Wikipedia page found for '{query}'."
except wikipedia.exceptions.DisambiguationError as e:
return f"Ambiguous query '{query}'. Options: {e.options[:3]}"
except Exception as e:
return f"An error occurred during Wikipedia search: {e}"
# --- Model and Agent Setup ---
try:
model = InferenceClientModel(model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo", token=api_keys['together'], provider="together")
logging.info("βœ… Primary model (Llama 3.1 70B) loaded successfully")
except Exception as e:
logging.warning(f"⚠️ Failed to load primary model, falling back. Error: {e}")
model = InferenceClientModel(model_id="Qwen/Qwen2.5-7B-Instruct", token=api_keys['together'], provider="together")
logging.info("βœ… Fallback model (Qwen 2.5 7B) loaded successfully")
google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None
# LOGICAL FIX: Create a single, powerful CodeAgent with all necessary tools.
tools_list = [tool for tool in [google_search_tool, query_webpage, query_youtube_video, wikipedia_search] if tool]
agent = CodeAgent(
model=model,
tools=tools_list,
instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer by writing and executing Python code.
**STRATEGY:**
You have a powerful toolkit. You can write and execute any Python code you need. You also have access to pre-defined tools that you can call from within your code.
1. **Analyze**: Break down the user's question into logical steps.
2. **Plan**: Decide if you need to search the web, query a webpage, or perform a calculation.
3. **Execute**: Write a Python script to perform the steps.
* For web searches, use `GoogleSearchTool()`.
* For Wikipedia lookups, use `wikipedia_search()`.
* For complex calculations or data manipulation, write the Python code directly.
* To query a specific webpage, use `query_webpage()`.
**HOW TO USE TOOLS IN YOUR CODE:**
To solve a problem, you will write a Python code block that calls the necessary tools.
*Example 1: Simple Calculation*
```python
# The user wants to know 15! / (12! * 3!)
import math
result = math.factorial(15) / (math.factorial(12) * math.factorial(3))
print(int(result))
```
*Example 2: Multi-step question involving web search*
```python
# Find the birth date of the author of 'Pride and Prejudice'
author_name = GoogleSearchTool(query="author of Pride and Prejudice")
# Let's assume the tool returns "Jane Austen"
birth_date_info = wikipedia_search(query="Jane Austen birth date")
print(birth_date_info)
```
**CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. This is the only part of your response that will be graded. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers.
"""
)
logging.info("🎯 GAIA agent with unified CodeAgent architecture initialized successfully!")
return create_gaia_agent_wrapper(agent)
# --- Main Execution Block for Local Testing ---
def main():
"""Tests the agent with sample GAIA-style questions."""
configure_logging()
logging.info("πŸ§ͺ Starting local agent testing...")
agent = initialize_agent()
if not agent:
logging.critical("πŸ’₯ Agent initialization failed. Exiting.")
return
test_questions = [
"What is 15! / (12! * 3!)?",
"In what year was the Python programming language first released?",
"What is the square root of 2025?",
]
for i, question in enumerate(test_questions, 1):
logging.info(f"\n{'='*60}\nπŸ” Test Question {i}: {question}\n{'='*60}")
start_time = time.time()
final_answer = agent(question)
elapsed_time = time.time() - start_time
logging.info(f"βœ… Submitted Answer: {final_answer}")
logging.info(f"⏱️ Execution time: {elapsed_time:.2f} seconds")
time.sleep(1)
logging.info(f"\n{'='*60}\n🏁 Testing complete!\n{'='*60}")
if __name__ == "__main__":
main()