vikramvasudevan commited on
Commit
c6893be
·
verified ·
1 Parent(s): 9d8a7d0

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +46 -12
  2. db.py +62 -7
  3. graph_helper.py +2 -0
  4. sanatan_assistant.py +33 -0
  5. tools.py +11 -2
app.py CHANGED
@@ -107,6 +107,11 @@ def chat(message, history, thread_id):
107
  return response["messages"][-1].content
108
 
109
 
 
 
 
 
 
110
  async def chat_streaming(message, history, thread_id):
111
  state = {"messages": (history or []) + [{"role": "user", "content": message}]}
112
  config = {"configurable": {"thread_id": thread_id}}
@@ -117,48 +122,58 @@ async def chat_streaming(message, history, thread_id):
117
 
118
  try:
119
  tool_calls = []
 
120
  async for msg, metadata in graph.astream(
121
  state, config=config, stream_mode="messages"
122
  ):
123
  node = metadata.get("langgraph_node", "?")
124
  name = getattr(msg, "name", "-")
 
 
 
 
 
 
 
 
 
 
 
 
125
  full: str = escape(msg.content)
126
  truncated = (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full
127
 
128
- processing_message = (
 
129
  f"<div class='thinking-bubble'><em>🤔{random.choice(thinking_verbs)} ...</em></div>"
130
  f"<div style='opacity: 0.1' title='{full}'>"
131
  f"<span>{node}:{name or ''}:</span>"
132
  f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
133
  f"</div>"
134
  )
 
135
  if (
136
  not isinstance(msg, ToolMessage)
137
  and not isinstance(msg, SystemMessage)
138
  and not isinstance(msg, AIMessageChunk)
139
  ):
140
  logger.info("msg = %s", msg)
141
- # yield processing_message
142
  if isinstance(msg, ToolMessage):
143
  logger.debug("tool message = %s", msg)
144
  html = (
145
- f"<div class='thinking-bubble'><em>🤔{name} : {random.choice(thinking_verbs)} ...</em></div>"
146
  f"<div style='opacity: 0.5'>"
147
  f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
148
  f"</div>"
149
  )
150
- yield html
151
- # yield f"""
152
- # <div class='thinking-bubble'>🤔 {random.choice(thinking_verbs)}...</div>
153
- # <div style='opacity: 0.5'><strong>[{node} - {name}]</strong>: {escape(msg.content)}</div>"""
154
-
155
  elif isinstance(msg, AIMessageChunk):
156
  if not msg.content:
157
  # logger.warning("*** No Message Chunk!")
158
- yield processing_message
159
  else:
160
  streamed_response += msg.content
161
- yield streamed_response
162
  if(msg.tool_calls):
163
  tool_calls.append(msg.tool_calls)
164
  elif isinstance(msg, AIMessage):
@@ -176,9 +191,10 @@ async def chat_streaming(message, history, thread_id):
176
  f"<strong>Telling myself:</strong> {truncated or '...'}"
177
  f"</div>"
178
  )
179
- yield html
180
 
181
- yield streamed_response
 
182
  except Exception as e:
183
  yield f"Error processing request {str(e)}"
184
 
