Sagar Sanghani commited on
Commit
66bfc35
·
1 Parent(s): 88401e0

updated with more tools

Browse files
Files changed (1) hide show
  1. model.py +63 -11
model.py CHANGED
@@ -2,30 +2,73 @@
2
  import os
3
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
4
  from langchain_community.tools import DuckDuckGoSearchRun
 
5
  from langchain_community.document_loaders import AsyncHtmlLoader
6
  from langchain.tools import tool
7
  from langchain.prompts import ChatPromptTemplate
8
  from langchain.agents import AgentExecutor, create_tool_calling_agent
9
  from prompt import get_prompt
 
 
10
  import re
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class Model:
13
  def __init__(self):
14
  #load_dotenv(find_dotenv())
15
  self.token = os.getenv("HF_TOKEN")
16
- print(f"token: {self.token}")
17
  self.system_prompt = get_prompt()
18
  print(f"system_prompt: {self.system_prompt}")
19
  self.agent_executor = self.setup_model()
20
 
21
- # Define a tool for the agent to use
22
- @tool
23
- def scrape_webpage(self,url: str) -> str:
24
- """Scrapes a given URL and returns the content."""
25
- loader = AsyncHtmlLoader(url)
26
- docs = loader.load()
27
- return docs[0].page_content# Define the search tool
28
-
29
  def get_answer(self, question: str) -> str:
30
  try:
31
  result = self.agent_executor.invoke({"input": question})
@@ -52,8 +95,17 @@ class Model:
52
  search = DuckDuckGoSearchRun()
53
 
54
  # # Define a tool for the agent to use
55
- tools = [search, self.scrape_webpage]
56
-
 
 
 
 
 
 
 
 
 
57
  llm = HuggingFaceEndpoint(
58
  repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
59
  huggingfacehub_api_token=self.token,
 
2
  import os
3
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
4
  from langchain_community.tools import DuckDuckGoSearchRun
5
+ from langchain_community.tools.tavily_search import TavilySearchResults
6
  from langchain_community.document_loaders import AsyncHtmlLoader
7
  from langchain.tools import tool
8
  from langchain.prompts import ChatPromptTemplate
9
  from langchain.agents import AgentExecutor, create_tool_calling_agent
10
  from prompt import get_prompt
11
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
+
13
  import re
14
 
15
+ # --- Define Tools ---
16
+ @tool
17
+ def multiply(a: int, b: int) -> int:
18
+ """Multiply two integers."""
19
+ return a * b
20
+
21
+ @tool
22
+ def add(a: int, b: int) -> int:
23
+ """Add two integers."""
24
+ return a + b
25
+
26
+ @tool
27
+ def subtract(a: int, b: int) -> int:
28
+ """Subtract b from a."""
29
+ return a - b
30
+
31
+ @tool
32
+ def divide(a: int, b: int) -> float:
33
+ """Divide a by b, error on zero."""
34
+ if b == 0:
35
+ raise ValueError("Cannot divide by zero.")
36
+ return a / b
37
+
38
+ @tool
39
+ def modulus(a: int, b: int) -> int:
40
+ """Compute a mod b."""
41
+ return a % b
42
+
43
+ @tool
44
+ def wiki_search(query: str) -> dict:
45
+ """Search Wikipedia and return up to 2 documents."""
46
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
47
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
48
+ return {"wiki_results": "\n---\n".join(results)}
49
+
50
+ @tool
51
+ def web_search(query: str) -> dict:
52
+ """Search Tavily and return up to 3 results."""
53
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
54
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
55
+ return {"web_results": "\n---\n".join(results)}
56
+
57
+ @tool
58
+ def arxiv_search(query: str) -> dict:
59
+ """Search Arxiv and return up to 3 docs."""
60
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
61
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs]
62
+ return {"arxiv_results": "\n---\n".join(results)}
63
+
64
  class Model:
65
  def __init__(self):
66
  #load_dotenv(find_dotenv())
67
  self.token = os.getenv("HF_TOKEN")
 
68
  self.system_prompt = get_prompt()
69
  print(f"system_prompt: {self.system_prompt}")
70
  self.agent_executor = self.setup_model()
71
 
 
 
 
 
 
 
 
 
72
  def get_answer(self, question: str) -> str:
73
  try:
74
  result = self.agent_executor.invoke({"input": question})
 
95
  search = DuckDuckGoSearchRun()
96
 
97
  # # Define a tool for the agent to use
98
+ tools = [
99
+ multiply,
100
+ add,
101
+ subtract,
102
+ divide,
103
+ modulus,
104
+ wiki_search,
105
+ search,
106
+ arxiv_search,
107
+ ]
108
+
109
  llm = HuggingFaceEndpoint(
110
  repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
111
  huggingfacehub_api_token=self.token,