Luigi commited on
Commit
a94befb
·
1 Parent(s): efa3ae6

Restore cancel generation feature with improved UI integration

Browse files
Files changed (1) hide show
  1. app.py +75 -25
app.py CHANGED
@@ -13,6 +13,9 @@ from ddgs import DDGS
13
  import spaces # Import spaces early to enable ZeroGPU support
14
  from torch.utils._pytree import tree_map
15
 
 
 
 
16
  access_token=os.environ['HF_TOKEN']
17
 
18
  # Optional: Disable GPU visibility if you wish to force CPU usage
@@ -402,7 +405,11 @@ def chat_response(user_msg, chat_history, system_prompt,
402
  top_k, top_p, repeat_penalty, search_timeout):
403
  """
404
  Generates streaming chat responses, optionally with background web search.
 
405
  """
 
 
 
406
  history = list(chat_history or [])
407
  history.append({'role': 'user', 'content': user_msg})
408
 
@@ -505,6 +512,12 @@ def chat_response(user_msg, chat_history, system_prompt,
505
 
506
  # Stream tokens
507
  for chunk in streamer:
 
 
 
 
 
 
508
  text = chunk
509
 
510
  # Detect start of thinking
@@ -560,6 +573,7 @@ def chat_response(user_msg, chat_history, system_prompt,
560
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
561
  yield history, debug
562
  finally:
 
563
  gc.collect()
564
 
565
 
@@ -616,34 +630,70 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
616
  st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
617
  clr = gr.Button("Clear Chat")
618
  with gr.Column(scale=7):
619
- chat = gr.Chatbot(type="messages")
620
- txt = gr.Textbox(placeholder="Type your message and press Enter...")
 
 
 
621
  dbg = gr.Markdown()
622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  # Update duration estimate when relevant inputs change
624
- model_dd.change(fn=update_duration_estimate,
625
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
626
- outputs=duration_display)
627
- search_chk.change(fn=update_duration_estimate,
628
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
629
- outputs=duration_display)
630
- max_tok.change(fn=update_duration_estimate,
631
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
632
- outputs=duration_display)
633
- mr.change(fn=update_duration_estimate,
634
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
635
- outputs=duration_display)
636
- mc.change(fn=update_duration_estimate,
637
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
638
- outputs=duration_display)
639
- st.change(fn=update_duration_estimate,
640
- inputs=[model_dd, search_chk, mr, mc, max_tok, st],
641
- outputs=duration_display)
642
 
 
643
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
644
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
645
- txt.submit(fn=chat_response,
646
- inputs=[txt, chat, sys_prompt, search_chk, mr, mc,
647
- model_dd, max_tok, temp, k, p, rp, st],
648
- outputs=[chat, dbg])
649
- demo.launch()
 
13
  import spaces # Import spaces early to enable ZeroGPU support
14
  from torch.utils._pytree import tree_map
15
 
16
+ # Global event to signal cancellation from the UI thread to the generation thread
17
+ cancel_event = threading.Event()
18
+
19
  access_token=os.environ['HF_TOKEN']
20
 
21
  # Optional: Disable GPU visibility if you wish to force CPU usage
 
405
  top_k, top_p, repeat_penalty, search_timeout):
406
  """
407
  Generates streaming chat responses, optionally with background web search.
408
+ This version includes cancellation support.
409
  """
410
+ # Clear the cancellation event at the start of a new generation
411
+ cancel_event.clear()
412
+
413
  history = list(chat_history or [])
414
  history.append({'role': 'user', 'content': user_msg})
415
 
 
512
 
513
  # Stream tokens
514
  for chunk in streamer:
515
+ # Check for cancellation signal
516
+ if cancel_event.is_set():
517
+ history[-1]['content'] += " [Generation Canceled]"
518
+ yield history, debug
519
+ break
520
+
521
  text = chunk
522
 
523
  # Detect start of thinking
 
573
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
574
  yield history, debug
575
  finally:
576
+ # Final cleanup
577
  gc.collect()
578
 
579
 
 
630
  st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
631
  clr = gr.Button("Clear Chat")
632
  with gr.Column(scale=7):
633
+ chat = gr.Chatbot(type="messages", height=600)
634
+ with gr.Row():
635
+ txt = gr.Textbox(placeholder="Type your message...", scale=8, container=False)
636
+ submit_btn = gr.Button("Submit", variant="primary", scale=1)
637
+ cancel_btn = gr.Button("⏹️ Cancel", variant="stop", visible=False, scale=1)
638
  dbg = gr.Markdown()
639
 
640
+ # --- Event Listeners ---
641
+
642
+ # Group all inputs for cleaner event handling
643
+ chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
644
+
645
+ def start_generation_and_update_ui(*args):
646
+ # Update UI to "generating" state
647
+ yield {
648
+ submit_btn: gr.update(interactive=False),
649
+ cancel_btn: gr.update(visible=True),
650
+ txt: gr.update(interactive=False, value=""), # Clear textbox and disable
651
+ }
652
+ # Call the actual chat response generator
653
+ for output in chat_response(*args):
654
+ yield {
655
+ chat: output[0],
656
+ dbg: output[1]
657
+ }
658
+
659
+ def reset_ui_after_generation():
660
+ # Update UI back to "idle" state
661
+ return {
662
+ submit_btn: gr.update(interactive=True),
663
+ cancel_btn: gr.update(visible=False),
664
+ txt: gr.update(interactive=True), # Re-enable textbox
665
+ }
666
+
667
+ def set_cancel_flag():
668
+ cancel_event.set()
669
+ print("Cancellation signal sent.")
670
+
671
+ # When the user submits their message (via button or enter)
672
+ submit_event = txt.submit(
673
+ fn=start_generation_and_update_ui,
674
+ inputs=chat_inputs,
675
+ outputs=[chat, dbg, submit_btn, cancel_btn, txt]
676
+ ).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
677
+
678
+ submit_btn.click(
679
+ fn=start_generation_and_update_ui,
680
+ inputs=chat_inputs,
681
+ outputs=[chat, dbg, submit_btn, cancel_btn, txt]
682
+ ).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
683
+
684
+ # When the user clicks the cancel button
685
+ cancel_btn.click(
686
+ fn=set_cancel_flag,
687
+ cancels=[submit_event] # This tells Gradio to stop the running `submit_event`
688
+ )
689
+
690
  # Update duration estimate when relevant inputs change
691
+ duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
692
+ for component in duration_inputs:
693
+ component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
+ # Other event listeners
696
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
697
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
698
+
699
+ demo.launch()