File size: 2,776 Bytes
f087af6
37fb9b0
9707da3
f087af6
7bac906
 
 
cdf543a
 
 
 
 
7bac906
f087af6
7bac906
cdf543a
 
 
 
 
 
7bac906
 
 
 
 
cdf543a
 
 
 
 
 
 
 
 
 
7bac906
cdf543a
 
9707da3
 
cdf543a
 
9707da3
cdf543a
37fb9b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import re
import uuid
from dotenv import load_dotenv
from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
from langchain_community.tools import DuckDuckGoSearchRun
from langchain.agents.middleware import (
    ModelCallLimitMiddleware,
    ToolCallLimitMiddleware,
)
from langchain.messages import HumanMessage, AIMessage

load_dotenv()


class LangChainAgent:
    def __init__(self):
        os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")

        system_prompt = """finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""

        self.agent = create_agent(
            model="google_genai:gemini-2.5-flash",
            tools=[DuckDuckGoSearchRun()],
            system_prompt=system_prompt,
            checkpointer=InMemorySaver(),
            middleware=[
                ModelCallLimitMiddleware(run_limit=10, exit_behavior="end"),
                ToolCallLimitMiddleware(run_limit=20, exit_behavior="end"),
            ],
        )

    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")
        # Generate a unique thread ID for each request
        thread_id = str(uuid.uuid4())
        response = self.agent.invoke(
            {"messages": [HumanMessage(content=question)]},
            {"configurable": {"thread_id": thread_id}, "recursion_limit": 50},
        )
        answer = response["messages"][-1].text
        final_answer = self.extract_final_answer(answer)
        print(f"Agent returning answer: {final_answer}")
        return final_answer

    def extract_final_answer(self, text):
        """
        Extracts the text following 'FINAL ANSWER:' using regex.
        """
        # re.DOTALL allows '.' to match newline characters
        pattern = re.compile(r"FINAL ANSWER:\s*(.*)", re.DOTALL)

        match = pattern.search(text)

        if match:
            # .group(1) retrieves the content of the first capturing group
            # .strip() removes leading/trailing whitespace
            return match.group(1).strip()
        else:
            return "No final answer provided, here is the complete text : " + text