File size: 12,044 Bytes
8cf27dc
07e3a65
892cc72
 
 
 
c9e0cf1
065eebf
892cc72
065eebf
892cc72
8cf27dc
 
892cc72
9fb8366
07e3a65
c1b14e1
abbd59c
065eebf
 
 
 
 
 
 
 
8cf27dc
892cc72
c2d4de7
892cc72
065eebf
892cc72
 
 
 
 
 
19b4b62
 
8cf27dc
892cc72
 
 
 
065eebf
 
 
 
 
 
 
19b4b62
892cc72
8cf27dc
c9e0cf1
19b4b62
 
065eebf
 
892cc72
 
19b4b62
892cc72
 
 
 
19b4b62
892cc72
 
 
 
 
 
19b4b62
 
892cc72
 
 
 
 
19b4b62
892cc72
 
 
065eebf
c9e0cf1
 
19b4b62
 
065eebf
c9e0cf1
 
 
19b4b62
c9e0cf1
 
 
 
 
 
 
065eebf
19b4b62
 
 
065eebf
 
19b4b62
065eebf
 
19b4b62
 
 
065eebf
19b4b62
065eebf
19b4b62
065eebf
19b4b62
065eebf
19b4b62
 
 
065eebf
19b4b62
892cc72
19b4b62
 
 
892cc72
19b4b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892cc72
19b4b62
892cc72
19b4b62
 
 
 
 
 
 
 
 
065eebf
19b4b62
 
065eebf
19b4b62
 
065eebf
19b4b62
 
065eebf
19b4b62
065eebf
 
19b4b62
065eebf
 
 
 
 
19b4b62
 
065eebf
 
19b4b62
065eebf
19b4b62
 
 
 
 
 
 
 
 
892cc72
065eebf
 
19b4b62
 
065eebf
 
19b4b62
 
 
 
 
 
 
 
 
 
 
 
065eebf
 
19b4b62
 
065eebf
 
 
 
 
 
 
 
 
19b4b62
065eebf
 
 
19b4b62
065eebf
19b4b62
065eebf
19b4b62
065eebf
 
 
19b4b62
065eebf
19b4b62
065eebf
19b4b62
065eebf
 
 
19b4b62
 
 
 
 
 
 
 
065eebf
19b4b62
 
 
 
 
065eebf
 
19b4b62
 
 
 
065eebf
 
 
 
19b4b62
065eebf
19b4b62
065eebf
19b4b62
 
 
 
065eebf
19b4b62
 
 
 
 
 
 
 
 
065eebf
19b4b62
 
065eebf
19b4b62
 
 
 
 
 
892cc72
 
5b546f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import os
import time
import logging
import urllib.parse as urlparse
import io
import contextlib
import re
import json
from functools import lru_cache, wraps
from typing import Optional, Dict, Any, List

from dotenv import load_dotenv
from requests.exceptions import RequestException
import serpapi
from llama_index.core import VectorStoreIndex, download_loader
from llama_index.core.schema import Document
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound

# --- Correctly import the specific tools from smolagents ---
from smolagents import (
    CodeAgent,
    InferenceClientModel,
    ToolCallingAgent,
    GoogleSearchTool,
    tool,
)

# --- Configuration and Setup ---

def configure_logging():
    """Sets up detailed logging configuration for debugging."""
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )

def load_api_keys() -> Dict[str, Optional[str]]:
    """Loads API keys from environment variables."""
    load_dotenv()
    keys = {
        'together': os.getenv('TOGETHER_API_KEY'),
        'serpapi': os.getenv('SERPAPI_API_KEY'),
    }
    for key_name, key_value in keys.items():
        if key_value:
            logging.info(f"βœ… {key_name.upper()} API key loaded")
        else:
            logging.warning(f"⚠️ {key_name.upper()} API key not found")
    
    if not keys['together']:
        raise ValueError("TOGETHER_API_KEY is required but not found.")
    return keys

# --- Custom Exceptions ---
class SerpApiClientException(Exception): pass
class YouTubeTranscriptApiError(Exception): pass

# --- Enhanced Decorators ---

