Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import logging | |
| import urllib.parse as urlparse | |
| import io | |
| import contextlib | |
| import re | |
| import json | |
| from functools import lru_cache, wraps | |
| from typing import Optional, Dict, Any, List | |
| from dotenv import load_dotenv | |
| from requests.exceptions import RequestException | |
| import serpapi | |
| from llama_index.core import VectorStoreIndex, download_loader | |
| from llama_index.core.schema import Document | |
| from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound | |
| # --- Correctly import the specific tools from smolagents --- | |
| from smolagents import ( | |
| CodeAgent, | |
| InferenceClientModel, | |
| ToolCallingAgent, | |
| GoogleSearchTool, | |
| tool, | |
| ) | |
| # --- Configuration and Setup --- | |
| def configure_logging(): | |
| """Sets up detailed logging configuration for debugging.""" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| def load_api_keys() -> Dict[str, Optional[str]]: | |
| """Loads API keys from environment variables.""" | |
| load_dotenv() | |
| keys = { | |
| 'together': os.getenv('TOGETHER_API_KEY'), | |
| 'serpapi': os.getenv('SERPAPI_API_KEY'), | |
| } | |
| for key_name, key_value in keys.items(): | |
| if key_value: | |
| logging.info(f"β {key_name.upper()} API key loaded") | |
| else: | |
| logging.warning(f"β οΈ {key_name.upper()} API key not found") | |
| if not keys['together']: | |
| raise ValueError("TOGETHER_API_KEY is required but not found.") | |
| return keys | |
| # --- Custom Exceptions --- | |
| class SerpApiClientException(Exception): pass | |
| class YouTubeTranscriptApiError(Exception): pass | |
| # --- Enhanced Decorators --- | |
| def retry(max_retries=3, initial_delay=1, backoff=2): | |
| """A robust retry decorator with exponential backoff.""" | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| delay = initial_delay | |
| retryable_exceptions = (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound) | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| return func(*args, **kwargs) | |
| except retryable_exceptions as e: | |
| if attempt == max_retries: | |
| logging.error(f"{func.__name__} failed after {attempt} attempts: {e}") | |
| # BUG FIX: Return a descriptive error string instead of raising, which could crash the agent. | |
| return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}" | |
| logging.warning(f"Attempt {attempt} for {func.__name__} failed: {e}. Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| delay *= backoff | |
| except Exception as e: | |
| logging.error(f"{func.__name__} failed with a non-retryable error: {e}") | |
| return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}" | |
| return wrapper | |
| return decorator | |
| # --- Enhanced Helper Functions --- | |
| def extract_video_id(url_or_id: str) -> Optional[str]: | |
| """Extracts YouTube video ID from various URL formats.""" | |
| if not url_or_id: return None | |
| url_or_id = url_or_id.strip() | |
| if re.match(r'^[a-zA-Z0-9_-]{11}$', url_or_id): | |
| return url_or_id | |
| patterns = [ | |
| r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/|youtube-nocookie\.com/embed/)([a-zA-Z0-9_-]{11})' | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, url_or_id) | |
| if match: | |
| return match.group(1) | |
| return None | |
| def clean_text_output(text: str) -> str: | |
| """Cleans and normalizes text output.""" | |
| if not text: return "" | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # --- Answer Formatting and Extraction (CRITICAL FOR GAIA) --- | |
| def extract_final_answer(response: str) -> str: | |
| """Extracts the final answer from the agent's full response string.""" | |
| if not response: return "" | |
| match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| # Fallback if the pattern is missing | |
| lines = response.strip().split('\n') | |
| return lines[-1].strip() | |
| def normalize_answer_format(answer: str) -> str: | |
| """Normalizes the extracted answer to meet strict GAIA formatting requirements.""" | |
| if not answer: return "" | |
| answer = answer.strip().rstrip('.') | |
| # Auto-detect type | |
| is_list = ',' in answer and len(answer.split(',')) > 1 | |
| is_numeric = False | |
| try: | |
| # Check if it can be converted to a float (handles integers and floats) | |
| float(answer.replace(',', '')) | |
| is_numeric = not is_list # A list of numbers is a list, not a single number | |
| except ValueError: | |
| is_numeric = False | |
| if is_numeric: | |
| return re.sub(r'[,$%]', '', answer).strip() | |
| elif is_list: | |
| elements = [elem.strip() for elem in answer.split(',')] | |
| # Recursively normalize each element of the list | |
| normalized_elements = [normalize_answer_format(elem) for elem in elements] | |
| return ', '.join(normalized_elements) | |
| else: # Is a string | |
| # Expand common abbreviations | |
| abbreviations = {'NYC': 'New York City', 'LA': 'Los Angeles', 'SF': 'San Francisco'} | |
| return abbreviations.get(answer.upper(), answer) | |
| # --- Agent Wrapper for GAIA Compliance --- | |
| def create_gaia_agent_wrapper(agent: CodeAgent): | |
| """ | |
| Creates a callable wrapper around the agent to enforce GAIA answer formatting. | |
| This is a key component for ensuring the final output is compliant. | |
| """ | |
| def gaia_compliant_agent(question: str) -> str: | |
| logging.info(f"Received question for GAIA compliant agent: '{question}'") | |
| full_response = agent.run(question) | |
| logging.info(f"Agent raw response:\n---\n{full_response}\n---") | |
| final_answer = extract_final_answer(full_response) | |
| normalized_answer = normalize_answer_format(final_answer) | |
| logging.info(f"Extracted final answer: '{final_answer}'") | |
| logging.info(f"Normalized answer for submission: '{normalized_answer}'") | |
| return normalized_answer | |
| return gaia_compliant_agent | |
| # --- Main Agent Initialization --- | |
| def initialize_agent(): | |
| """Initializes the enhanced multi-disciplinary agent for the GAIA benchmark.""" | |
| configure_logging() | |
| logging.info("π Starting GAIA agent initialization...") | |
| try: | |
| api_keys = load_api_keys() | |
| except ValueError as e: | |
| logging.error(f"FATAL: {e}") | |
| return None | |
| # --- Tool Definitions --- | |
| def get_webpage_index(url: str) -> VectorStoreIndex: | |
| logging.info(f"π Indexing webpage: {url}") | |
| loader = download_loader("BeautifulSoupWebReader")() | |
| docs = loader.load_data(urls=[url]) | |
| if not docs or not any(len(doc.text.strip()) > 50 for doc in docs): | |
| raise ValueError(f"No substantial content found in {url}") | |
| return VectorStoreIndex.from_documents(docs) | |
| def enhanced_python_execution(code: str) -> str: | |
| """Executes Python code in a restricted environment and returns the output.""" | |
| logging.info(f"π Executing Python code: {code[:200]}...") | |
| stdout_capture = io.StringIO() | |
| try: | |
| # ENHANCEMENT: Restrict built-ins for better security | |
| safe_globals = { | |
| "requests": __import__("requests"), "pd": __import__("pandas"), "np": __import__("numpy"), | |
| "datetime": __import__("datetime"), "math": __import__("math"), "re": __import__("re"), | |
| "json": __import__("json"), "collections": __import__("collections") | |
| } | |
| restricted_builtins = { | |
| 'print': print, 'len': len, 'range': range, 'str': str, 'int': int, 'float': float, | |
| 'list': list, 'dict': dict, 'set': set, 'tuple': tuple, 'max': max, 'min': min, 'sum': sum, | |
| 'sorted': sorted, 'round': round | |
| } | |
| with contextlib.redirect_stdout(stdout_capture): | |
| exec(code, {"__builtins__": restricted_builtins}, safe_globals) | |
| result = stdout_capture.getvalue().strip() | |
| return result if result else "Code executed successfully with no output." | |
| except Exception as e: | |
| error_msg = f"Code execution error: {e}" | |
| logging.error(error_msg) | |
| return error_msg | |
| # --- Model and Agent Setup --- | |
| try: | |
| model = InferenceClientModel( | |
| model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo", | |
| token=api_keys['together'], | |
| provider="together" | |
| ) | |
| logging.info("β Primary model (Llama 3.1 70B) loaded successfully") | |
| except Exception as e: | |
| logging.warning(f"β οΈ Failed to load primary model, falling back. Error: {e}") | |
| model = InferenceClientModel( | |
| model_id="Qwen/Qwen2.5-7B-Instruct", | |
| token=api_keys['together'], | |
| provider="together" | |
| ) | |
| logging.info("β Fallback model (Qwen 2.5 7B) loaded successfully") | |
| google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None | |
| tools_list = [tool for tool in [google_search_tool, enhanced_python_execution] if tool] | |
| manager = CodeAgent( | |
| model=model, | |
| tools=tools_list, | |
| instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer. | |
| **STRATEGY:** | |
| 1. **Analyze**: Break down the user's question into steps. | |
| 2. **Execute**: Use the provided tools (`GoogleSearchTool`, `enhanced_python_execution`) to find the information or perform calculations. | |
| 3. **Synthesize**: Combine the results of your tool use to form a final answer. | |
| 4. **Format**: Present your final answer clearly at the end of your response, prefixed with `FINAL ANSWER:`. | |
| **CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. The text that follows this prefix is what will be submitted. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers. | |
| - For "What is 2*21?": `FINAL ANSWER: 42` | |
| - For "Capital of France?": `FINAL ANSWER: Paris` | |
| - For "What are the first three even numbers?": `FINAL ANSWER: 2, 4, 6` | |
| """ | |
| ) | |
| logging.info("π― GAIA agent initialized successfully!") | |
| # BUG FIX: Return the wrapped, compliant agent instead of the raw manager. | |
| return create_gaia_agent_wrapper(manager) | |
| # --- Main Execution Block for Local Testing --- | |
| def main(): | |
| """Tests the agent with sample GAIA-style questions.""" | |
| configure_logging() | |
| logging.info("π§ͺ Starting local agent testing...") | |
| agent = initialize_agent() | |
| if not agent: | |
| logging.critical("π₯ Agent initialization failed. Exiting.") | |
| return | |
| test_questions = [ | |
| "What is 15! / (12! * 3!)?", | |
| "In what year was the Python programming language first released?", | |
| "What is the square root of 2025?", | |
| ] | |
| for i, question in enumerate(test_questions, 1): | |
| logging.info(f"\n{'='*60}\nπ Test Question {i}: {question}\n{'='*60}") | |
| start_time = time.time() | |
| # BUG FIX: Call the agent wrapper directly, not agent.run() | |
| final_answer = agent(question) | |
| elapsed_time = time.time() - start_time | |
| logging.info(f"β Submitted Answer: {final_answer}") | |
| logging.info(f"β±οΈ Execution time: {elapsed_time:.2f} seconds") | |
| time.sleep(1) | |
| logging.info(f"\n{'='*60}\nπ Testing complete!\n{'='*60}") | |
| if __name__ == "__main__": | |
| main() |