Sagar Sanghani commited on
Commit
73f74f6
·
1 Parent(s): 960e946

made google work, added llama

Browse files
Files changed (1) hide show
  1. model.py +39 -19
model.py CHANGED
@@ -2,11 +2,11 @@ from dotenv import load_dotenv, find_dotenv
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
5
- from langchain_community.tools import DuckDuckGoSearchRun
6
  from langchain_tavily import TavilySearch
7
  from langchain_community.document_loaders import AsyncHtmlLoader
8
  from langchain.tools import tool
9
- from langchain.prompts import ChatPromptTemplate
10
  from langchain.agents import AgentExecutor, create_tool_calling_agent
11
  from csv_cache import CSVSCache
12
  from prompt import get_prompt
@@ -64,15 +64,18 @@ class LLMProvider(Enum):
64
  corresponding environment variable names for API keys.
65
  """
66
  HUGGINGFACE = ("HuggingFace", "HF_TOKEN")
 
67
  GOOGLE_GEMINI = ("Google Gemini", "GOOGLE_API_KEY")
68
 
69
  class Model:
70
- def __init__(self):
71
- #load_dotenv(find_dotenv())
72
- self.token = os.getenv("HF_TOKEN")
73
  self.system_prompt = get_prompt()
74
  print(f"system_prompt: {self.system_prompt}")
 
75
  self.agent_executor = self.setup_model()
 
 
76
 
77
  def get_answer(self, question: str) -> str:
78
  try:
@@ -107,10 +110,19 @@ class Model:
107
  if provider == LLMProvider.HUGGINGFACE:
108
  llm = HuggingFaceEndpoint(
109
  repo_id="Qwen/Qwen3-Next-80B-A3B-Thinking",
110
- huggingfacehub_api_token=self.token,
 
 
 
 
 
 
 
 
111
  temperature=0
112
  )
113
  return ChatHuggingFace(llm=llm).bind_tools(tools)
 
114
 
115
 
116
  elif provider == LLMProvider.GOOGLE_GEMINI:
@@ -118,14 +130,17 @@ class Model:
118
  model="gemini-2.5-flash",
119
  temperature=0
120
  )
121
- # Define the Google Search tool using the Gemini API's built-in tool
122
- google_search_tool = Tool(
123
- name="google_search",
124
- description="Search Google for recent information.",
125
- func=lambda query: chat.invoke(query, tools=[{"googleSearch": {}}]).content
126
- )
127
- tools = tools.append(google_search_tool)
128
- return chat.bind(tools)
 
 
 
129
 
130
  else:
131
  raise ValueError(f"Unknown LLM provider: {provider}")
@@ -148,14 +163,15 @@ class Model:
148
  tavily_search_tool,
149
  arxiv_search,
150
  ]
151
- chat = self.get_chat_with_tools(LLMProvider.HUGGINGFACE, tools)
152
 
153
  # Create the ReAct prompt template
154
  prompt = ChatPromptTemplate.from_messages(
155
  [
156
  ("system", self.system_prompt), # Use the new, detailed ReAct prompt
157
- ("placeholder", "{agent_scratchpad}"),
158
  ("human", "{input}"),
 
 
159
  ]
160
  )
161
 
@@ -193,10 +209,14 @@ def update_mode(model):
193
 
194
  def main():
195
  load_dotenv(find_dotenv())
196
- model = Model()
 
 
197
  #update_mode(model)
198
- response = model.get_answer("Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.")
199
- print(f"the output is: {response}")
 
 
200
 
201
  if __name__ == "__main__":
202
  main()
 
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
5
+ from langchain_community.tools.google_search.tool import GoogleSearchAPIWrapper
6
  from langchain_tavily import TavilySearch
7
  from langchain_community.document_loaders import AsyncHtmlLoader
8
  from langchain.tools import tool
9
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
10
  from langchain.agents import AgentExecutor, create_tool_calling_agent
11
  from csv_cache import CSVSCache
12
  from prompt import get_prompt
 
64
  corresponding environment variable names for API keys.
65
  """
66
  HUGGINGFACE = ("HuggingFace", "HF_TOKEN")
67
+ HUGGINGFACE_LLAMA = ("HUGGINGFACE_LLAMA", "HF_TOKEN")
68
  GOOGLE_GEMINI = ("Google Gemini", "GOOGLE_API_KEY")
69
 
70
  class Model:
71
+ def __init__(self, provider: LLMProvider = LLMProvider.HUGGINGFACE):
72
+ load_dotenv(find_dotenv())
 
73
  self.system_prompt = get_prompt()
74
  print(f"system_prompt: {self.system_prompt}")
75
+ self.provider = provider
76
  self.agent_executor = self.setup_model()
77
+
78
+
79
 
80
  def get_answer(self, question: str) -> str:
81
  try:
 
110
  if provider == LLMProvider.HUGGINGFACE:
111
  llm = HuggingFaceEndpoint(
112
  repo_id="Qwen/Qwen3-Next-80B-A3B-Thinking",
113
+ huggingfacehub_api_token=api_token,
114
+ temperature=0
115
+ )
116
+ return ChatHuggingFace(llm=llm).bind_tools(tools)
117
+
118
+ if provider == LLMProvider.HUGGINGFACE_LLAMA:
119
+ llm = HuggingFaceEndpoint(
120
+ repo_id="meta-llama/Llama-2-7b-chat-hf",
121
+ huggingfacehub_api_token=api_token,
122
  temperature=0
123
  )
124
  return ChatHuggingFace(llm=llm).bind_tools(tools)
125
+
126
 
127
 
128
  elif provider == LLMProvider.GOOGLE_GEMINI:
 
130
  model="gemini-2.5-flash",
131
  temperature=0
132
  )
133
+ # search = GoogleSearchAPIWrapper()
134
+
135
+ # # Define the Google Search tool correctly
136
+ # google_search_tool = Tool(
137
+ # name="Google Search",
138
+ # description="Search Google for recent information.",
139
+ # func=search.run, # Use the run method to execute the search directly
140
+ # )
141
+
142
+ # tools.append(google_search_tool)
143
+ return chat.bind_tools(tools)
144
 
145
  else:
146
  raise ValueError(f"Unknown LLM provider: {provider}")
 
163
  tavily_search_tool,
164
  arxiv_search,
165
  ]
166
+ chat = self.get_chat_with_tools(self.provider, tools)
167
 
168
  # Create the ReAct prompt template
169
  prompt = ChatPromptTemplate.from_messages(
170
  [
171
  ("system", self.system_prompt), # Use the new, detailed ReAct prompt
 
172
  ("human", "{input}"),
173
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
174
+
175
  ]
176
  )
177
 
 
209
 
210
  def main():
211
  load_dotenv(find_dotenv())
212
+ csv = CSVSCache()
213
+ df = csv.get_all_entries()
214
+ model = Model(LLMProvider.HUGGINGFACE)
215
  #update_mode(model)
216
+ test_questions = [0, 6, 10, 12, 15]
217
+ for row in test_questions:
218
+ response = model.get_answer(df.iloc[row]['question'])
219
+ print(f"the output is: {response}")
220
 
221
  if __name__ == "__main__":
222
  main()