Spaces:
Running
Running
| """Event handlers for the TuRTLe leaderboard.""" | |
| import gradio as gr | |
| from config.constants import ( | |
| CC_BENCHMARKS, | |
| LC_BENCHMARKS, | |
| NON_RTL_METRICS, | |
| RTL_METRICS, | |
| S2R_BENCHMARKS, | |
| ) | |
| from utils import handle_special_cases | |
| def create_leaderboard_handlers( | |
| filter_leaderboard_fn, | |
| generate_scatter_plot_fn, | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| bubble_benchmark, | |
| bubble_metric, | |
| scatter_plot, | |
| leaderboard, | |
| simulator_radio, | |
| state, | |
| name, | |
| ): | |
| def update_benchmarks_by_task(task): | |
| if task == "Spec-to-RTL": | |
| new_benchmarks = ["All"] + S2R_BENCHMARKS | |
| elif task == "Code Completion": | |
| new_benchmarks = ["All"] + CC_BENCHMARKS | |
| elif task == "Line Completion †": | |
| new_benchmarks = LC_BENCHMARKS | |
| else: | |
| new_benchmarks = ["All"] | |
| benchmark_value = "All" if "All" in new_benchmarks else new_benchmarks[0] | |
| filtered = filter_leaderboard_fn( | |
| task, | |
| benchmark_value, | |
| model_type_dropdown.value, | |
| search_box.value, | |
| params_slider.value, | |
| state, | |
| name, | |
| ) | |
| return gr.update(value=benchmark_value, choices=new_benchmarks), filtered | |
| def on_benchmark_change(benchmark, _): | |
| if benchmark == "RTL-Repo": | |
| metric = "Exact Matching (EM)" | |
| return gr.update(choices=RTL_METRICS, value=metric), generate_scatter_plot_fn( | |
| benchmark, metric, state | |
| ) | |
| else: | |
| metric = NON_RTL_METRICS[0] | |
| return gr.update(choices=NON_RTL_METRICS[:-1], value=metric), generate_scatter_plot_fn( | |
| benchmark, metric, state | |
| ) | |
| def on_metric_change(benchmark, metric): | |
| benchmark, metric = handle_special_cases(benchmark, metric) | |
| fig = generate_scatter_plot_fn(benchmark, metric, state) | |
| return gr.update(value=benchmark), fig | |
| def on_simulator_change( | |
| simulator, | |
| task, | |
| benchmark, | |
| model_type, | |
| search, | |
| max_params, | |
| plot_bench, | |
| plot_metric, | |
| ): | |
| state.set_simulator(simulator) | |
| leaderboard_df = filter_leaderboard_fn(task, benchmark, model_type, search, max_params, state, name) | |
| fig = generate_scatter_plot_fn(plot_bench, plot_metric, state) | |
| return leaderboard_df, fig | |
| task_radio.change( | |
| fn=update_benchmarks_by_task, | |
| inputs=[task_radio], | |
| outputs=[benchmark_radio, leaderboard], | |
| ) | |
| def filter_with_state(task, benchmark, model_type, search, max_params): | |
| return filter_leaderboard_fn(task, benchmark, model_type, search, max_params, state, name) | |
| benchmark_radio.change( | |
| fn=filter_with_state, | |
| inputs=[ | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| ], | |
| outputs=leaderboard, | |
| ) | |
| model_type_dropdown.change( | |
| fn=filter_with_state, | |
| inputs=[ | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| ], | |
| outputs=leaderboard, | |
| ) | |
| search_box.change( | |
| fn=filter_with_state, | |
| inputs=[ | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| ], | |
| outputs=leaderboard, | |
| ) | |
| params_slider.change( | |
| fn=filter_with_state, | |
| inputs=[ | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| ], | |
| outputs=leaderboard, | |
| ) | |
| # Scroll preservation JS for plot updates | |
| scroll_preserve_js = """ | |
| // This is to avoid resetting user scroll each time a plot is re-generated | |
| (benchmark, metric) => { | |
| let scrollY = window.scrollY; | |
| const observer = new MutationObserver(() => { | |
| window.scrollTo(0, scrollY); | |
| observer.disconnect(); | |
| }); | |
| observer.observe(document.getElementById('full-width-plot'), { childList: true }); | |
| return [benchmark, metric]; | |
| } | |
| """ | |
| bubble_benchmark.change( | |
| fn=on_benchmark_change, | |
| inputs=[bubble_benchmark, bubble_metric], | |
| outputs=[bubble_metric, scatter_plot], | |
| js=scroll_preserve_js, | |
| ) | |
| bubble_metric.change( | |
| fn=on_metric_change, | |
| inputs=[bubble_benchmark, bubble_metric], | |
| outputs=[bubble_benchmark, scatter_plot], | |
| js=scroll_preserve_js, | |
| ) | |
| simulator_radio.change( | |
| fn=on_simulator_change, | |
| inputs=[ | |
| simulator_radio, | |
| task_radio, | |
| benchmark_radio, | |
| model_type_dropdown, | |
| search_box, | |
| params_slider, | |
| bubble_benchmark, | |
| bubble_metric, | |
| ], | |
| outputs=[leaderboard, scatter_plot], | |
| ) | |