Luigi commited on
Commit
fc989b4
·
1 Parent(s): ab92e0d

Add dynamic GPU time estimate indicator to UI

Browse files

- Shows estimated GPU seconds for current inference settings
- Updates in real-time when model, search settings, or max tokens change
- Displays model size, AOT status, and search status
- Helps users manage precious ZeroGPU time budget effectively

Files changed (1) hide show
  1. app.py +48 -0
app.py CHANGED
@@ -614,6 +614,28 @@ def cancel_generation():
614
  def update_default_prompt(enable_search):
615
  return f"You are a helpful assistant."
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  # ------------------------------
618
  # Gradio UI
619
  # ------------------------------
@@ -625,6 +647,12 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
625
  model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value="Qwen3-1.7B")
626
  search_chk = gr.Checkbox(label="Enable Web Search", value=False)
627
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
 
 
 
 
 
 
628
  gr.Markdown("### Generation Parameters")
629
  max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
630
  temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
@@ -642,6 +670,26 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
642
  txt = gr.Textbox(placeholder="Type your message and press Enter...")
643
  dbg = gr.Markdown()
644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
646
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
647
  cnl.click(fn=cancel_generation, outputs=dbg)
 
614
  def update_default_prompt(enable_search):
615
  return f"You are a helpful assistant."
616
 
617
+ def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
618
+ """Calculate and format the estimated GPU duration for current settings."""
619
+ try:
620
+ # Create dummy values for the other parameters that get_duration expects
621
+ dummy_msg = ""
622
+ dummy_history = []
623
+ dummy_system_prompt = ""
624
+
625
+ duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
626
+ enable_search, max_results, max_chars, model_name,
627
+ max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
628
+
629
+ model_size = MODELS[model_name].get("params_b", 4.0)
630
+ use_aot = model_size >= 2
631
+
632
+ return f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n" \
633
+ f"📊 **Model Size:** {model_size:.1f}B parameters\n" \
634
+ f"⚡ **AOT Compilation:** {'Enabled' if use_aot else 'Disabled'}\n" \
635
+ f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}"
636
+ except Exception as e:
637
+ return f"⚠️ Error calculating estimate: {e}"
638
+
639
  # ------------------------------
640
  # Gradio UI
641
  # ------------------------------
 
647
  model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value="Qwen3-1.7B")
648
  search_chk = gr.Checkbox(label="Enable Web Search", value=False)
649
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
650
+
651
+ # GPU Time Estimate Display
652
+ duration_display = gr.Markdown(value=update_duration_estimate(
653
+ "Qwen3-1.7B", False, 4, 50, 1024, 5.0
654
+ ))
655
+
656
  gr.Markdown("### Generation Parameters")
657
  max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
658
  temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
 
670
  txt = gr.Textbox(placeholder="Type your message and press Enter...")
671
  dbg = gr.Markdown()
672
 
673
+ # Update duration estimate when relevant inputs change
674
+ model_dd.change(fn=update_duration_estimate,
675
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
676
+ outputs=duration_display)
677
+ search_chk.change(fn=update_duration_estimate,
678
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
679
+ outputs=duration_display)
680
+ max_tok.change(fn=update_duration_estimate,
681
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
682
+ outputs=duration_display)
683
+ mr.change(fn=update_duration_estimate,
684
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
685
+ outputs=duration_display)
686
+ mc.change(fn=update_duration_estimate,
687
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
688
+ outputs=duration_display)
689
+ st.change(fn=update_duration_estimate,
690
+ inputs=[model_dd, search_chk, mr, mc, max_tok, st],
691
+ outputs=duration_display)
692
+
693
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
694
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
695
  cnl.click(fn=cancel_generation, outputs=dbg)