Sagar Sanghani
made google work, added llama
73f74f6
raw
history blame
7.36 kB
from dotenv import load_dotenv, find_dotenv
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.tools.google_search.tool import GoogleSearchAPIWrapper
from langchain_tavily import TavilySearch
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain.tools import tool
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import AgentExecutor, create_tool_calling_agent
from csv_cache import CSVSCache
from prompt import get_prompt
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from enum import Enum
from langchain_core.tools import Tool
import re
# --- Define Tools ---
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b, error on zero."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Compute a mod b."""
return a % b
@tool
def wiki_search(query: str) -> dict:
"""Search Wikipedia and return up to 2 documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
return {"wiki_results": "\n---\n".join(results)}
@tool
def arxiv_search(query: str) -> dict:
"""Search Arxiv and return up to 3 docs."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs]
return {"arxiv_results": "\n---\n".join(results)}
class LLMProvider(Enum):
"""
An Enum to represent the different LLM providers and their
corresponding environment variable names for API keys.
"""
HUGGINGFACE = ("HuggingFace", "HF_TOKEN")
HUGGINGFACE_LLAMA = ("HUGGINGFACE_LLAMA", "HF_TOKEN")
GOOGLE_GEMINI = ("Google Gemini", "GOOGLE_API_KEY")
class Model:
def __init__(self, provider: LLMProvider = LLMProvider.HUGGINGFACE):
load_dotenv(find_dotenv())
self.system_prompt = get_prompt()
print(f"system_prompt: {self.system_prompt}")
self.provider = provider
self.agent_executor = self.setup_model()
def get_answer(self, question: str) -> str:
try:
result = self.agent_executor.invoke({"input": question})
except BaseException as e:
print(f"An error occurred: {e}")
result = {"FINAL_ANSWER":"ERROR"}
# The final answer is typically in the 'output' key of the result dictionary
final_answer = result['output']
pattern = r'FINAL_ANSWER:"(.*?)"'
match = re.search(pattern, final_answer, re.DOTALL)
if match:
final_answer_value = match.group(1)
print(f"The extracted FINAL_ANSWER is: {final_answer_value}")
else:
print("ERROR: Pattern not found.: {r}")
final_answer_value = "ERROR"
return final_answer_value
def get_chat_with_tools(self, provider: LLMProvider, tools):
api_token = os.getenv(provider.value[1])
if not api_token:
raise ValueError(
f"API key for {provider.value[0]} not found. "
f"Please set the '{provider.value[1]}' environment variable."
)
if provider == LLMProvider.HUGGINGFACE:
llm = HuggingFaceEndpoint(
repo_id="Qwen/Qwen3-Next-80B-A3B-Thinking",
huggingfacehub_api_token=api_token,
temperature=0
)
return ChatHuggingFace(llm=llm).bind_tools(tools)
if provider == LLMProvider.HUGGINGFACE_LLAMA:
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Llama-2-7b-chat-hf",
huggingfacehub_api_token=api_token,
temperature=0
)
return ChatHuggingFace(llm=llm).bind_tools(tools)
elif provider == LLMProvider.GOOGLE_GEMINI:
chat = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
temperature=0
)
# search = GoogleSearchAPIWrapper()
# # Define the Google Search tool correctly
# google_search_tool = Tool(
# name="Google Search",
# description="Search Google for recent information.",
# func=search.run, # Use the run method to execute the search directly
# )
# tools.append(google_search_tool)
return chat.bind_tools(tools)
else:
raise ValueError(f"Unknown LLM provider: {provider}")
def setup_model(self):
tavily_search_tool = TavilySearch(
api_key=os.getenv("TAVILY_API_KEY"),
max_results=5,
topic="general",
)
# # Define a tool for the agent to use
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
tavily_search_tool,
arxiv_search,
]
chat = self.get_chat_with_tools(self.provider, tools)
# Create the ReAct prompt template
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt), # Use the new, detailed ReAct prompt
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
# Create the agent
agent = create_tool_calling_agent(chat, tools, prompt)
# Create the agent executor
return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def update_mode(model):
csv = CSVSCache()
df = csv.get_all_entries()
# Loop over the rows using iterrows() for clear, row-by-row logic.
i = 0
for index, row in df.iterrows():
if row['answer'] == 'unknown':
question = row['question']
print(f"Found unknown answer for question: '{question}'")
# Call the provided LLM function to get the new answer.
llm_response = model.get_answer(question)
# Update the DataFrame at the specific row and column.
# We use .at for efficient single-cell updates.
df.at[index, 'answer'] = llm_response
print(f"Updated with new answer: '{llm_response}'")
if index > 20:
break
print("\nProcessing complete.")
csv.df = df
csv._save_cache()
def main():
load_dotenv(find_dotenv())
csv = CSVSCache()
df = csv.get_all_entries()
model = Model(LLMProvider.HUGGINGFACE)
#update_mode(model)
test_questions = [0, 6, 10, 12, 15]
for row in test_questions:
response = model.get_answer(df.iloc[row]['question'])
print(f"the output is: {response}")
if __name__ == "__main__":
main()