|
|
import os |
|
|
import re |
|
|
from smolagents import ( |
|
|
AgentMemory, |
|
|
CodeAgent, |
|
|
InferenceClientModel, |
|
|
FinalAnswerTool, |
|
|
WebSearchTool, |
|
|
) |
|
|
from collections.abc import Callable |
|
|
|
|
|
from smolagents.default_tools import VisitWebpageTool, WikipediaSearchTool |
|
|
from smolagents.models import OpenAIModel |
|
|
|
|
|
|
|
|
class LLMOnlyAgent: |
|
|
def __init__(self): |
|
|
|
|
|
self.instructions = """finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. |
|
|
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. |
|
|
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. |
|
|
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. |
|
|
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""" |
|
|
|
|
|
|
|
|
qwen_model = InferenceClientModel( |
|
|
max_tokens=8096, |
|
|
model_id="Qwen/Qwen3-Coder-30B-A3B-Instruct", |
|
|
custom_role_conversions=None, |
|
|
provider="nebius", |
|
|
) |
|
|
|
|
|
gemini_model = OpenAIModel( |
|
|
max_tokens=8096, |
|
|
model_id="gemini-2.5-flash", |
|
|
|
|
|
api_base="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
api_key=os.environ["GEMINI_API_KEY"], |
|
|
) |
|
|
|
|
|
model = gemini_model |
|
|
|
|
|
|
|
|
self.agent = CodeAgent( |
|
|
model=model, |
|
|
instructions=self.instructions, |
|
|
tools=[ |
|
|
FinalAnswerTool(), |
|
|
WikipediaSearchTool(), |
|
|
WebSearchTool(), |
|
|
VisitWebpageTool(), |
|
|
], |
|
|
additional_authorized_imports=[ |
|
|
"markdownify", |
|
|
"requests", |
|
|
"pandas", |
|
|
"numpy", |
|
|
"chess", |
|
|
], |
|
|
max_steps=5, |
|
|
planning_interval=3, |
|
|
|
|
|
) |
|
|
|
|
|
print("LLM-only Agent initialized.") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
|
answer = self.agent.run(question) |
|
|
print(f"Agent returning answer: {answer}") |
|
|
return answer |
|
|
|
|
|
def final_answer_checks(self) -> list[Callable]: |
|
|
return [self.check_func] |
|
|
|
|
|
def check_func(self, answer: str, memory: AgentMemory) -> bool: |
|
|
check = bool( |
|
|
re.match( |
|
|
r'^(\d+(\.\d+)?|\w+(\s+\w+){0,4}|(\d+(\.\d+)?|"[^"]*"|\w+)(\s*,\s*(\d+(\.\d+)?|"[^"]*"|\w+))+)$', |
|
|
answer, |
|
|
) |
|
|
) |
|
|
print(f"FINAL ANSWER CHECK is {check}") |
|
|
return check |
|
|
|