@@ -236,6 +252,24 @@ chatInterface = gr.ChatInterface(
236
  chatbot=chatbot,
237
  css="""
238
  <style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  .thinking-bubble {
240
  opacity: 0.5;
241
  font-style: italic;
 
107
  return response["messages"][-1].content
108
 
109
 
110
+ def add_node_to_tree(node_tree : list[str], node : str) -> list[str]:
111
+ node_tree[-1] = node
112
+ node_tree.append("<span class='spinner'>⏳</span>")
113
+ return node_tree
114
+
115
  async def chat_streaming(message, history, thread_id):
116
  state = {"messages": (history or []) + [{"role": "user", "content": message}]}
117
  config = {"configurable": {"thread_id": thread_id}}
 
122
 
123
  try:
124
  tool_calls = []
125
+ node_tree = ["__start__","<span class='spinner'>⏳</span>"]
126
  async for msg, metadata in graph.astream(
127
  state, config=config, stream_mode="messages"
128
  ):
129
  node = metadata.get("langgraph_node", "?")
130
  name = getattr(msg, "name", "-")
131
+ if(not isinstance(msg, ToolMessage)):
132
+ node_icon = "🧠"
133
+ else:
134
+ node_icon = "⚙️"
135
+ node_label = f"node:{node}"
136
+ tool_label =f"{name or ''}"
137
+ if(tool_label):
138
+ node_label = node_label + f":{tool_label}"
139
+ label = f"{node_icon} {node_label}"
140
+ # checking for -2 last but one. since last entry is always a spinner
141
+ if(node_tree[-2] != label):
142
+ add_node_to_tree(node_tree, label)
143
  full: str = escape(msg.content)
144
  truncated = (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full
145
 
146
+ def generate_processing_message():
147
+ return (
148
  f"<div class='thinking-bubble'><em>🤔{random.choice(thinking_verbs)} ...</em></div>"
149
  f"<div style='opacity: 0.1' title='{full}'>"
150
  f"<span>{node}:{name or ''}:</span>"
151
  f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
152
  f"</div>"
153
  )
154
+
155
  if (
156
  not isinstance(msg, ToolMessage)
157
  and not isinstance(msg, SystemMessage)
158
  and not isinstance(msg, AIMessageChunk)
159
  ):
160
  logger.info("msg = %s", msg)
 
161
  if isinstance(msg, ToolMessage):
162
  logger.debug("tool message = %s", msg)
163
  html = (
164
+ f"<div class='thinking-bubble'><em>🤔{name} tool: {random.choice(thinking_verbs)} ...</em></div>"
165
  f"<div style='opacity: 0.5'>"
166
  f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
167
  f"</div>"
168
  )
169
+ yield f"### { " → ".join(node_tree)}\n{html}"
 
 
 
 
170
  elif isinstance(msg, AIMessageChunk):
171
  if not msg.content:
172
  # logger.warning("*** No Message Chunk!")
173
+ yield f"### { " → ".join(node_tree)}\n{generate_processing_message()}"
174
  else:
175
  streamed_response += msg.content
176
+ yield f"### { " → ".join(node_tree)}\n{streamed_response}"
177
  if(msg.tool_calls):
178
  tool_calls.append(msg.tool_calls)
179
  elif isinstance(msg, AIMessage):
 
191
  f"<strong>Telling myself:</strong> {truncated or '...'}"
192
  f"</div>"
193
  )
194
+ yield f"### { " → ".join(node_tree)}\n{html}"
195
 
196
+ node_tree[-1] = "✅"
197
+ yield f"### { " → ".join(node_tree)}\n{streamed_response}"
198
  except Exception as e:
199
  yield f"Error processing request {str(e)}"
200
 
 
252
  chatbot=chatbot,
253
  css="""
254
  <style>
255
+ .spinner {
256
+ display: inline-block;
257
+ width: 1em;
258
+ height: 1em;
259
+ border: 2px solid #ccc;
260
+ border-top: 2px solid #333;
261
+ border-radius: 50%;
262
+ animation: spin 1s linear infinite;
263
+ vertical-align: middle;
264
+ margin-left: 0.5em;
265
+ }
266
+
267
+ @keyframes spin {
268
+ 0% { transform: rotate(0deg); }
269
+ 100% { transform: rotate(360deg); }
270
+ }
271
+
272
+
273
  .thinking-bubble {
274
  opacity: 0.5;
275
  font-style: italic;
db.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Literal
2
  import chromadb
3
-
4
  from config import SanatanConfig
5
  from embeddings import get_embedding
6
  import logging
@@ -38,6 +38,60 @@ class SanatanDatabase:
38
  )
39
  return response
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def search_by_metadata(
42
  self,
43
  collection_name: str,
@@ -81,10 +135,11 @@ if __name__ == "__main__":
81
  query = input("Search for: ")
82
  if query.strip() == "":
83
  break
84
- response = database.search(
85
- collection_name=collection_name, query=query, n_results=1
86
  )
87
- print("Document: ")
88
- print(response["documents"][0][0])
89
- print("Metadata: ")
90
- print(response["metadatas"][0][0])
 
 
1
  from typing import Literal
2
  import chromadb
3
+ import re, unicodedata
4
  from config import SanatanConfig
5
  from embeddings import get_embedding
6
  import logging
 
38
  )
39
  return response
40
 
41
+ def search_for_literal(
42
+ self, collection_name: str, literal_to_search_for: str, n_results=2
43
+ ):
44
+ logger.info(
45
+ "Searching literally for [%s] in [%s]",
46
+ literal_to_search_for,
47
+ collection_name,
48
+ )
49
+ collection = self.chroma_client.get_or_create_collection(name=collection_name)
50
+
51
+ def normalize(text):
52
+ return unicodedata.normalize("NFKC", text).lower()
53
+
54
+ # 1. Try native contains
55
+ response = collection.query(
56
+ query_texts=[""],
57
+ where_document={"$contains": literal_to_search_for},
58
+ n_results=n_results,
59
+ )
60
+
61
+ if response["documents"] and any(response["documents"]):
62
+ return response
63
+
64
+ # 2. Regex fallback (normalized)
65
+ regex = re.compile(re.escape(normalize(literal_to_search_for)))
66
+
67
+ all_docs = collection.get()
68
+ matched_docs = []
69
+
70
+ for doc, metadata, ids in zip(
71
+ all_docs["documents"], all_docs["metadatas"], all_docs["ids"]
72
+ ):
73
+ for i, d in enumerate(doc):
74
+ if regex.search(normalize(d)):
75
+ matched_docs.append(
76
+ {
77
+ "id": ids[i],
78
+ "document": d,
79
+ "metadata": (
80
+ metadata[i] if isinstance(metadata, list) else metadata
81
+ ),
82
+ }
83
+ )
84
+ if len(matched_docs) >= n_results:
85
+ break
86
+ if len(matched_docs) >= n_results:
87
+ break
88
+
89
+ return {
90
+ "documents": [[d["document"] for d in matched_docs]],
91
+ "ids": [[d["id"] for d in matched_docs]],
92
+ "metadatas": [[d["metadata"] for d in matched_docs]],
93
+ }
94
+
95
  def search_by_metadata(
96
  self,
97
  collection_name: str,
 
135
  query = input("Search for: ")
136
  if query.strip() == "":
137
  break
138
+ response = database.search_for_literal(
139
+ collection_name=collection_name, literal_to_search_for=query, n_results=1
140
  )
141
+ print("Matches" , response)
142
+ # print("Document: ")
143
+ # print(response["documents"][0][0])
144
+ # print("Metadata: ")
145
+ # print(response["metadatas"][0][0])
graph_helper.py CHANGED
@@ -13,6 +13,7 @@ from tools import (
13
  tool_get_standardized_azhwar_names,
14
  tool_search_db_by_metadata,
15
  tool_get_standardized_divya_desam_names,
 
16
  )
17
  from langgraph.prebuilt import ToolNode, tools_condition
18
  from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage
@@ -37,6 +38,7 @@ def generate_graph() -> CompiledStateGraph:
37
  tool_get_standardized_prabandham_names,
38
  tool_get_standardized_divya_desam_names,
39
  tool_search_db_by_metadata,
 
40
  ]
41
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.2).bind_tools(tools)
42
 
 
13
  tool_get_standardized_azhwar_names,
