import pandas as pd import plotly.express as px from config.constants import ( CC_BENCHMARKS, LC_BENCHMARKS, NON_RTL_METRICS, RTL_METRICS, S2R_BENCHMARKS, SCATTER_PLOT_X_TICKS, TYPE_COLORS, Y_AXIS_LIMITS, DISCARDED_MODELS, ) from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases # this is just a simple class to load the correct data depending on which sim we are at class Simulator: def __init__(self, icarus_df, icarus_agg, verilator_df, verilator_agg): self.icarus_df = icarus_df self.icarus_agg = icarus_agg self.verilator_df = verilator_df self.verilator_agg = verilator_agg self.current_simulator = "Icarus" def get_current_df(self): if self.current_simulator == "Icarus": return self.icarus_df else: return self.verilator_df def get_current_agg(self): if self.current_simulator == "Icarus": return self.icarus_agg else: return self.verilator_agg def set_simulator(self, simulator): self.current_simulator = simulator # filtering main function for the leaderboard body def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state, name): """Filter leaderboard data based on user selections.""" subset = state.get_current_df().copy() # Filter by task specific benchmarks when 'All' benchmarks is selected if task == "Spec-to-RTL": valid_benchmarks = S2R_BENCHMARKS if benchmark == "All": subset = subset[subset["Benchmark"].isin(valid_benchmarks)] elif task == "Code Completion": valid_benchmarks = CC_BENCHMARKS if benchmark == "All": subset = subset[subset["Benchmark"].isin(valid_benchmarks)] elif task == "Line Completion †": valid_benchmarks = LC_BENCHMARKS if benchmark == "All": subset = subset[subset["Benchmark"].isin(valid_benchmarks)] if benchmark != "All": subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark] if model_type != "All": # without emojis subset = subset[subset["Model Type"] == model_type.split(" ")[0]] if search_query: subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)] max_params = float(max_params) subset = subset[subset["Params"] <= max_params] if name == "Other Models": subset = subset[subset["Model"].isin(DISCARDED_MODELS)] else: subset = subset[~subset["Model"].isin(DISCARDED_MODELS)] if benchmark == "All": if task == "Spec-to-RTL": return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R", name=name) elif task == "Code Completion": return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC", name=name) elif task == "Line Completion †": return filter_RTLRepo(subset, name=name) elif benchmark == "RTL-Repo": return filter_RTLRepo(subset, name=name) else: agg_column = None if benchmark == "VerilogEval S2R": agg_column = "Agg VerilogEval S2R" elif benchmark == "VerilogEval MC": agg_column = "Agg VerilogEval MC" elif benchmark == "RTLLM": agg_column = "Agg RTLLM" elif benchmark == "VeriGen": agg_column = "Agg VeriGen" return filter_bench(subset, state.get_current_agg(), agg_column, name=name) def generate_scatter_plot(benchmark, metric, state): """Generate a scatter plot for the given benchmark and metric.""" benchmark, metric = handle_special_cases(benchmark, metric) subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark] subset = subset[~subset["Model"].isin(DISCARDED_MODELS)] if benchmark == "RTL-Repo": subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)] detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean() detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True) else: detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index() details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model") scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna( subset=["Params", metric] ) scatter_data["x"] = scatter_data["Params"] scatter_data["y"] = scatter_data[metric] scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40 scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray") y_range = Y_AXIS_LIMITS.get(metric, [0, 80]) fig = px.scatter( scatter_data, x="x", y="y", log_x=True, size="size", color="Model Type", text="Model", hover_data={metric: ":.2f"}, title=f"Params vs. {metric} for {benchmark}", labels={"x": "# Params (Log Scale)", "y": metric}, template="plotly_white", height=600, width=1200, ) fig.update_traces( textposition="top center", textfont_size=10, marker=dict(opacity=0.8, line=dict(width=0.5, color="black")), ) fig.update_layout( xaxis=dict( showgrid=True, type="log", tickmode="array", tickvals=SCATTER_PLOT_X_TICKS["tickvals"], ticktext=SCATTER_PLOT_X_TICKS["ticktext"], ), showlegend=False, yaxis=dict(range=y_range), margin=dict(l=50, r=50, t=50, b=50), plot_bgcolor="white", ) return fig