Update agent.py
Browse files
agent.py
CHANGED
|
@@ -14,10 +14,12 @@ import mimetypes
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
import tempfile
|
|
|
|
| 17 |
from typing import List, Dict, Any, Optional
|
| 18 |
import json
|
| 19 |
import requests
|
| 20 |
from urllib.parse import urlparse
|
|
|
|
| 21 |
|
| 22 |
from smolagents import (
|
| 23 |
CodeAgent,
|
|
@@ -42,6 +44,39 @@ def _download_file(file_id: str) -> bytes:
|
|
| 42 |
resp.raise_for_status()
|
| 43 |
return resp.content
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# --------------------------------------------------------------------------- #
|
| 46 |
# custom tool: fetch GAIA attachments
|
| 47 |
# --------------------------------------------------------------------------- #
|
|
@@ -224,6 +259,81 @@ def analyze_excel_file(file_path: str, query: str) -> str:
|
|
| 224 |
except Exception as e:
|
| 225 |
return f"Error analyzing Excel file: {str(e)}"
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
# --------------------------------------------------------------------------- #
|
| 228 |
# GAIAAgent class
|
| 229 |
# --------------------------------------------------------------------------- #
|
|
@@ -233,7 +343,8 @@ class GAIAAgent:
|
|
| 233 |
api_key: Optional[str] = None,
|
| 234 |
temperature: float = 0.1,
|
| 235 |
verbose: bool = False,
|
| 236 |
-
system_prompt: Optional[str] = None
|
|
|
|
| 237 |
):
|
| 238 |
"""
|
| 239 |
Initialize a GAIAAgent with Claude model
|
|
@@ -243,6 +354,7 @@ class GAIAAgent:
|
|
| 243 |
temperature: Temperature for text generation
|
| 244 |
verbose: Enable verbose logging
|
| 245 |
system_prompt: Custom system prompt (optional)
|
|
|
|
| 246 |
"""
|
| 247 |
# Set verbosity
|
| 248 |
self.verbose = verbose
|
|
@@ -260,15 +372,16 @@ All answers are graded by exact string match, so format carefully!"""
|
|
| 260 |
if self.verbose:
|
| 261 |
print(f"Using Anthropic token: {api_key[:5]}...")
|
| 262 |
|
| 263 |
-
# Initialize Claude model
|
| 264 |
-
self.model =
|
| 265 |
model_id="anthropic/claude-3-5-sonnet-20240620", # Use Claude 3.5 Sonnet
|
| 266 |
api_key=api_key,
|
| 267 |
-
temperature=temperature
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
if self.verbose:
|
| 271 |
-
print(f"Initialized model:
|
| 272 |
|
| 273 |
# Initialize default tools
|
| 274 |
self.tools = [
|
|
@@ -334,8 +447,12 @@ All answers are graded by exact string match, so format carefully!"""
|
|
| 334 |
# If there's a file, read it and include its content in the context
|
| 335 |
if task_file_path:
|
| 336 |
try:
|
|
|
|
|
|
|
| 337 |
with open(task_file_path, 'r', errors='ignore') as f:
|
| 338 |
-
file_content = f.read()
|
|
|
|
|
|
|
| 339 |
|
| 340 |
# Determine file type from extension
|
| 341 |
import os
|
|
@@ -343,11 +460,11 @@ All answers are graded by exact string match, so format carefully!"""
|
|
| 343 |
|
| 344 |
context = f"""
|
| 345 |
Question: {question}
|
| 346 |
-
This question has an associated file. Here is the file content:
|
| 347 |
```{file_ext}
|
| 348 |
{file_content}
|
| 349 |
```
|
| 350 |
-
Analyze the file content
|
| 351 |
"""
|
| 352 |
except Exception as file_e:
|
| 353 |
try:
|
|
@@ -385,12 +502,12 @@ This question appears to be in reversed text. Here's the reversed version:
|
|
| 385 |
Now answer the question above. Remember to format your answer exactly as requested.
|
| 386 |
"""
|
| 387 |
|
| 388 |
-
# Add a prompt to ensure precise answers
|
| 389 |
full_prompt = f"""{context}
|
| 390 |
When answering, provide ONLY the precise answer requested.
|
| 391 |
Do not include explanations, steps, reasoning, or additional text.
|
| 392 |
Be direct and specific. GAIA benchmark requires exact matching answers.
|
| 393 |
-
|
| 394 |
"""
|
| 395 |
|
| 396 |
# Run the agent with the question
|
|
@@ -486,8 +603,9 @@ class ClaudeAgent:
|
|
| 486 |
# Create GAIAAgent instance
|
| 487 |
self.agent = GAIAAgent(
|
| 488 |
api_key=api_key,
|
| 489 |
-
temperature=0.1,
|
| 490 |
-
verbose=True,
|
|
|
|
| 491 |
)
|
| 492 |
except Exception as e:
|
| 493 |
print(f"Error initializing GAIAAgent: {e}")
|
|
@@ -506,6 +624,9 @@ class ClaudeAgent:
|
|
| 506 |
try:
|
| 507 |
print(f"Received question: {question[:100]}..." if len(question) > 100 else f"Received question: {question}")
|
| 508 |
|
|
|
|
|
|
|
|
|
|
| 509 |
# Detect reversed text
|
| 510 |
if question.startswith(".") or ".rewsna eht sa" in question:
|
| 511 |
print("Detected reversed text question")
|
|
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
import tempfile
|
| 17 |
+
import time
|
| 18 |
from typing import List, Dict, Any, Optional
|
| 19 |
import json
|
| 20 |
import requests
|
| 21 |
from urllib.parse import urlparse
|
| 22 |
+
import random
|
| 23 |
|
| 24 |
from smolagents import (
|
| 25 |
CodeAgent,
|
|
|
|
| 44 |
resp.raise_for_status()
|
| 45 |
return resp.content
|
| 46 |
|
| 47 |
+
# --------------------------------------------------------------------------- #
|
| 48 |
+
# Rate limiting helper
|
| 49 |
+
# --------------------------------------------------------------------------- #
|
| 50 |
+
class RateLimiter:
|
| 51 |
+
"""Simple rate limiter to prevent Anthropic API rate limit errors"""
|
| 52 |
+
def __init__(self, requests_per_minute=20, burst=3):
|
| 53 |
+
self.requests_per_minute = requests_per_minute
|
| 54 |
+
self.burst = burst
|
| 55 |
+
self.request_times = []
|
| 56 |
+
|
| 57 |
+
def wait(self):
|
| 58 |
+
"""Wait if needed to avoid exceeding rate limits"""
|
| 59 |
+
now = time.time()
|
| 60 |
+
# Remove timestamps older than 1 minute
|
| 61 |
+
self.request_times = [t for t in self.request_times if now - t < 60]
|
| 62 |
+
|
| 63 |
+
# If we've made too many requests in the last minute, wait
|
| 64 |
+
if len(self.request_times) >= self.requests_per_minute:
|
| 65 |
+
oldest = min(self.request_times)
|
| 66 |
+
sleep_time = 60 - (now - oldest) + 1 # +1 for safety
|
| 67 |
+
print(f"Rate limit approaching. Waiting {sleep_time:.2f} seconds before next request...")
|
| 68 |
+
time.sleep(sleep_time)
|
| 69 |
+
|
| 70 |
+
# Add current timestamp to the list
|
| 71 |
+
self.request_times.append(time.time())
|
| 72 |
+
|
| 73 |
+
# Add a small random delay to avoid bursts of requests
|
| 74 |
+
if len(self.request_times) > self.burst:
|
| 75 |
+
time.sleep(random.uniform(0.2, 1.0))
|
| 76 |
+
|
| 77 |
+
# Global rate limiter instance
|
| 78 |
+
RATE_LIMITER = RateLimiter(requests_per_minute=25) # Keep below 40 for safety
|
| 79 |
+
|
| 80 |
# --------------------------------------------------------------------------- #
|
| 81 |
# custom tool: fetch GAIA attachments
|
| 82 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 259 |
except Exception as e:
|
| 260 |
return f"Error analyzing Excel file: {str(e)}"
|
| 261 |
|
| 262 |
+
# --------------------------------------------------------------------------- #
|
| 263 |
+
# Custom LiteLLM model with rate limiting and error handling
|
| 264 |
+
# --------------------------------------------------------------------------- #
|
| 265 |
+
class RateLimitedClaudeModel:
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
model_id: str = "anthropic/claude-3-5-sonnet-20240620",
|
| 269 |
+
api_key: Optional[str] = None,
|
| 270 |
+
temperature: float = 0.1,
|
| 271 |
+
max_tokens: int = 1024,
|
| 272 |
+
max_retries: int = 3,
|
| 273 |
+
retry_delay: int = 5,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Initialize a Claude model with rate limiting and error handling
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
model_id: The model ID to use
|
| 280 |
+
api_key: The API key to use
|
| 281 |
+
temperature: The temperature to use
|
| 282 |
+
max_tokens: The maximum number of tokens to generate
|
| 283 |
+
max_retries: The maximum number of retries on rate limit errors
|
| 284 |
+
retry_delay: The initial delay between retries (will increase exponentially)
|
| 285 |
+
"""
|
| 286 |
+
# Get API key
|
| 287 |
+
if api_key is None:
|
| 288 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
| 289 |
+
if not api_key:
|
| 290 |
+
raise ValueError("No Anthropic token provided. Please set ANTHROPIC_API_KEY environment variable or pass api_key parameter.")
|
| 291 |
+
|
| 292 |
+
self.model_id = model_id
|
| 293 |
+
self.api_key = api_key
|
| 294 |
+
self.temperature = temperature
|
| 295 |
+
self.max_tokens = max_tokens
|
| 296 |
+
self.max_retries = max_retries
|
| 297 |
+
self.retry_delay = retry_delay
|
| 298 |
+
|
| 299 |
+
# Create the underlying LiteLLM model
|
| 300 |
+
self.model = LiteLLMModel(
|
| 301 |
+
model_id=model_id,
|
| 302 |
+
api_key=api_key,
|
| 303 |
+
temperature=temperature
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def __call__(self, prompt: str, system_instruction: str, **kwargs) -> str:
|
| 307 |
+
"""
|
| 308 |
+
Call the model with rate limiting and error handling
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
prompt: The prompt to generate from
|
| 312 |
+
system_instruction: The system instruction to use
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
The generated text
|
| 316 |
+
"""
|
| 317 |
+
retries = 0
|
| 318 |
+
while True:
|
| 319 |
+
try:
|
| 320 |
+
# Wait according to rate limiter
|
| 321 |
+
RATE_LIMITER.wait()
|
| 322 |
+
|
| 323 |
+
# Call the model
|
| 324 |
+
return self.model(prompt, system_instruction=system_instruction, **kwargs)
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
# Check if it's a rate limit error
|
| 328 |
+
if "rate_limit_error" in str(e) and retries < self.max_retries:
|
| 329 |
+
retries += 1
|
| 330 |
+
sleep_time = self.retry_delay * (2 ** (retries - 1)) # Exponential backoff
|
| 331 |
+
print(f"Rate limit exceeded, retrying in {sleep_time} seconds (attempt {retries}/{self.max_retries})...")
|
| 332 |
+
time.sleep(sleep_time)
|
| 333 |
+
else:
|
| 334 |
+
# If it's not a rate limit error or we've exceeded max retries, raise
|
| 335 |
+
raise
|
| 336 |
+
|
| 337 |
# --------------------------------------------------------------------------- #
|
| 338 |
# GAIAAgent class
|
| 339 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 343 |
api_key: Optional[str] = None,
|
| 344 |
temperature: float = 0.1,
|
| 345 |
verbose: bool = False,
|
| 346 |
+
system_prompt: Optional[str] = None,
|
| 347 |
+
max_tokens: int = 1024,
|
| 348 |
):
|
| 349 |
"""
|
| 350 |
Initialize a GAIAAgent with Claude model
|
|
|
|
| 354 |
temperature: Temperature for text generation
|
| 355 |
verbose: Enable verbose logging
|
| 356 |
system_prompt: Custom system prompt (optional)
|
| 357 |
+
max_tokens: Maximum number of tokens to generate per response
|
| 358 |
"""
|
| 359 |
# Set verbosity
|
| 360 |
self.verbose = verbose
|
|
|
|
| 372 |
if self.verbose:
|
| 373 |
print(f"Using Anthropic token: {api_key[:5]}...")
|
| 374 |
|
| 375 |
+
# Initialize Claude model with rate limiting
|
| 376 |
+
self.model = RateLimitedClaudeModel(
|
| 377 |
model_id="anthropic/claude-3-5-sonnet-20240620", # Use Claude 3.5 Sonnet
|
| 378 |
api_key=api_key,
|
| 379 |
+
temperature=temperature,
|
| 380 |
+
max_tokens=max_tokens,
|
| 381 |
)
|
| 382 |
|
| 383 |
if self.verbose:
|
| 384 |
+
print(f"Initialized model: RateLimitedClaudeModel - anthropic/claude-3-5-sonnet-20240620")
|
| 385 |
|
| 386 |
# Initialize default tools
|
| 387 |
self.tools = [
|
|
|
|
| 447 |
# If there's a file, read it and include its content in the context
|
| 448 |
if task_file_path:
|
| 449 |
try:
|
| 450 |
+
# Limit file content size to avoid token limits
|
| 451 |
+
max_file_size = 10000 # Characters
|
| 452 |
with open(task_file_path, 'r', errors='ignore') as f:
|
| 453 |
+
file_content = f.read(max_file_size)
|
| 454 |
+
if len(file_content) >= max_file_size:
|
| 455 |
+
file_content = file_content[:max_file_size] + "... [content truncated to prevent exceeding token limits]"
|
| 456 |
|
| 457 |
# Determine file type from extension
|
| 458 |
import os
|
|
|
|
| 460 |
|
| 461 |
context = f"""
|
| 462 |
Question: {question}
|
| 463 |
+
This question has an associated file. Here is the file content (it may be truncated):
|
| 464 |
```{file_ext}
|
| 465 |
{file_content}
|
| 466 |
```
|
| 467 |
+
Analyze the available file content to answer the question.
|
| 468 |
"""
|
| 469 |
except Exception as file_e:
|
| 470 |
try:
|
|
|
|
| 502 |
Now answer the question above. Remember to format your answer exactly as requested.
|
| 503 |
"""
|
| 504 |
|
| 505 |
+
# Add a prompt to ensure precise answers but keep it concise
|
| 506 |
full_prompt = f"""{context}
|
| 507 |
When answering, provide ONLY the precise answer requested.
|
| 508 |
Do not include explanations, steps, reasoning, or additional text.
|
| 509 |
Be direct and specific. GAIA benchmark requires exact matching answers.
|
| 510 |
+
Example: If asked "What is the capital of France?", respond just with "Paris".
|
| 511 |
"""
|
| 512 |
|
| 513 |
# Run the agent with the question
|
|
|
|
| 603 |
# Create GAIAAgent instance
|
| 604 |
self.agent = GAIAAgent(
|
| 605 |
api_key=api_key,
|
| 606 |
+
temperature=0.1, # Use low temperature for precise answers
|
| 607 |
+
verbose=True, # Enable verbose logging
|
| 608 |
+
max_tokens=1024, # Reduce max tokens to avoid hitting rate limits
|
| 609 |
)
|
| 610 |
except Exception as e:
|
| 611 |
print(f"Error initializing GAIAAgent: {e}")
|
|
|
|
| 624 |
try:
|
| 625 |
print(f"Received question: {question[:100]}..." if len(question) > 100 else f"Received question: {question}")
|
| 626 |
|
| 627 |
+
# Add delay between questions to respect rate limits
|
| 628 |
+
time.sleep(random.uniform(0.5, 2.0))
|
| 629 |
+
|
| 630 |
# Detect reversed text
|
| 631 |
if question.startswith(".") or ".rewsna eht sa" in question:
|
| 632 |
print("Detected reversed text question")
|