Corey Morris
commited on
Commit
·
9695a47
1
Parent(s):
b9b6115
Added radar chart. Compares a model to the 5 models that have the closest performance on MMLU_average
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import plotly.express as px
|
|
| 4 |
from result_data_processor import ResultDataProcessor
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
|
| 8 |
st.set_page_config(layout="wide")
|
| 9 |
|
|
@@ -47,6 +48,46 @@ def plot_top_n(df, target_column, n=10):
|
|
| 47 |
# Show the plot
|
| 48 |
st.pyplot(fig)
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
data_provider = ResultDataProcessor()
|
| 51 |
|
| 52 |
# st.title('Model Evaluation Results including MMLU by task')
|
|
@@ -131,6 +172,7 @@ st.download_button(
|
|
| 131 |
mime="text/csv",
|
| 132 |
)
|
| 133 |
|
|
|
|
| 134 |
def create_plot(df, x_values, y_values, models=None, title=None):
|
| 135 |
if models is not None:
|
| 136 |
df = df[df.index.isin(models)]
|
|
@@ -215,6 +257,21 @@ if selected_x_column != selected_y_column: # Avoid creating a plot with the s
|
|
| 215 |
else:
|
| 216 |
st.write("Please select different columns for the x and y axes.")
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# end of custom scatter plots
|
| 219 |
st.markdown("## Notable findings and plots")
|
| 220 |
|
|
|
|
| 4 |
from result_data_processor import ResultDataProcessor
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
|
| 9 |
st.set_page_config(layout="wide")
|
| 10 |
|
|
|
|
| 48 |
# Show the plot
|
| 49 |
st.pyplot(fig)
|
| 50 |
|
| 51 |
+
# Function to create an unfilled radar chart
|
| 52 |
+
def create_radar_chart_unfilled(df, model_names, metrics):
|
| 53 |
+
fig = go.Figure()
|
| 54 |
+
min_value = df.loc[model_names, metrics].min().min()
|
| 55 |
+
max_value = df.loc[model_names, metrics].max().max()
|
| 56 |
+
for model_name in model_names:
|
| 57 |
+
values_model = df.loc[model_name, metrics]
|
| 58 |
+
fig.add_trace(go.Scatterpolar(
|
| 59 |
+
r=values_model,
|
| 60 |
+
theta=metrics,
|
| 61 |
+
name=model_name
|
| 62 |
+
))
|
| 63 |
+
|
| 64 |
+
fig.update_layout(
|
| 65 |
+
polar=dict(
|
| 66 |
+
radialaxis=dict(
|
| 67 |
+
visible=True,
|
| 68 |
+
range=[min_value, max_value]
|
| 69 |
+
)),
|
| 70 |
+
showlegend=True
|
| 71 |
+
)
|
| 72 |
+
return fig
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Function to create a line chart
|
| 76 |
+
def create_line_chart(df, model_names, metrics):
|
| 77 |
+
line_data = []
|
| 78 |
+
for model_name in model_names:
|
| 79 |
+
values_model = df.loc[model_name, metrics]
|
| 80 |
+
for metric, value in zip(metrics, values_model):
|
| 81 |
+
line_data.append({'Model': model_name, 'Metric': metric, 'Value': value})
|
| 82 |
+
|
| 83 |
+
line_df = pd.DataFrame(line_data)
|
| 84 |
+
|
| 85 |
+
fig = px.line(line_df, x='Metric', y='Value', color='Model', title='Comparison of Models', line_dash_sequence=['solid'])
|
| 86 |
+
fig.update_layout(showlegend=True)
|
| 87 |
+
return fig
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
data_provider = ResultDataProcessor()
|
| 92 |
|
| 93 |
# st.title('Model Evaluation Results including MMLU by task')
|
|
|
|
| 172 |
mime="text/csv",
|
| 173 |
)
|
| 174 |
|
| 175 |
+
|
| 176 |
def create_plot(df, x_values, y_values, models=None, title=None):
|
| 177 |
if models is not None:
|
| 178 |
df = df[df.index.isin(models)]
|
|
|
|
| 257 |
else:
|
| 258 |
st.write("Please select different columns for the x and y axes.")
|
| 259 |
|
| 260 |
+
|
| 261 |
+
# Section to select a model and display radar and line charts
|
| 262 |
+
st.header("Compare Models")
|
| 263 |
+
selected_model_name = st.selectbox("Select a Model:", filtered_data.index.tolist())
|
| 264 |
+
metrics_to_compare = ['MMLU_abstract_algebra', 'MMLU_astronomy', 'MMLU_business_ethics', 'MMLU_average', 'MMLU_moral_scenarios']
|
| 265 |
+
closest_models = filtered_data['MMLU_average'].sub(filtered_data.loc[selected_model_name, 'MMLU_average']).abs().nsmallest(5).index.tolist()
|
| 266 |
+
|
| 267 |
+
fig_radar = create_radar_chart_unfilled(filtered_data, closest_models, metrics_to_compare)
|
| 268 |
+
fig_line = create_line_chart(filtered_data, closest_models, metrics_to_compare)
|
| 269 |
+
|
| 270 |
+
st.plotly_chart(fig_radar)
|
| 271 |
+
st.plotly_chart(fig_line)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
# end of custom scatter plots
|
| 276 |
st.markdown("## Notable findings and plots")
|
| 277 |
|