import os import time import logging import urllib.parse as urlparse import io import contextlib import re import json from functools import lru_cache, wraps from typing import Optional, Dict, Any, List from dotenv import load_dotenv from requests.exceptions import RequestException import serpapi from llama_index.core import VectorStoreIndex, download_loader from llama_index.core.schema import Document from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound # --- Correctly import the specific tools from smolagents --- from smolagents import ( CodeAgent, InferenceClientModel, ToolCallingAgent, GoogleSearchTool, tool, ) # --- Configuration and Setup --- def configure_logging(): """Sets up detailed logging configuration for debugging.""" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) def load_api_keys() -> Dict[str, Optional[str]]: """Loads API keys from environment variables.""" load_dotenv() keys = { 'together': os.getenv('TOGETHER_API_KEY'), 'serpapi': os.getenv('SERPAPI_API_KEY'), } for key_name, key_value in keys.items(): if key_value: logging.info(f"โœ… {key_name.upper()} API key loaded") else: logging.warning(f"โš ๏ธ {key_name.upper()} API key not found") if not keys['together']: raise ValueError("TOGETHER_API_KEY is required but not found.") return keys # --- Custom Exceptions --- class SerpApiClientException(Exception): pass class YouTubeTranscriptApiError(Exception): pass # --- Enhanced Decorators --- def retry(max_retries=3, initial_delay=1, backoff=2): """A robust retry decorator with exponential backoff.""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): delay = initial_delay retryable_exceptions = (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound) for attempt in range(1, max_retries + 1): try: return func(*args, **kwargs) except retryable_exceptions as e: if attempt == max_retries: logging.error(f"{func.__name__} failed after {attempt} attempts: {e}") # BUG FIX: Return a descriptive error string instead of raising, which could crash the agent. return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}" logging.warning(f"Attempt {attempt} for {func.__name__} failed: {e}. Retrying in {delay} seconds...") time.sleep(delay) delay *= backoff except Exception as e: logging.error(f"{func.__name__} failed with a non-retryable error: {e}") return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}" return wrapper return decorator # --- Enhanced Helper Functions --- def extract_video_id(url_or_id: str) -> Optional[str]: """Extracts YouTube video ID from various URL formats.""" if not url_or_id: return None url_or_id = url_or_id.strip() if re.match(r'^[a-zA-Z0-9_-]{11}$', url_or_id): return url_or_id patterns = [ r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/|youtube-nocookie\.com/embed/)([a-zA-Z0-9_-]{11})' ] for pattern in patterns: match = re.search(pattern, url_or_id) if match: return match.group(1) return None def clean_text_output(text: str) -> str: """Cleans and normalizes text output.""" if not text: return "" text = re.sub(r'\s+', ' ', text).strip() return text # --- Answer Formatting and Extraction (CRITICAL FOR GAIA) --- def extract_final_answer(response: str) -> str: """Extracts the final answer from the agent's full response string.""" if not response: return "" match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL) if match: return match.group(1).strip() # Fallback if the pattern is missing lines = response.strip().split('\n') return lines[-1].strip() def normalize_answer_format(answer: str) -> str: """Normalizes the extracted answer to meet strict GAIA formatting requirements.""" if not answer: return "" answer = answer.strip().rstrip('.') # Auto-detect type is_list = ',' in answer and len(answer.split(',')) > 1 is_numeric = False try: # Check if it can be converted to a float (handles integers and floats) float(answer.replace(',', '')) is_numeric = not is_list # A list of numbers is a list, not a single number except ValueError: is_numeric = False if is_numeric: return re.sub(r'[,$%]', '', answer).strip() elif is_list: elements = [elem.strip() for elem in answer.split(',')] # Recursively normalize each element of the list normalized_elements = [normalize_answer_format(elem) for elem in elements] return ', '.join(normalized_elements) else: # Is a string # Expand common abbreviations abbreviations = {'NYC': 'New York City', 'LA': 'Los Angeles', 'SF': 'San Francisco'} return abbreviations.get(answer.upper(), answer) # --- Agent Wrapper for GAIA Compliance --- def create_gaia_agent_wrapper(agent: CodeAgent): """ Creates a callable wrapper around the agent to enforce GAIA answer formatting. This is a key component for ensuring the final output is compliant. """ def gaia_compliant_agent(question: str) -> str: logging.info(f"Received question for GAIA compliant agent: '{question}'") full_response = agent.run(question) logging.info(f"Agent raw response:\n---\n{full_response}\n---") final_answer = extract_final_answer(full_response) normalized_answer = normalize_answer_format(final_answer) logging.info(f"Extracted final answer: '{final_answer}'") logging.info(f"Normalized answer for submission: '{normalized_answer}'") return normalized_answer return gaia_compliant_agent # --- Main Agent Initialization --- def initialize_agent(): """Initializes the enhanced multi-disciplinary agent for the GAIA benchmark.""" configure_logging() logging.info("๐Ÿš€ Starting GAIA agent initialization...") try: api_keys = load_api_keys() except ValueError as e: logging.error(f"FATAL: {e}") return None # --- Tool Definitions --- @lru_cache(maxsize=64) @retry def get_webpage_index(url: str) -> VectorStoreIndex: logging.info(f"๐Ÿ“„ Indexing webpage: {url}") loader = download_loader("BeautifulSoupWebReader")() docs = loader.load_data(urls=[url]) if not docs or not any(len(doc.text.strip()) > 50 for doc in docs): raise ValueError(f"No substantial content found in {url}") return VectorStoreIndex.from_documents(docs) @tool def enhanced_python_execution(code: str) -> str: """Executes Python code in a restricted environment and returns the output.""" logging.info(f"๐Ÿ Executing Python code: {code[:200]}...") stdout_capture = io.StringIO() try: # ENHANCEMENT: Restrict built-ins for better security safe_globals = { "requests": __import__("requests"), "pd": __import__("pandas"), "np": __import__("numpy"), "datetime": __import__("datetime"), "math": __import__("math"), "re": __import__("re"), "json": __import__("json"), "collections": __import__("collections") } restricted_builtins = { 'print': print, 'len': len, 'range': range, 'str': str, 'int': int, 'float': float, 'list': list, 'dict': dict, 'set': set, 'tuple': tuple, 'max': max, 'min': min, 'sum': sum, 'sorted': sorted, 'round': round } with contextlib.redirect_stdout(stdout_capture): exec(code, {"__builtins__": restricted_builtins}, safe_globals) result = stdout_capture.getvalue().strip() return result if result else "Code executed successfully with no output." except Exception as e: error_msg = f"Code execution error: {e}" logging.error(error_msg) return error_msg # --- Model and Agent Setup --- try: model = InferenceClientModel( model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo", token=api_keys['together'], provider="together" ) logging.info("โœ… Primary model (Llama 3.1 70B) loaded successfully") except Exception as e: logging.warning(f"โš ๏ธ Failed to load primary model, falling back. Error: {e}") model = InferenceClientModel( model_id="Qwen/Qwen2.5-7B-Instruct", token=api_keys['together'], provider="together" ) logging.info("โœ… Fallback model (Qwen 2.5 7B) loaded successfully") google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None tools_list = [tool for tool in [google_search_tool, enhanced_python_execution] if tool] manager = CodeAgent( model=model, tools=tools_list, instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer. **STRATEGY:** 1. **Analyze**: Break down the user's question into steps. 2. **Execute**: Use the provided tools (`GoogleSearchTool`, `enhanced_python_execution`) to find the information or perform calculations. 3. **Synthesize**: Combine the results of your tool use to form a final answer. 4. **Format**: Present your final answer clearly at the end of your response, prefixed with `FINAL ANSWER:`. **CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. The text that follows this prefix is what will be submitted. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers. - For "What is 2*21?": `FINAL ANSWER: 42` - For "Capital of France?": `FINAL ANSWER: Paris` - For "What are the first three even numbers?": `FINAL ANSWER: 2, 4, 6` """ ) logging.info("๐ŸŽฏ GAIA agent initialized successfully!") # BUG FIX: Return the wrapped, compliant agent instead of the raw manager. return create_gaia_agent_wrapper(manager) # --- Main Execution Block for Local Testing --- def main(): """Tests the agent with sample GAIA-style questions.""" configure_logging() logging.info("๐Ÿงช Starting local agent testing...") agent = initialize_agent() if not agent: logging.critical("๐Ÿ’ฅ Agent initialization failed. Exiting.") return test_questions = [ "What is 15! / (12! * 3!)?", "In what year was the Python programming language first released?", "What is the square root of 2025?", ] for i, question in enumerate(test_questions, 1): logging.info(f"\n{'='*60}\n๐Ÿ” Test Question {i}: {question}\n{'='*60}") start_time = time.time() # BUG FIX: Call the agent wrapper directly, not agent.run() final_answer = agent(question) elapsed_time = time.time() - start_time logging.info(f"โœ… Submitted Answer: {final_answer}") logging.info(f"โฑ๏ธ Execution time: {elapsed_time:.2f} seconds") time.sleep(1) logging.info(f"\n{'='*60}\n๐Ÿ Testing complete!\n{'='*60}") if __name__ == "__main__": main()