Spaces:
Running
Running
| import pandas as pd | |
| import plotly.express as px | |
| from config.constants import ( | |
| CC_BENCHMARKS, | |
| LC_BENCHMARKS, | |
| NON_RTL_METRICS, | |
| RTL_METRICS, | |
| S2R_BENCHMARKS, | |
| SCATTER_PLOT_X_TICKS, | |
| TYPE_COLORS, | |
| Y_AXIS_LIMITS, | |
| DISCARDED_MODELS, | |
| ) | |
| from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases | |
| # this is just a simple class to load the correct data depending on which sim we are at | |
| class Simulator: | |
| def __init__(self, icarus_df, icarus_agg, verilator_df, verilator_agg): | |
| self.icarus_df = icarus_df | |
| self.icarus_agg = icarus_agg | |
| self.verilator_df = verilator_df | |
| self.verilator_agg = verilator_agg | |
| self.current_simulator = "Icarus" | |
| def get_current_df(self): | |
| if self.current_simulator == "Icarus": | |
| return self.icarus_df | |
| else: | |
| return self.verilator_df | |
| def get_current_agg(self): | |
| if self.current_simulator == "Icarus": | |
| return self.icarus_agg | |
| else: | |
| return self.verilator_agg | |
| def set_simulator(self, simulator): | |
| self.current_simulator = simulator | |
| # filtering main function for the leaderboard body | |
| def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state, name): | |
| """Filter leaderboard data based on user selections.""" | |
| subset = state.get_current_df().copy() | |
| # Filter by task specific benchmarks when 'All' benchmarks is selected | |
| if task == "Spec-to-RTL": | |
| valid_benchmarks = S2R_BENCHMARKS | |
| if benchmark == "All": | |
| subset = subset[subset["Benchmark"].isin(valid_benchmarks)] | |
| elif task == "Code Completion": | |
| valid_benchmarks = CC_BENCHMARKS | |
| if benchmark == "All": | |
| subset = subset[subset["Benchmark"].isin(valid_benchmarks)] | |
| elif task == "Line Completion †": | |
| valid_benchmarks = LC_BENCHMARKS | |
| if benchmark == "All": | |
| subset = subset[subset["Benchmark"].isin(valid_benchmarks)] | |
| if benchmark != "All": | |
| subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark] | |
| if model_type != "All": | |
| # without emojis | |
| subset = subset[subset["Model Type"] == model_type.split(" ")[0]] | |
| if search_query: | |
| subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)] | |
| max_params = float(max_params) | |
| subset = subset[subset["Params"] <= max_params] | |
| if name == "Other Models": | |
| subset = subset[subset["Model"].isin(DISCARDED_MODELS)] | |
| else: | |
| subset = subset[~subset["Model"].isin(DISCARDED_MODELS)] | |
| if benchmark == "All": | |
| if task == "Spec-to-RTL": | |
| return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R", name=name) | |
| elif task == "Code Completion": | |
| return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC", name=name) | |
| elif task == "Line Completion †": | |
| return filter_RTLRepo(subset, name=name) | |
| elif benchmark == "RTL-Repo": | |
| return filter_RTLRepo(subset, name=name) | |
| else: | |
| agg_column = None | |
| if benchmark == "VerilogEval S2R": | |
| agg_column = "Agg VerilogEval S2R" | |
| elif benchmark == "VerilogEval MC": | |
| agg_column = "Agg VerilogEval MC" | |
| elif benchmark == "RTLLM": | |
| agg_column = "Agg RTLLM" | |
| elif benchmark == "VeriGen": | |
| agg_column = "Agg VeriGen" | |
| return filter_bench(subset, state.get_current_agg(), agg_column, name=name) | |
| def generate_scatter_plot(benchmark, metric, state): | |
| """Generate a scatter plot for the given benchmark and metric.""" | |
| benchmark, metric = handle_special_cases(benchmark, metric) | |
| subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark] | |
| subset = subset[~subset["Model"].isin(DISCARDED_MODELS)] | |
| if benchmark == "RTL-Repo": | |
| subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)] | |
| detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean() | |
| detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True) | |
| else: | |
| detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index() | |
| details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model") | |
| scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna( | |
| subset=["Params", metric] | |
| ) | |
| scatter_data["x"] = scatter_data["Params"] | |
| scatter_data["y"] = scatter_data[metric] | |
| scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40 | |
| scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray") | |
| y_range = Y_AXIS_LIMITS.get(metric, [0, 80]) | |
| fig = px.scatter( | |
| scatter_data, | |
| x="x", | |
| y="y", | |
| log_x=True, | |
| size="size", | |
| color="Model Type", | |
| text="Model", | |
| hover_data={metric: ":.2f"}, | |
| title=f"Params vs. {metric} for {benchmark}", | |
| labels={"x": "# Params (Log Scale)", "y": metric}, | |
| template="plotly_white", | |
| height=600, | |
| width=1200, | |
| ) | |
| fig.update_traces( | |
| textposition="top center", | |
| textfont_size=10, | |
| marker=dict(opacity=0.8, line=dict(width=0.5, color="black")), | |
| ) | |
| fig.update_layout( | |
| xaxis=dict( | |
| showgrid=True, | |
| type="log", | |
| tickmode="array", | |
| tickvals=SCATTER_PLOT_X_TICKS["tickvals"], | |
| ticktext=SCATTER_PLOT_X_TICKS["ticktext"], | |
| ), | |
| showlegend=False, | |
| yaxis=dict(range=y_range), | |
| margin=dict(l=50, r=50, t=50, b=50), | |
| plot_bgcolor="white", | |
| ) | |
| return fig | |