vikramvasudevan commited on
Commit
d7a5e89
Β·
verified Β·
1 Parent(s): e7a8938

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +109 -26
app.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import os
2
  import random
3
  import asyncio
4
  import logging
 
5
  import uuid
6
  from html import escape
7
  import gradio as gr
@@ -107,11 +109,30 @@ def chat(message, history, thread_id):
107
  return response["messages"][-1].content
108
 
109
 
110
- def add_node_to_tree(node_tree : list[str], node : str) -> list[str]:
111
- node_tree[-1] = node
112
- node_tree.append("<span class='spinner'>⏳</span>")
 
 
 
 
 
 
 
 
 
113
  return node_tree
114
 
 
 
 
 
 
 
 
 
 
 
115
  async def chat_streaming(message, history, thread_id):
116
  state = {"messages": (history or []) + [{"role": "user", "content": message}]}
117
  config = {"configurable": {"thread_id": thread_id}}
@@ -121,36 +142,43 @@ async def chat_streaming(message, history, thread_id):
121
  MAX_CONTENT = 500
122
 
123
  try:
124
- tool_calls = []
125
- node_tree = ["__start__","<span class='spinner'>⏳</span>"]
 
 
126
  async for msg, metadata in graph.astream(
127
  state, config=config, stream_mode="messages"
128
  ):
129
  node = metadata.get("langgraph_node", "?")
130
  name = getattr(msg, "name", "-")
131
- if(not isinstance(msg, ToolMessage)):
132
  node_icon = "🧠"
133
  else:
134
  node_icon = "βš™οΈ"
135
  node_label = f"node:{node}"
136
- tool_label =f"{name or ''}"
137
- if(tool_label):
138
  node_label = node_label + f":{tool_label}"
139
  label = f"{node_icon} {node_label}"
 
 
 
 
 
140
  # checking for -2 last but one. since last entry is always a spinner
141
- if(node_tree[-2] != label):
142
- add_node_to_tree(node_tree, label)
143
  full: str = escape(msg.content)