def retry(max_retries=3, initial_delay=1, backoff=2):
    """A robust retry decorator with exponential backoff."""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            delay = initial_delay
            retryable_exceptions = (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound)
            for attempt in range(1, max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except retryable_exceptions as e:
                    if attempt == max_retries:
                        logging.error(f"{func.__name__} failed after {attempt} attempts: {e}")
                        # BUG FIX: Return a descriptive error string instead of raising, which could crash the agent.
                        return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}"
                    logging.warning(f"Attempt {attempt} for {func.__name__} failed: {e}. Retrying in {delay} seconds...")
                    time.sleep(delay)
                    delay *= backoff
                except Exception as e:
                    logging.error(f"{func.__name__} failed with a non-retryable error: {e}")
                    return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}"
        return wrapper
    return decorator

# --- Enhanced Helper Functions ---

def extract_video_id(url_or_id: str) -> Optional[str]:
    """Extracts YouTube video ID from various URL formats."""
    if not url_or_id: return None
    url_or_id = url_or_id.strip()
    if re.match(r'^[a-zA-Z0-9_-]{11}$', url_or_id):
        return url_or_id
    patterns = [
        r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/|youtube-nocookie\.com/embed/)([a-zA-Z0-9_-]{11})'
    ]
    for pattern in patterns:
        match = re.search(pattern, url_or_id)
        if match:
            return match.group(1)
    return None

def clean_text_output(text: str) -> str:
    """Cleans and normalizes text output."""
    if not text: return ""
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# --- Answer Formatting and Extraction (CRITICAL FOR GAIA) ---

def extract_final_answer(response: str) -> str:
    """Extracts the final answer from the agent's full response string."""
    if not response: return ""
    match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL)
    if match:
        return match.group(1).strip()
    
    # Fallback if the pattern is missing
    lines = response.strip().split('\n')
    return lines[-1].strip()

def normalize_answer_format(answer: str) -> str:
    """Normalizes the extracted answer to meet strict GAIA formatting requirements."""
    if not answer: return ""
    
    answer = answer.strip().rstrip('.')
    
    # Auto-detect type
    is_list = ',' in answer and len(answer.split(',')) > 1
    is_numeric = False
    try:
        # Check if it can be converted to a float (handles integers and floats)
        float(answer.replace(',', ''))
        is_numeric = not is_list # A list of numbers is a list, not a single number
    except ValueError:
        is_numeric = False

    if is_numeric:
        return re.sub(r'[,$%]', '', answer).strip()
    elif is_list:
        elements = [elem.strip() for elem in answer.split(',')]
        # Recursively normalize each element of the list
        normalized_elements = [normalize_answer_format(elem) for elem in elements]
        return ', '.join(normalized_elements)
    else: # Is a string
        # Expand common abbreviations
        abbreviations = {'NYC': 'New York City', 'LA': 'Los Angeles', 'SF': 'San Francisco'}
        return abbreviations.get(answer.upper(), answer)

# --- Agent Wrapper for GAIA Compliance ---

def create_gaia_agent_wrapper(agent: CodeAgent):
    """
    Creates a callable wrapper around the agent to enforce GAIA answer formatting.
    This is a key component for ensuring the final output is compliant.
    """
    def gaia_compliant_agent(question: str) -> str:
        logging.info(f"Received question for GAIA compliant agent: '{question}'")
        full_response = agent.run(question)
        logging.info(f"Agent raw response:\n---\n{full_response}\n---")
        
        final_answer = extract_final_answer(full_response)
        normalized_answer = normalize_answer_format(final_answer)
        
        logging.info(f"Extracted final answer: '{final_answer}'")
        logging.info(f"Normalized answer for submission: '{normalized_answer}'")
        
        return normalized_answer
    return gaia_compliant_agent

# --- Main Agent Initialization ---

