Spaces:
Sleeping
Sleeping
| from dotenv import load_dotenv, find_dotenv | |
| import os | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| from langchain_community.tools.google_search.tool import GoogleSearchAPIWrapper | |
| from langchain_tavily import TavilySearch | |
| from langchain_community.document_loaders import AsyncHtmlLoader | |
| from langchain.tools import tool | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from csv_cache import CSVSCache | |
| from prompt import get_prompt | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from enum import Enum | |
| from langchain_core.tools import Tool | |
| import re | |
| # --- Define Tools --- | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two integers.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two integers.""" | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract b from a.""" | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide a by b, error on zero.""" | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Compute a mod b.""" | |
| return a % b | |
| def wiki_search(query: str) -> dict: | |
| """Search Wikipedia and return up to 2 documents.""" | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs] | |
| return {"wiki_results": "\n---\n".join(results)} | |
| def arxiv_search(query: str) -> dict: | |
| """Search Arxiv and return up to 3 docs.""" | |
| docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs] | |
| return {"arxiv_results": "\n---\n".join(results)} | |
| class LLMProvider(Enum): | |
| """ | |
| An Enum to represent the different LLM providers and their | |
| corresponding environment variable names for API keys. | |
| """ | |
| HUGGINGFACE = ("HuggingFace", "HF_TOKEN") | |
| HUGGINGFACE_LLAMA = ("HUGGINGFACE_LLAMA", "HF_TOKEN") | |
| GOOGLE_GEMINI = ("Google Gemini", "GOOGLE_API_KEY") | |
| class Model: | |
| def __init__(self, provider: LLMProvider = LLMProvider.HUGGINGFACE): | |
| load_dotenv(find_dotenv()) | |
| self.system_prompt = get_prompt() | |
| print(f"system_prompt: {self.system_prompt}") | |
| self.provider = provider | |
| self.agent_executor = self.setup_model() | |
| def get_answer(self, question: str) -> str: | |
| try: | |
| result = self.agent_executor.invoke({"input": question}) | |
| except BaseException as e: | |
| print(f"An error occurred: {e}") | |
| result = {"FINAL_ANSWER":"ERROR"} | |
| # The final answer is typically in the 'output' key of the result dictionary | |
| final_answer = result['output'] | |
| pattern = r'FINAL_ANSWER:"(.*?)"' | |
| match = re.search(pattern, final_answer, re.DOTALL) | |
| if match: | |
| final_answer_value = match.group(1) | |
| print(f"The extracted FINAL_ANSWER is: {final_answer_value}") | |
| else: | |
| print("ERROR: Pattern not found.: {r}") | |
| final_answer_value = "ERROR" | |
| return final_answer_value | |
| def get_chat_with_tools(self, provider: LLMProvider, tools): | |
| api_token = os.getenv(provider.value[1]) | |
| if not api_token: | |
| raise ValueError( | |
| f"API key for {provider.value[0]} not found. " | |
| f"Please set the '{provider.value[1]}' environment variable." | |
| ) | |
| if provider == LLMProvider.HUGGINGFACE: | |
| llm = HuggingFaceEndpoint( | |
| repo_id="Qwen/Qwen3-Next-80B-A3B-Thinking", | |
| huggingfacehub_api_token=api_token, | |
| temperature=0 | |
| ) | |
| return ChatHuggingFace(llm=llm).bind_tools(tools) | |
| if provider == LLMProvider.HUGGINGFACE_LLAMA: | |
| llm = HuggingFaceEndpoint( | |
| repo_id="meta-llama/Llama-2-7b-chat-hf", | |
| huggingfacehub_api_token=api_token, | |
| temperature=0 | |
| ) | |
| return ChatHuggingFace(llm=llm).bind_tools(tools) | |
| elif provider == LLMProvider.GOOGLE_GEMINI: | |
| chat = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash", | |
| temperature=0 | |
| ) | |
| # search = GoogleSearchAPIWrapper() | |
| # # Define the Google Search tool correctly | |
| # google_search_tool = Tool( | |
| # name="Google Search", | |
| # description="Search Google for recent information.", | |
| # func=search.run, # Use the run method to execute the search directly | |
| # ) | |
| # tools.append(google_search_tool) | |
| return chat.bind_tools(tools) | |
| else: | |
| raise ValueError(f"Unknown LLM provider: {provider}") | |
| def setup_model(self): | |
| tavily_search_tool = TavilySearch( | |
| api_key=os.getenv("TAVILY_API_KEY"), | |
| max_results=5, | |
| topic="general", | |
| ) | |
| # # Define a tool for the agent to use | |
| tools = [ | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| modulus, | |
| wiki_search, | |
| tavily_search_tool, | |
| arxiv_search, | |
| ] | |
| chat = self.get_chat_with_tools(self.provider, tools) | |
| # Create the ReAct prompt template | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", self.system_prompt), # Use the new, detailed ReAct prompt | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ] | |
| ) | |
| # Create the agent | |
| agent = create_tool_calling_agent(chat, tools, prompt) | |
| # Create the agent executor | |
| return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) | |
| def update_mode(model): | |
| csv = CSVSCache() | |
| df = csv.get_all_entries() | |
| # Loop over the rows using iterrows() for clear, row-by-row logic. | |
| i = 0 | |
| for index, row in df.iterrows(): | |
| if row['answer'] == 'unknown': | |
| question = row['question'] | |
| print(f"Found unknown answer for question: '{question}'") | |
| # Call the provided LLM function to get the new answer. | |
| llm_response = model.get_answer(question) | |
| # Update the DataFrame at the specific row and column. | |
| # We use .at for efficient single-cell updates. | |
| df.at[index, 'answer'] = llm_response | |
| print(f"Updated with new answer: '{llm_response}'") | |
| if index > 20: | |
| break | |
| print("\nProcessing complete.") | |
| csv.df = df | |
| csv._save_cache() | |
| def main(): | |
| load_dotenv(find_dotenv()) | |
| csv = CSVSCache() | |
| df = csv.get_all_entries() | |
| model = Model(LLMProvider.HUGGINGFACE) | |
| #update_mode(model) | |
| test_questions = [0, 6, 10, 12, 15] | |
| for row in test_questions: | |
| response = model.get_answer(df.iloc[row]['question']) | |
| print(f"the output is: {response}") | |
| if __name__ == "__main__": | |
| main() | |