144
  truncated = (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full
145
 
146
  def generate_processing_message():
147
  return (
148
- f"<div class='thinking-bubble'><em>πŸ€”{random.choice(thinking_verbs)} ...</em></div>"
149
- f"<div style='opacity: 0.1' title='{full}'>"
150
- f"<span>{node}:{name or ''}:</span>"
151
- f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
152
- f"</div>"
153
- )
154
 
155
  if (
156
  not isinstance(msg, ToolMessage)
@@ -160,13 +188,17 @@ async def chat_streaming(message, history, thread_id):
160
  logger.info("msg = %s", msg)
161
  if isinstance(msg, ToolMessage):
162
  logger.debug("tool message = %s", msg)
 
163
  html = (
164
- f"<div class='thinking-bubble'><em>πŸ€”{name} tool: {random.choice(thinking_verbs)} ...</em></div>"
165
  f"<div style='opacity: 0.5'>"
166
- f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
 
 
167
  f"</div>"
168
  )
169
- yield f"### { " β†’ ".join(node_tree)}\n{html}"
 
170
  elif isinstance(msg, AIMessageChunk):
171
  if not msg.content:
172
  # logger.warning("*** No Message Chunk!")
@@ -174,11 +206,30 @@ async def chat_streaming(message, history, thread_id):
174
  else:
175
  streamed_response += msg.content
176
  yield f"### { " β†’ ".join(node_tree)}\n{streamed_response}"
177
- if(msg.tool_calls):
178
- tool_calls.append(msg.tool_calls)
179
- elif isinstance(msg, AIMessage):
180
- if(msg.tool_calls):
181
- tool_calls.append(msg.tool_calls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  else:
183
  logger.debug("message = ", type(msg), msg.content[:100])
184
  full: str = escape(msg.content)
@@ -192,11 +243,37 @@ async def chat_streaming(message, history, thread_id):
192
  f"</div>"
193
  )
194
  yield f"### { " β†’ ".join(node_tree)}\n{html}"
 
 
195
 
196
  node_tree[-1] = "βœ…"
197
- yield f"### { " β†’ ".join(node_tree)}\n{streamed_response}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
- yield f"Error processing request {str(e)}"
 
 
 
 
 
 
200
 
201
 
202
  # UI Elements
@@ -281,6 +358,12 @@ chatInterface = gr.ChatInterface(
281
  50% { opacity: 1; }
282
  100% { opacity: 0.3; }
283
  }
 
 
 
 
 
 
284
  """,
285
  )
286
 
 
1
+ import json
2
  import os
3
  import random
4
  import asyncio
5
  import logging
6
+ import traceback
7
  import uuid
8
  from html import escape
9
  import gradio as gr
 
109
  return response["messages"][-1].content
110
 
111
 
112
+ def add_node_to_tree(
113
+ node_tree: list[str], node_label: str, tooltip: str = "no arguments to show"
114
+ ) -> list[str]:
115
+ if tooltip:
116
+ tooltip = escape(tooltip).replace("'", "&apos;")
117
+ node_with_tooltip = (
118
+ f"<span class='node-label' title='{tooltip}'>{node_label}</span>"
119
+ )
120
+ else:
121
+ node_with_tooltip = node_label
122
+ node_tree[-1] = node_with_tooltip
123
+ node_tree.append("<span class='spinner'>&nbsp;</span>")
124
  return node_tree
125
 
126
+
127
+ def get_args_for_toolcall(tool_calls_buffer: dict, tool_call_id: str):
128
+ return (
129
+ tool_calls_buffer[tool_call_id]["args_str"]
130
+ if tool_call_id in tool_calls_buffer
131
+ and "args_str" in tool_calls_buffer[tool_call_id]
132
+ else ""
133
+ )
134
+
135
+
136
  async def chat_streaming(message, history, thread_id):
137
  state = {"messages": (history or []) + [{"role": "user", "content": message}]}
138
  config = {"configurable": {"thread_id": thread_id}}
 
142
  MAX_CONTENT = 500
143
 
144
  try:
145
+ node_tree = ["🚩", "<span class='spinner'>&nbsp;</span>"]
146
+
147
+ tool_calls_buffer = {}
148
+
149
  async for msg, metadata in graph.astream(
150
  state, config=config, stream_mode="messages"
151
  ):
152
  node = metadata.get("langgraph_node", "?")
153
  name = getattr(msg, "name", "-")
154
+ if not isinstance(msg, ToolMessage):
155
  node_icon = "🧠"
156
  else:
157
  node_icon = "βš™οΈ"
158
  node_label = f"node:{node}"
159
+ tool_label = f"{name or ''}"
160
+ if tool_label:
161
  node_label = node_label + f":{tool_label}"
162
  label = f"{node_icon} {node_label}"
163
+ tooltip = ""
164
+ if isinstance(msg, ToolMessage):
165
+ tooltip = get_args_for_toolcall(tool_calls_buffer, msg.tool_call_id)
166
+ logger.info("tooltip = ", tooltip)
167
+
168
  # checking for -2 last but one. since last entry is always a spinner
169
+ if node_tree[-2] != label:
170
+ add_node_to_tree(node_tree, label, tooltip)
171
  full: str = escape(msg.content)
172
  truncated = (full[:MAX_CONTENT] + "…") if len(full) > MAX_CONTENT else full
173
 
174
  def generate_processing_message():
175
  return (
176
+ f"<div class='thinking-bubble'><em>πŸ€”{random.choice(thinking_verbs)} ...</em></div>"
177
+ f"<div style='opacity: 0.1' title='{full}'>"
178
+ f"<span>{node}:{name or ''}:</span>"
179
+ f"<strong>Looking for : [{message}]</strong> {truncated or '...'}"
180
+ f"</div>"
181
+ )
182
 
183
  if (
184
  not isinstance(msg, ToolMessage)
 
188
  logger.info("msg = %s", msg)
189
  if isinstance(msg, ToolMessage):
190
  logger.debug("tool message = %s", msg)
191
+
192
  html = (
193
+ f"<div class='thinking-bubble'><em>πŸ€” {msg.name} tool: {random.choice(thinking_verbs)} ...</em></div>"
194
  f"<div style='opacity: 0.5'>"
195
+ f"<strong>Looking for : [{message}]</strong><br>"
196
+ f"<strong>Tool Args:</strong> {tooltip or '(no args)'}<br>"
197
+ f"{truncated or '...'}"
198
  f"</div>"
199
  )
200
+ yield f"### { ' β†’ '.join(node_tree)}\n{html}"
201
+ await asyncio.sleep(5)
202
  elif isinstance(msg, AIMessageChunk):
203
  if not msg.content:
204
  # logger.warning("*** No Message Chunk!")
 
206
  else:
207
  streamed_response += msg.content
208
  yield f"### { " β†’ ".join(node_tree)}\n{streamed_response}"
209
+
210
+ if msg.tool_call_chunks:
211
+ for tool_call_chunk in msg.tool_call_chunks:
212
+ print("*** tool_call_chunk = ", tool_call_chunk)
213
+ if tool_call_chunk["id"] is not None:
214
+ tool_call_id = tool_call_chunk["id"]
215
+
216
+ if tool_call_id not in tool_calls_buffer:
217
+ tool_calls_buffer[tool_call_id] = {
218
+ "name": "",
219
+ "args_str": "",
220
+ "id": tool_call_id,
221
+ "type": "tool_call",
222
+ }
223
+
224
+ # Accumulate tool call name and arguments
225
+ if tool_call_chunk["name"] is not None:
226
+ tool_calls_buffer[tool_call_id]["name"] += tool_call_chunk[
227
+ "name"
228
+ ]
229
+ if tool_call_chunk["args"] is not None:
230
+ tool_calls_buffer[tool_call_id][
231
+ "args_str"
232
+ ] += tool_call_chunk["args"]
233
  else:
234
  logger.debug("message = ", type(msg), msg.content[:100])
235
  full: str = escape(msg.content)
 
243
  f"</div>"
244
  )
245
  yield f"### { " β†’ ".join(node_tree)}\n{html}"
246
+ if getattr(msg, "tool_calls", []):
247
+ logger.info("ELSE::tool_calls = %s", msg.tool_calls)
248
 
249
  node_tree[-1] = "βœ…"
250
+
251
+ yield (
252
+ f"### {' β†’ '.join(node_tree)}"
253
+ f"\n{streamed_response}"
254
+ )
255
+
256
+ print("************************************")
257
+ # Now, you can process the complete tool calls from the buffer
258
+ for tool_call_id, accumulated_tool_call in tool_calls_buffer.items():
259
+ # Attempt to parse arguments only if the 'args_str' isn't empty
260
+ if accumulated_tool_call["args_str"]:
261
+ try:
262
+ parsed_args = json.loads(accumulated_tool_call["args_str"])
263
+ print(f"Tool Name: {accumulated_tool_call['name']}")
264
+ print(f"Tool Arguments: {parsed_args}")
265
+ except json.JSONDecodeError:
266
+ print(
267
+ f"Partial arguments for tool {accumulated_tool_call['name']}: {accumulated_tool_call['args_str']}"
268
+ )
269
  except Exception as e:
270
+ logger.error("❌❌❌ Error processing request: %s", e)
271
+ traceback.print_exc()
272
+ yield (
273
+ f"❌❌❌ Error processing request {str(e)}\n"
274
+ "here is what I got so far ...\n"
275
+ f"### { " β†’ ".join(node_tree)}\n{streamed_response}"
276
+ )
277
 
278
 
279
  # UI Elements
 
358
  50% { opacity: 1; }
359
  100% { opacity: 0.3; }
360
  }
361
+
362
+ .node-label {
363
+ cursor: help;
364
+ border-bottom: 1px dotted #aaa;
365
+ }
366
+
367
  """,
368
  )
369