Spaces:
Running
Running
Restore cancel generation feature with improved UI integration
Browse files
app.py
CHANGED
|
@@ -13,6 +13,9 @@ from ddgs import DDGS
|
|
| 13 |
import spaces # Import spaces early to enable ZeroGPU support
|
| 14 |
from torch.utils._pytree import tree_map
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
access_token=os.environ['HF_TOKEN']
|
| 17 |
|
| 18 |
# Optional: Disable GPU visibility if you wish to force CPU usage
|
|
@@ -402,7 +405,11 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 402 |
top_k, top_p, repeat_penalty, search_timeout):
|
| 403 |
"""
|
| 404 |
Generates streaming chat responses, optionally with background web search.
|
|
|
|
| 405 |
"""
|
|
|
|
|
|
|
|
|
|
| 406 |
history = list(chat_history or [])
|
| 407 |
history.append({'role': 'user', 'content': user_msg})
|
| 408 |
|
|
@@ -505,6 +512,12 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 505 |
|
| 506 |
# Stream tokens
|
| 507 |
for chunk in streamer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
text = chunk
|
| 509 |
|
| 510 |
# Detect start of thinking
|
|
@@ -560,6 +573,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 560 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 561 |
yield history, debug
|
| 562 |
finally:
|
|
|
|
| 563 |
gc.collect()
|
| 564 |
|
| 565 |
|
|
@@ -616,34 +630,70 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 616 |
st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
|
| 617 |
clr = gr.Button("Clear Chat")
|
| 618 |
with gr.Column(scale=7):
|
| 619 |
-
chat = gr.Chatbot(type="messages")
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
| 621 |
dbg = gr.Markdown()
|
| 622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
# Update duration estimate when relevant inputs change
|
| 624 |
-
model_dd
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
search_chk.change(fn=update_duration_estimate,
|
| 628 |
-
inputs=[model_dd, search_chk, mr, mc, max_tok, st],
|
| 629 |
-
outputs=duration_display)
|
| 630 |
-
max_tok.change(fn=update_duration_estimate,
|
| 631 |
-
inputs=[model_dd, search_chk, mr, mc, max_tok, st],
|
| 632 |
-
outputs=duration_display)
|
| 633 |
-
mr.change(fn=update_duration_estimate,
|
| 634 |
-
inputs=[model_dd, search_chk, mr, mc, max_tok, st],
|
| 635 |
-
outputs=duration_display)
|
| 636 |
-
mc.change(fn=update_duration_estimate,
|
| 637 |
-
inputs=[model_dd, search_chk, mr, mc, max_tok, st],
|
| 638 |
-
outputs=duration_display)
|
| 639 |
-
st.change(fn=update_duration_estimate,
|
| 640 |
-
inputs=[model_dd, search_chk, mr, mc, max_tok, st],
|
| 641 |
-
outputs=duration_display)
|
| 642 |
|
|
|
|
| 643 |
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
|
| 644 |
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
model_dd, max_tok, temp, k, p, rp, st],
|
| 648 |
-
outputs=[chat, dbg])
|
| 649 |
-
demo.launch()
|
|
|
|
| 13 |
import spaces # Import spaces early to enable ZeroGPU support
|
| 14 |
from torch.utils._pytree import tree_map
|
| 15 |
|
| 16 |
+
# Global event to signal cancellation from the UI thread to the generation thread
|
| 17 |
+
cancel_event = threading.Event()
|
| 18 |
+
|
| 19 |
access_token=os.environ['HF_TOKEN']
|
| 20 |
|
| 21 |
# Optional: Disable GPU visibility if you wish to force CPU usage
|
|
|
|
| 405 |
top_k, top_p, repeat_penalty, search_timeout):
|
| 406 |
"""
|
| 407 |
Generates streaming chat responses, optionally with background web search.
|
| 408 |
+
This version includes cancellation support.
|
| 409 |
"""
|
| 410 |
+
# Clear the cancellation event at the start of a new generation
|
| 411 |
+
cancel_event.clear()
|
| 412 |
+
|
| 413 |
history = list(chat_history or [])
|
| 414 |
history.append({'role': 'user', 'content': user_msg})
|
| 415 |
|
|
|
|
| 512 |
|
| 513 |
# Stream tokens
|
| 514 |
for chunk in streamer:
|
| 515 |
+
# Check for cancellation signal
|
| 516 |
+
if cancel_event.is_set():
|
| 517 |
+
history[-1]['content'] += " [Generation Canceled]"
|
| 518 |
+
yield history, debug
|
| 519 |
+
break
|
| 520 |
+
|
| 521 |
text = chunk
|
| 522 |
|
| 523 |
# Detect start of thinking
|
|
|
|
| 573 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 574 |
yield history, debug
|
| 575 |
finally:
|
| 576 |
+
# Final cleanup
|
| 577 |
gc.collect()
|
| 578 |
|
| 579 |
|
|
|
|
| 630 |
st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
|
| 631 |
clr = gr.Button("Clear Chat")
|
| 632 |
with gr.Column(scale=7):
|
| 633 |
+
chat = gr.Chatbot(type="messages", height=600)
|
| 634 |
+
with gr.Row():
|
| 635 |
+
txt = gr.Textbox(placeholder="Type your message...", scale=8, container=False)
|
| 636 |
+
submit_btn = gr.Button("Submit", variant="primary", scale=1)
|
| 637 |
+
cancel_btn = gr.Button("⏹️ Cancel", variant="stop", visible=False, scale=1)
|
| 638 |
dbg = gr.Markdown()
|
| 639 |
|
| 640 |
+
# --- Event Listeners ---
|
| 641 |
+
|
| 642 |
+
# Group all inputs for cleaner event handling
|
| 643 |
+
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
|
| 644 |
+
|
| 645 |
+
def start_generation_and_update_ui(*args):
|
| 646 |
+
# Update UI to "generating" state
|
| 647 |
+
yield {
|
| 648 |
+
submit_btn: gr.update(interactive=False),
|
| 649 |
+
cancel_btn: gr.update(visible=True),
|
| 650 |
+
txt: gr.update(interactive=False, value=""), # Clear textbox and disable
|
| 651 |
+
}
|
| 652 |
+
# Call the actual chat response generator
|
| 653 |
+
for output in chat_response(*args):
|
| 654 |
+
yield {
|
| 655 |
+
chat: output[0],
|
| 656 |
+
dbg: output[1]
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
def reset_ui_after_generation():
|
| 660 |
+
# Update UI back to "idle" state
|
| 661 |
+
return {
|
| 662 |
+
submit_btn: gr.update(interactive=True),
|
| 663 |
+
cancel_btn: gr.update(visible=False),
|
| 664 |
+
txt: gr.update(interactive=True), # Re-enable textbox
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
def set_cancel_flag():
|
| 668 |
+
cancel_event.set()
|
| 669 |
+
print("Cancellation signal sent.")
|
| 670 |
+
|
| 671 |
+
# When the user submits their message (via button or enter)
|
| 672 |
+
submit_event = txt.submit(
|
| 673 |
+
fn=start_generation_and_update_ui,
|
| 674 |
+
inputs=chat_inputs,
|
| 675 |
+
outputs=[chat, dbg, submit_btn, cancel_btn, txt]
|
| 676 |
+
).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
|
| 677 |
+
|
| 678 |
+
submit_btn.click(
|
| 679 |
+
fn=start_generation_and_update_ui,
|
| 680 |
+
inputs=chat_inputs,
|
| 681 |
+
outputs=[chat, dbg, submit_btn, cancel_btn, txt]
|
| 682 |
+
).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
|
| 683 |
+
|
| 684 |
+
# When the user clicks the cancel button
|
| 685 |
+
cancel_btn.click(
|
| 686 |
+
fn=set_cancel_flag,
|
| 687 |
+
cancels=[submit_event] # This tells Gradio to stop the running `submit_event`
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
# Update duration estimate when relevant inputs change
|
| 691 |
+
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
|
| 692 |
+
for component in duration_inputs:
|
| 693 |
+
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
+
# Other event listeners
|
| 696 |
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
|
| 697 |
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 698 |
+
|
| 699 |
+
demo.launch()
|
|
|
|
|
|
|
|
|