14
  tool_search_db_by_metadata,
15
  tool_get_standardized_divya_desam_names,
16
+ tool_search_db_for_literal
17
  )
18
  from langgraph.prebuilt import ToolNode, tools_condition
19
  from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage
 
38
  tool_get_standardized_prabandham_names,
39
  tool_get_standardized_divya_desam_names,
40
  tool_search_db_by_metadata,
41
+ tool_search_db_for_literal
42
  ]
43
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.2).bind_tools(tools)
44
 
sanatan_assistant.py CHANGED
@@ -206,3 +206,36 @@ def query_by_metadata_field(
206
  response["documents"], response["metadatas"], response["ids"]
207
  )
208
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  response["documents"], response["metadatas"], response["ids"]
207
  )
208
  )
209
+
210
+
211
+ def query_by_literal_text(
212
+ collection_name: allowedCollections,
213
+ literal_to_search_for: str,
214
+ n_results=5,
215
+ ):
216
+ """
217
+ Search a scripture collection by a literal. Do NOT use this for semantic search. Only use when the user specifically asks for literal search.
218
+
219
+ Parameters:
220
+ - collection_name (str): The name of the scripture collection to search. ...
221
+ - literal_to_search_for (str): The search query.
222
+ - n_results (int): Number of results to return. Default is 5.
223
+
224
+ Returns:
225
+ - A list of matching results.
226
+ """
227
+ logger.info("Performing literal search in collection [%s] for [%s]", collection_name, literal_to_search_for)
228
+
229
+
230
+ response = sanatanDatabase.search_for_literal(
231
+ collection_name=collection_name,
232
+ literal_to_search_for=literal_to_search_for,
233
+ n_results=n_results,
234
+ )
235
+
236
+ return "\n\n".join(
237
+ f"Document: {doc}\nMetadata: {meta}\nID: {id_}"
238
+ for doc, meta, id_ in zip(
239
+ response["documents"], response["metadatas"], response["ids"]
240
+ )
241
+ )
tools.py CHANGED
@@ -6,7 +6,7 @@ from config import SanatanConfig
6
  from nalayiram_helper import get_standardized_azhwar_names, get_standardized_divya_desam_names
