TuRTLe-Leaderboard / data_processing.py
arnauad3's picture
Other Models Leaderboard
7dd3ffd
raw
history blame
5.77 kB
import pandas as pd
import plotly.express as px
from config.constants import (
CC_BENCHMARKS,
LC_BENCHMARKS,
NON_RTL_METRICS,
RTL_METRICS,
S2R_BENCHMARKS,
SCATTER_PLOT_X_TICKS,
TYPE_COLORS,
Y_AXIS_LIMITS,
DISCARDED_MODELS,
)
from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases
# this is just a simple class to load the correct data depending on which sim we are at
class Simulator:
def __init__(self, icarus_df, icarus_agg, verilator_df, verilator_agg):
self.icarus_df = icarus_df
self.icarus_agg = icarus_agg
self.verilator_df = verilator_df
self.verilator_agg = verilator_agg
self.current_simulator = "Icarus"
def get_current_df(self):
if self.current_simulator == "Icarus":
return self.icarus_df
else:
return self.verilator_df
def get_current_agg(self):
if self.current_simulator == "Icarus":
return self.icarus_agg
else:
return self.verilator_agg
def set_simulator(self, simulator):
self.current_simulator = simulator
# filtering main function for the leaderboard body
def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state, name):
"""Filter leaderboard data based on user selections."""
subset = state.get_current_df().copy()
# Filter by task specific benchmarks when 'All' benchmarks is selected
if task == "Spec-to-RTL":
valid_benchmarks = S2R_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Code Completion":
valid_benchmarks = CC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Line Completion †":
valid_benchmarks = LC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
if benchmark != "All":
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
if model_type != "All":
# without emojis
subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
if search_query:
subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
max_params = float(max_params)
subset = subset[subset["Params"] <= max_params]
if name == "Other Models":
subset = subset[subset["Model"].isin(DISCARDED_MODELS)]
else:
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
if benchmark == "All":
if task == "Spec-to-RTL":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R", name=name)
elif task == "Code Completion":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC", name=name)
elif task == "Line Completion †":
return filter_RTLRepo(subset, name=name)
elif benchmark == "RTL-Repo":
return filter_RTLRepo(subset, name=name)
else:
agg_column = None
if benchmark == "VerilogEval S2R":
agg_column = "Agg VerilogEval S2R"
elif benchmark == "VerilogEval MC":
agg_column = "Agg VerilogEval MC"
elif benchmark == "RTLLM":
agg_column = "Agg RTLLM"
elif benchmark == "VeriGen":
agg_column = "Agg VeriGen"
return filter_bench(subset, state.get_current_agg(), agg_column, name=name)
def generate_scatter_plot(benchmark, metric, state):
"""Generate a scatter plot for the given benchmark and metric."""
benchmark, metric = handle_special_cases(benchmark, metric)
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
if benchmark == "RTL-Repo":
subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
else:
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
subset=["Params", metric]
)
scatter_data["x"] = scatter_data["Params"]
scatter_data["y"] = scatter_data[metric]
scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
fig = px.scatter(
scatter_data,
x="x",
y="y",
log_x=True,
size="size",
color="Model Type",
text="Model",
hover_data={metric: ":.2f"},
title=f"Params vs. {metric} for {benchmark}",
labels={"x": "# Params (Log Scale)", "y": metric},
template="plotly_white",
height=600,
width=1200,
)
fig.update_traces(
textposition="top center",
textfont_size=10,
marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
)
fig.update_layout(
xaxis=dict(
showgrid=True,
type="log",
tickmode="array",
tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
),
showlegend=False,
yaxis=dict(range=y_range),
margin=dict(l=50, r=50, t=50, b=50),
plot_bgcolor="white",
)
return fig