TuRTLe-Leaderboard / handlers /leaderboard_handlers.py
arnauad3's picture
Other Models Leaderboard
7dd3ffd
"""Event handlers for the TuRTLe leaderboard."""
import gradio as gr
from config.constants import (
CC_BENCHMARKS,
LC_BENCHMARKS,
NON_RTL_METRICS,
RTL_METRICS,
S2R_BENCHMARKS,
)
from utils import handle_special_cases
def create_leaderboard_handlers(
filter_leaderboard_fn,
generate_scatter_plot_fn,
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
bubble_benchmark,
bubble_metric,
scatter_plot,
leaderboard,
simulator_radio,
state,
name,
):
def update_benchmarks_by_task(task):
if task == "Spec-to-RTL":
new_benchmarks = ["All"] + S2R_BENCHMARKS
elif task == "Code Completion":
new_benchmarks = ["All"] + CC_BENCHMARKS
elif task == "Line Completion †":
new_benchmarks = LC_BENCHMARKS
else:
new_benchmarks = ["All"]
benchmark_value = "All" if "All" in new_benchmarks else new_benchmarks[0]
filtered = filter_leaderboard_fn(
task,
benchmark_value,
model_type_dropdown.value,
search_box.value,
params_slider.value,
state,
name,
)
return gr.update(value=benchmark_value, choices=new_benchmarks), filtered
def on_benchmark_change(benchmark, _):
if benchmark == "RTL-Repo":
metric = "Exact Matching (EM)"
return gr.update(choices=RTL_METRICS, value=metric), generate_scatter_plot_fn(
benchmark, metric, state
)
else:
metric = NON_RTL_METRICS[0]
return gr.update(choices=NON_RTL_METRICS[:-1], value=metric), generate_scatter_plot_fn(
benchmark, metric, state
)
def on_metric_change(benchmark, metric):
benchmark, metric = handle_special_cases(benchmark, metric)
fig = generate_scatter_plot_fn(benchmark, metric, state)
return gr.update(value=benchmark), fig
def on_simulator_change(
simulator,
task,
benchmark,
model_type,
search,
max_params,
plot_bench,
plot_metric,
):
state.set_simulator(simulator)
leaderboard_df = filter_leaderboard_fn(task, benchmark, model_type, search, max_params, state, name)
fig = generate_scatter_plot_fn(plot_bench, plot_metric, state)
return leaderboard_df, fig
task_radio.change(
fn=update_benchmarks_by_task,
inputs=[task_radio],
outputs=[benchmark_radio, leaderboard],
)
def filter_with_state(task, benchmark, model_type, search, max_params):
return filter_leaderboard_fn(task, benchmark, model_type, search, max_params, state, name)
benchmark_radio.change(
fn=filter_with_state,
inputs=[
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
],
outputs=leaderboard,
)
model_type_dropdown.change(
fn=filter_with_state,
inputs=[
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
],
outputs=leaderboard,
)
search_box.change(
fn=filter_with_state,
inputs=[
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
],
outputs=leaderboard,
)
params_slider.change(
fn=filter_with_state,
inputs=[
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
],
outputs=leaderboard,
)
# Scroll preservation JS for plot updates
scroll_preserve_js = """
// This is to avoid resetting user scroll each time a plot is re-generated
(benchmark, metric) => {
let scrollY = window.scrollY;
const observer = new MutationObserver(() => {
window.scrollTo(0, scrollY);
observer.disconnect();
});
observer.observe(document.getElementById('full-width-plot'), { childList: true });
return [benchmark, metric];
}
"""
bubble_benchmark.change(
fn=on_benchmark_change,
inputs=[bubble_benchmark, bubble_metric],
outputs=[bubble_metric, scatter_plot],
js=scroll_preserve_js,
)
bubble_metric.change(
fn=on_metric_change,
inputs=[bubble_benchmark, bubble_metric],
outputs=[bubble_benchmark, scatter_plot],
js=scroll_preserve_js,
)
simulator_radio.change(
fn=on_simulator_change,
inputs=[
simulator_radio,
task_radio,
benchmark_radio,
model_type_dropdown,
search_box,
params_slider,
bubble_benchmark,
bubble_metric,
],
outputs=[leaderboard, scatter_plot],
)