7
  from push_notifications_helper import push
8
  from serperdev_helper import search as search_web
9
- from sanatan_assistant import format_scripture_answer, query, query_by_metadata_field
10
 
11
  tool_push = Tool(
12
  name="push", description="Send a push notification to the user", func=push
@@ -17,7 +17,16 @@ allowed_collections = [s["collection_name"] for s in SanatanConfig.scriptures]
17
  tool_search_db = StructuredTool.from_function(
18
  query,
19
  description=(
20
- "Do a vector Search within a specific scripture collection. "
 
 
 
 
 
 
 
 
 
21
  f"The collection_name must be one of: {', '.join(allowed_collections)}."
22
  "Use this to find relevant scripture verses or explanations based on the given query."
23
  # "If the query doesn't yield any relevant results, then call `tool_search_db_by_metadata` tool to search specifically by a given metadata field (only if specific field from metadata has been mentioned)."
 
6
  from nalayiram_helper import get_standardized_azhwar_names, get_standardized_divya_desam_names
7
  from push_notifications_helper import push
8
  from serperdev_helper import search as search_web
9
+ from sanatan_assistant import format_scripture_answer, query, query_by_metadata_field, query_by_literal_text
10
 
11
  tool_push = Tool(
12
  name="push", description="Send a push notification to the user", func=push
 
17
  tool_search_db = StructuredTool.from_function(
18
  query,
19
  description=(
20
+ "Do a semantic vector search within a specific scripture collection. "
21
+ f"The collection_name must be one of: {', '.join(allowed_collections)}."
22
+ "Use this to narrow down relevant scripture verses or explanations based on the given query."
23
+ ),
24
+ )
25
+
26
+ tool_search_db_for_literal = StructuredTool.from_function(
27
+ query_by_literal_text,
28
+ description=(
29
+ "Do a literal search within a specific scripture collection (only if user specifically asks for a literal search or if semantic search does not yield relevant results)."
30
  f"The collection_name must be one of: {', '.join(allowed_collections)}."
31
  "Use this to find relevant scripture verses or explanations based on the given query."
32
  # "If the query doesn't yield any relevant results, then call `tool_search_db_by_metadata` tool to search specifically by a given metadata field (only if specific field from metadata has been mentioned)."