Nymbo commited on
Commit
e30c5c2
·
verified ·
1 Parent(s): 37dcc6f

adding FIlterer process for better Deep_Research reports

Browse files
Files changed (1) hide show
  1. Modules/Deep_Research.py +208 -79
Modules/Deep_Research.py CHANGED
@@ -4,10 +4,10 @@ import os
4
  import re
5
  import tempfile
6
  import time
7
- from collections import deque
8
  from concurrent.futures import Future, ThreadPoolExecutor, as_completed
9
  from datetime import datetime
10
- from typing import Annotated, Dict, List, Tuple
11
  from urllib.parse import urlparse
12
 
13
  import gradio as gr
@@ -63,6 +63,14 @@ RESEARCHER_SYSTEM_PROMPT = (
63
  "</planning_rules>\n\n"
64
  )
65
 
 
 
 
 
 
 
 
 
66
 
67
  class SlowHost(Exception):
68
  pass
@@ -161,6 +169,51 @@ def _build_research_prompt(summary: str, queries: List[str], url_list: List[str]
161
  return "\n\n".join(prompt_parts)
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def _write_report_tmp(text: str) -> str:
165
  tmp_dir = tempfile.mkdtemp(prefix="deep_research_")
166
  path = os.path.join(tmp_dir, "research_report.txt")
@@ -169,6 +222,76 @@ def _write_report_tmp(text: str) -> str:
169
  return path
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  @autodoc(
173
  summary=TOOL_SUMMARY,
174
  )
@@ -217,6 +340,11 @@ def Deep_Research(
217
  def time_left() -> float:
218
  return max(0.0, deadline - time.time())
219
 
 
 
 
 
 
220
  all_urls: list[str] = []
221
  tasks = []
222
  with ThreadPoolExecutor(max_workers=min(5, sum(1 for q in queries if q.strip())) or 1) as executor:
@@ -279,71 +407,79 @@ def Deep_Research(
279
  return any(path.endswith(ext) for ext in skip_exts)
280
 
281
  all_urls = [url for url in all_urls if not _skip_url(url)]
282
- pages: dict[str, str] = {}
283
- if all_urls:
284
- queue = deque(all_urls)
285
- attempts: dict[str, int] = {url: 0 for url in all_urls}
286
- max_attempts = 2
287
- max_workers = min(12, max(4, len(all_urls)))
288
- in_flight: dict[Future, str] = {}
289
- delayed: list[tuple[float, str]] = []
290
-
291
- def schedule_next(executor: ThreadPoolExecutor) -> None:
292
- while queue and len(in_flight) < max_workers:
293
- url = queue.popleft()
294
- if url in pages:
295
- continue
296
- if attempts[url] >= max_attempts:
297
- continue
298
- attempts[url] += 1
299
- tl = time_left()
300
- per_timeout = 10.0 if tl > 15 else (5.0 if tl > 8 else 2.0)
301
- future = executor.submit(_fetch_page_markdown_fast, url, 3000, per_timeout)
302
- in_flight[future] = url
303
-
304
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
305
- schedule_next(executor)
306
- while (in_flight or queue) and time_left() > 0.2:
307
- now = time.time()
308
- if delayed:
309
- ready = []
310
- not_ready = []
311
- for ready_time, url in delayed:
312
- (ready if ready_time <= now else not_ready).append((ready_time, url))
313
- delayed = not_ready
314
- for _, url in ready:
315
- queue.append(url)
316
- if ready:
317
- schedule_next(executor)
318
- done = [future for future in list(in_flight.keys()) if future.done()]
319
- if not done:
320
- if not queue and delayed:
321
- sleep_for = max(0.02, min(0.25, max(0.0, min(t for t, _ in delayed) - time.time())))
322
- time.sleep(sleep_for)
323
- else:
324
- time.sleep(0.05)
325
- else:
326
- for future in done:
327
- url = in_flight.pop(future)
328
- try:
329
- md = future.result()
330
- if md and not md.startswith("Unsupported content type") and not md.startswith("An error occurred"):
331
- pages[url] = md
332
- try:
333
- print(f"[FETCH OK] {url} (chars={len(md)})", flush=True)
334
- except Exception:
335
- pass
336
- except SlowHost:
337
- if time_left() > 5.0:
338
- delayed.append((time.time() + 3.0, url))
339
- except Exception:
340
- pass
341
- schedule_next(executor)
 
 
 
 
 
 
 
 
 
 
 
 
342
  prompt = _build_research_prompt(summary=summary or "", queries=[q for q in queries if q.strip()], url_list=list(pages.keys()), pages_map=pages)
343
- now = datetime.now().astimezone()
344
- date_str = now.strftime("%A, %B %d, %Y %I:%M %p %Z").strip()
345
- if not date_str:
346
- date_str = now.isoformat()
347
  system_message = {"role": "system", "content": RESEARCHER_SYSTEM_PROMPT}
348
  date_message = {"role": "user", "content": f"The current date is {date_str}. Return only the research report."}
349
  messages = [
@@ -358,19 +494,9 @@ def Deep_Research(
358
  print(f"[PIPELINE] Fetch complete: pages={len(pages)}, unique_urls={len(pages.keys())}, prompt_chars={prompt_chars}", flush=True)
359
  print("[PIPELINE] Starting inference (provider=cerebras, model=Qwen/Qwen3-235B-A22B-Thinking-2507)", flush=True)
360
 
361
- def _run_inference(provider: str, max_tokens: int, temp: float, top_p: float):
362
- client = InferenceClient(provider=provider, api_key=HF_TEXTGEN_TOKEN)
363
- return client.chat.completions.create(
364
- model="Qwen/Qwen3-235B-A22B-Thinking-2507",
365
- messages=messages,
366
- max_tokens=max_tokens,
367
- temperature=temp,
368
- top_p=top_p,
369
- )
370
-
371
  try:
372
  print("[LLM] Attempt 1: provider=cerebras, max_tokens=32768", flush=True)
373
- completion = _run_inference("cerebras", max_tokens=32768, temp=0.3, top_p=0.95)
374
  except Exception as exc1:
375
  print(f"[LLM] Attempt 1 failed: {str(exc1)[:200]}", flush=True)
376
  try:
@@ -386,12 +512,12 @@ def Deep_Research(
386
  {"role": "user", "content": prompt2},
387
  ]
388
  print("[LLM] Attempt 2: provider=cerebras (trimmed), max_tokens=16384", flush=True)
389
- completion = _run_inference("cerebras", max_tokens=16384, temp=0.7, top_p=0.95)
390
  except Exception as exc2:
391
  print(f"[LLM] Attempt 2 failed: {str(exc2)[:200]}", flush=True)
392
  try:
393
  print("[LLM] Attempt 3: provider=auto, max_tokens=8192", flush=True)
394
- completion = _run_inference("auto", max_tokens=8192, temp=0.7, top_p=0.95)
395
  except Exception as exc3:
396
  _log_call_end("Deep_Research", f"error={_truncate_for_log(str(exc3), 260)}")
397
  raise gr.Error(f"Researcher model call failed: {exc3}")
@@ -423,6 +549,9 @@ def Deep_Research(
423
  except Exception:
424
  pass
425
  links_text = "\n".join([f"[{i+1}] {url}" for i, url in enumerate(pages.keys())])
 
 
 
426
  file_path = _write_report_tmp(report)
427
  elapsed = time.time() - start_ts
428
  print(f"[TIMING] Deep_Research elapsed: {elapsed:.2f}s", flush=True)
 
4
  import re
5
  import tempfile
6
  import time
7
+ from collections import OrderedDict, deque
8
  from concurrent.futures import Future, ThreadPoolExecutor, as_completed
9
  from datetime import datetime
10
+ from typing import Annotated, Callable, Dict, List, Tuple
11
  from urllib.parse import urlparse
12
 
13
  import gradio as gr
 
63
  "</planning_rules>\n\n"
64
  )
65
 
66
+ FILTERER_SYSTEM_PROMPT = (
67
+ "You are Nymbot Filterer, an analyst who selects the most relevant sources for a research task. "
68
+ "You will be given a summary of the research topic (and optional search queries) followed by multiple fetched documents. "
69
+ "Each document includes its URL and a truncated excerpt. Evaluate how well each source helps answer the research topic. "
70
+ "Return only the URLs that should be used for the final research step. Output plain text with exactly one URL per line and no additional commentary, bullets, numbering, or explanations. "
71
+ "If no sources are relevant, return an empty string."
72
+ )
73
+
74
 
75
  class SlowHost(Exception):
76
  pass
 
169
  return "\n\n".join(prompt_parts)
170
 
171
 
172
+ def _build_filter_prompt(summary: str, queries: List[str], pages_map: Dict[str, str]) -> str:
173
+ populated = [q for q in queries if q and q.strip()]
174
+ summary_text = summary or ""
175
+ prompt_sections: List[str] = []
176
+ prompt_sections.append("<research_topic_summary>\n" + summary_text + "\n</research_topic_summary>")
177
+ if populated:
178
+ prompt_sections.append("<search_queries>\n" + "\n".join(populated) + "\n</search_queries>")
179
+ sources: List[str] = []
180
+ for idx, (url, text) in enumerate(pages_map.items(), start=1):
181
+ content = text.strip()
182
+ if not content:
183
+ continue
184
+ sources.append(f"[Source {idx}] URL: {url}\n\n{content}")
185
+ sources_joined, truncated = _truncate_join(sources, max_chars=60_000)
186
+ prompt_sections.append("<candidate_sources>\n" + sources_joined + ("\n\n[NOTE] Sources truncated due to context limits." if truncated else "") + "\n</candidate_sources>")
187
+ prompt_sections.append(
188
+ "<task>\nIdentify which of the provided URLs should be retained for the final research synthesis. "
189
+ "Consider coverage, credibility, and relevance to the research topic. "
190
+ "Return ONLY the URLs you choose, with one URL per line and no additional text.\n</task>"
191
+ )
192
+ return "\n\n".join(prompt_sections)
193
+
194
+
195
+ def _parse_filterer_output(raw: str, allowed_urls: List[str]) -> List[str]:
196
+ if not raw:
197
+ return []
198
+ allowed_set = {url.strip(): idx for idx, url in enumerate(allowed_urls)}
199
+ found_indices: set[int] = set()
200
+ for line in raw.splitlines():
201
+ candidate = line.strip()
202
+ if not candidate:
203
+ continue
204
+ if candidate in allowed_set:
205
+ found_indices.add(allowed_set[candidate])
206
+ continue
207
+ match = re.search(r"https?://[^\s]+", candidate)
208
+ if not match:
209
+ continue
210
+ url = match.group(0).rstrip(".,);]")
211
+ if url in allowed_set:
212
+ found_indices.add(allowed_set[url])
213
+ selected = [allowed_urls[idx] for idx in sorted(found_indices)]
214
+ return selected
215
+
216
+
217
  def _write_report_tmp(text: str) -> str:
218
  tmp_dir = tempfile.mkdtemp(prefix="deep_research_")
219
  path = os.path.join(tmp_dir, "research_report.txt")
 
222
  return path
223
 
224
 
225
+ def _fetch_pages_within_budget(urls: List[str], char_limit: int, time_left_fn: Callable[[], float]) -> OrderedDict:
226
+ pages: dict[str, str] = {}
227
+ if not urls:
228
+ return OrderedDict()
229
+ queue = deque(urls)
230
+ attempts: dict[str, int] = {url: 0 for url in urls}
231
+ max_attempts = 2
232
+ max_workers = min(12, max(4, len(urls)))
233
+ in_flight: dict[Future, str] = {}
234
+ delayed: list[tuple[float, str]] = []
235
+
236
+ def schedule_next(executor: ThreadPoolExecutor) -> None:
237
+ while queue and len(in_flight) < max_workers:
238
+ url = queue.popleft()
239
+ if url in pages:
240
+ continue
241
+ attempts.setdefault(url, 0)
242
+ if attempts[url] >= max_attempts:
243
+ continue
244
+ attempts[url] += 1
245
+ tl = time_left_fn()
246
+ if tl <= 0.1:
247
+ return
248
+ per_timeout = 10.0 if tl > 15 else (5.0 if tl > 8 else 2.0)
249
+ future = executor.submit(_fetch_page_markdown_fast, url, char_limit, per_timeout)
250
+ in_flight[future] = url
251
+
252
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
253
+ schedule_next(executor)
254
+ while (in_flight or queue or delayed) and time_left_fn() > 0.2:
255
+ now = time.time()
256
+ if delayed:
257
+ ready: list[tuple[float, str]] = []
258
+ not_ready: list[tuple[float, str]] = []
259
+ for ready_time, delayed_url in delayed:
260
+ (ready if ready_time <= now else not_ready).append((ready_time, delayed_url))
261
+ delayed = not_ready
262
+ for _, delayed_url in ready:
263
+ queue.append(delayed_url)
264
+ if ready:
265
+ schedule_next(executor)
266
+ done = [future for future in list(in_flight.keys()) if future.done()]
267
+ if not done:
268
+ if not queue and delayed:
269
+ next_ready = min((t for t, _ in delayed), default=time.time())
270
+ sleep_for = max(0.0, next_ready - time.time())
271
+ time.sleep(max(0.02, min(0.25, sleep_for)))
272
+ else:
273
+ time.sleep(0.05)
274
+ continue
275
+ for future in done:
276
+ url = in_flight.pop(future)
277
+ try:
278
+ md = future.result()
279
+ if md and not md.startswith("Unsupported content type") and not md.startswith("An error occurred"):
280
+ pages[url] = md
281
+ try:
282
+ print(f"[FETCH OK] {url} (chars={len(md)})", flush=True)
283
+ except Exception:
284
+ pass
285
+ except SlowHost:
286
+ if time_left_fn() > 5.0:
287
+ delayed.append((time.time() + 3.0, url))
288
+ except Exception:
289
+ pass
290
+ schedule_next(executor)
291
+ ordered = OrderedDict((url, pages[url]) for url in urls if url in pages)
292
+ return ordered
293
+
294
+
295
  @autodoc(
296
  summary=TOOL_SUMMARY,
297
  )
 
340
  def time_left() -> float:
341
  return max(0.0, deadline - time.time())
342
 
343
+ now_dt = datetime.now().astimezone()
344
+ date_str = now_dt.strftime("%A, %B %d, %Y %I:%M %p %Z").strip()
345
+ if not date_str:
346
+ date_str = now_dt.isoformat()
347
+
348
  all_urls: list[str] = []
349
  tasks = []
350
  with ThreadPoolExecutor(max_workers=min(5, sum(1 for q in queries if q.strip())) or 1) as executor:
 
407
  return any(path.endswith(ext) for ext in skip_exts)
408
 
409
  all_urls = [url for url in all_urls if not _skip_url(url)]
410
+ truncated_pages = OrderedDict()
411
+ if all_urls and time_left() > 0.2:
412
+ truncated_pages = _fetch_pages_within_budget(all_urls, 3000, time_left)
413
+ print(
414
+ f"[PIPELINE] Initial fetch complete: candidates={len(all_urls)}, truncated_documents={len(truncated_pages)}, time_left={time_left():.2f}s",
415
+ flush=True,
416
+ )
417
+
418
+ def _invoke_chat(messages, provider: str, max_tokens: int, temp: float, top_p: float):
419
+ client = InferenceClient(provider=provider, api_key=HF_TEXTGEN_TOKEN)
420
+ return client.chat.completions.create(
421
+ model="Qwen/Qwen3-235B-A22B-Thinking-2507",
422
+ messages=messages,
423
+ max_tokens=max_tokens,
424
+ temperature=temp,
425
+ top_p=top_p,
426
+ )
427
+
428
+ filtered_urls: List[str] = list(truncated_pages.keys())
429
+ filter_output = ""
430
+ filter_used_fallback = False
431
+ filter_success = False
432
+ if truncated_pages and time_left() > 3.0:
433
+ filter_prompt = _build_filter_prompt(summary or "", [q for q in queries if q.strip()], truncated_pages)
434
+ filter_messages = [
435
+ {"role": "system", "content": FILTERER_SYSTEM_PROMPT},
436
+ {"role": "user", "content": f"The current date is {date_str}. Consider how recent each source is when deciding relevance."},
437
+ {"role": "user", "content": filter_prompt},
438
+ ]
439
+ filter_completion = None
440
+ try:
441
+ print("[FILTER] Attempt 1: provider=cerebras, max_tokens=2048", flush=True)
442
+ filter_completion = _invoke_chat(filter_messages, "cerebras", 2048, 0.2, 0.9)
443
+ except Exception as exc1:
444
+ print(f"[FILTER] Attempt 1 failed: {str(exc1)[:200]}", flush=True)
445
+ try:
446
+ print("[FILTER] Attempt 2: provider=auto, max_tokens=2048", flush=True)
447
+ filter_completion = _invoke_chat(filter_messages, "auto", 2048, 0.2, 0.9)
448
+ except Exception as exc2:
449
+ print(f"[FILTER] Attempt 2 failed: {str(exc2)[:200]}", flush=True)
450
+ if filter_completion and filter_completion.choices:
451
+ filter_output = filter_completion.choices[0].message.content or ""
452
+ filtered_urls = _parse_filterer_output(filter_output, list(truncated_pages.keys()))
453
+ filter_success = bool(filter_output.strip()) and bool(filtered_urls)
454
+ if not filtered_urls:
455
+ filter_used_fallback = True
456
+ fallback_count = min(8, len(truncated_pages))
457
+ filtered_urls = list(truncated_pages.keys())[:fallback_count]
458
+ max_final_urls = 20
459
+ if len(filtered_urls) > max_final_urls:
460
+ filter_used_fallback = True
461
+ filtered_urls = filtered_urls[:max_final_urls]
462
+ if not filter_success:
463
+ filter_used_fallback = True
464
+ print(
465
+ f"[FILTER] Selected URLs={len(filtered_urls)}, fallback={filter_used_fallback}, time_left={time_left():.2f}s",
466
+ flush=True,
467
+ )
468
+
469
+ final_pages_fetched = OrderedDict()
470
+ if filtered_urls and time_left() > 0.2:
471
+ final_pages_fetched = _fetch_pages_within_budget(filtered_urls, 8000, time_left)
472
+ merged_pages = OrderedDict()
473
+ for url in filtered_urls:
474
+ content = final_pages_fetched.get(url) or truncated_pages.get(url) or ""
475
+ if content:
476
+ merged_pages[url] = content
477
+ pages = merged_pages
478
+ print(
479
+ f"[PIPELINE] Final fetch complete: retained_documents={len(pages)}, time_left={time_left():.2f}s",
480
+ flush=True,
481
+ )
482
  prompt = _build_research_prompt(summary=summary or "", queries=[q for q in queries if q.strip()], url_list=list(pages.keys()), pages_map=pages)
 
 
 
 
483
  system_message = {"role": "system", "content": RESEARCHER_SYSTEM_PROMPT}
484
  date_message = {"role": "user", "content": f"The current date is {date_str}. Return only the research report."}
485
  messages = [
 
494
  print(f"[PIPELINE] Fetch complete: pages={len(pages)}, unique_urls={len(pages.keys())}, prompt_chars={prompt_chars}", flush=True)
495
  print("[PIPELINE] Starting inference (provider=cerebras, model=Qwen/Qwen3-235B-A22B-Thinking-2507)", flush=True)
496
 
 
 
 
 
 
 
 
 
 
 
497
  try:
498
  print("[LLM] Attempt 1: provider=cerebras, max_tokens=32768", flush=True)
499
+ completion = _invoke_chat(messages, "cerebras", max_tokens=32768, temp=0.3, top_p=0.95)
500
  except Exception as exc1:
501
  print(f"[LLM] Attempt 1 failed: {str(exc1)[:200]}", flush=True)
502
  try:
 
512
  {"role": "user", "content": prompt2},
513
  ]
514
  print("[LLM] Attempt 2: provider=cerebras (trimmed), max_tokens=16384", flush=True)
515
+ completion = _invoke_chat(messages, "cerebras", max_tokens=16384, temp=0.7, top_p=0.95)
516
  except Exception as exc2:
517
  print(f"[LLM] Attempt 2 failed: {str(exc2)[:200]}", flush=True)
518
  try:
519
  print("[LLM] Attempt 3: provider=auto, max_tokens=8192", flush=True)
520
+ completion = _invoke_chat(messages, "auto", max_tokens=8192, temp=0.7, top_p=0.95)
521
  except Exception as exc3:
522
  _log_call_end("Deep_Research", f"error={_truncate_for_log(str(exc3), 260)}")
523
  raise gr.Error(f"Researcher model call failed: {exc3}")
 
549
  except Exception:
550
  pass
551
  links_text = "\n".join([f"[{i+1}] {url}" for i, url in enumerate(pages.keys())])
552
+ if links_text:
553
+ sources_section = "\n\n## Sources\n" + "\n".join([f"[{i+1}] {url}" for i, url in enumerate(pages.keys())])
554
+ report = report.rstrip() + sources_section
555
  file_path = _write_report_tmp(report)
556
  elapsed = time.time() - start_ts
557
  print(f"[TIMING] Deep_Research elapsed: {elapsed:.2f}s", flush=True)