def initialize_agent():
    """Initializes the enhanced multi-disciplinary agent for the GAIA benchmark."""
    configure_logging()
    logging.info("πŸš€ Starting GAIA agent initialization...")
    
    try:
        api_keys = load_api_keys()
    except ValueError as e:
        logging.error(f"FATAL: {e}")
        return None

    # --- Tool Definitions ---

    @lru_cache(maxsize=64)
    @retry
    def get_webpage_index(url: str) -> VectorStoreIndex:
        logging.info(f"πŸ“„ Indexing webpage: {url}")
        loader = download_loader("BeautifulSoupWebReader")()
        docs = loader.load_data(urls=[url])
        if not docs or not any(len(doc.text.strip()) > 50 for doc in docs):
            raise ValueError(f"No substantial content found in {url}")
        return VectorStoreIndex.from_documents(docs)

    @tool
    def enhanced_python_execution(code: str) -> str:
        """Executes Python code in a restricted environment and returns the output."""
        logging.info(f"🐍 Executing Python code: {code[:200]}...")
        stdout_capture = io.StringIO()
        try:
            # ENHANCEMENT: Restrict built-ins for better security
            safe_globals = {
                "requests": __import__("requests"), "pd": __import__("pandas"), "np": __import__("numpy"),
                "datetime": __import__("datetime"), "math": __import__("math"), "re": __import__("re"),
                "json": __import__("json"), "collections": __import__("collections")
            }
            restricted_builtins = {
                'print': print, 'len': len, 'range': range, 'str': str, 'int': int, 'float': float,
                'list': list, 'dict': dict, 'set': set, 'tuple': tuple, 'max': max, 'min': min, 'sum': sum,
                'sorted': sorted, 'round': round
            }
            with contextlib.redirect_stdout(stdout_capture):
                exec(code, {"__builtins__": restricted_builtins}, safe_globals)
            
            result = stdout_capture.getvalue().strip()
            return result if result else "Code executed successfully with no output."
        except Exception as e:
            error_msg = f"Code execution error: {e}"
            logging.error(error_msg)
            return error_msg

    # --- Model and Agent Setup ---
    
    try:
        model = InferenceClientModel(
            model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo",
            token=api_keys['together'],
            provider="together"
        )
        logging.info("βœ… Primary model (Llama 3.1 70B) loaded successfully")
    except Exception as e:
        logging.warning(f"⚠️ Failed to load primary model, falling back. Error: {e}")
        model = InferenceClientModel(
            model_id="Qwen/Qwen2.5-7B-Instruct",
            token=api_keys['together'],
            provider="together"
        )
        logging.info("βœ… Fallback model (Qwen 2.5 7B) loaded successfully")

    google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None

    tools_list = [tool for tool in [google_search_tool, enhanced_python_execution] if tool]

    manager = CodeAgent(
        model=model,
        tools=tools_list,
        instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer.
        
        **STRATEGY:**
        1.  **Analyze**: Break down the user's question into steps.
        2.  **Execute**: Use the provided tools (`GoogleSearchTool`, `enhanced_python_execution`) to find the information or perform calculations.
        3.  **Synthesize**: Combine the results of your tool use to form a final answer.
        4.  **Format**: Present your final answer clearly at the end of your response, prefixed with `FINAL ANSWER:`.

        **CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. The text that follows this prefix is what will be submitted. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers.
        - For "What is 2*21?": `FINAL ANSWER: 42`
        - For "Capital of France?": `FINAL ANSWER: Paris`
        - For "What are the first three even numbers?": `FINAL ANSWER: 2, 4, 6`
        """
    )
    
    logging.info("🎯 GAIA agent initialized successfully!")
    
    # BUG FIX: Return the wrapped, compliant agent instead of the raw manager.
    return create_gaia_agent_wrapper(manager)

# --- Main Execution Block for Local Testing ---

def main():
    """Tests the agent with sample GAIA-style questions."""
    configure_logging()
    logging.info("πŸ§ͺ Starting local agent testing...")
    
    agent = initialize_agent()
    if not agent:
        logging.critical("πŸ’₯ Agent initialization failed. Exiting.")
        return
        
    test_questions = [
        "What is 15! / (12! * 3!)?",
        "In what year was the Python programming language first released?",
        "What is the square root of 2025?",
    ]
    
    for i, question in enumerate(test_questions, 1):
        logging.info(f"\n{'='*60}\nπŸ” Test Question {i}: {question}\n{'='*60}")
        start_time = time.time()
        
        # BUG FIX: Call the agent wrapper directly, not agent.run()
        final_answer = agent(question)
        
        elapsed_time = time.time() - start_time
        logging.info(f"βœ… Submitted Answer: {final_answer}")
        logging.info(f"⏱️ Execution time: {elapsed_time:.2f} seconds")
        time.sleep(1)
            
    logging.info(f"\n{'='*60}\n🏁 Testing complete!\n{'='*60}")

if __name__ == "__main__":
    main()