Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
36bf409
1
Parent(s):
8e47868
Updated plotted models to exclude flagged models
Browse files
src/display_models/plot_results.py
CHANGED
|
@@ -4,6 +4,7 @@ from plotly.graph_objs import Figure
|
|
| 4 |
import pickle
|
| 5 |
from datetime import datetime, timezone
|
| 6 |
from typing import List, Dict, Tuple, Any
|
|
|
|
| 7 |
|
| 8 |
# Average ⬆️ human baseline is 0.897 (source: averaging human baselines below)
|
| 9 |
# ARC human baseline is 0.80 (source: https://lab42.global/arc/)
|
|
@@ -42,6 +43,9 @@ def join_model_info_with_results(results_df: pd.DataFrame) -> pd.DataFrame:
|
|
| 42 |
# copy dataframe to avoid modifying the original
|
| 43 |
df = results_df.copy(deep=True)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
# load cache from disk
|
| 46 |
try:
|
| 47 |
with open("model_info_cache.pkl", "rb") as f:
|
|
@@ -216,4 +220,4 @@ def create_metric_plot_obj(
|
|
| 216 |
|
| 217 |
# Example Usage:
|
| 218 |
# human_baselines dictionary is defined.
|
| 219 |
-
# chart = create_metric_plot_obj(scores_df, ["ARC", "HellaSwag", "MMLU", "TruthfulQA"], human_baselines, "Graph Title")
|
|
|
|
| 4 |
import pickle
|
| 5 |
from datetime import datetime, timezone
|
| 6 |
from typing import List, Dict, Tuple, Any
|
| 7 |
+
from src.display_models.model_metadata_flags import FLAGGED_MODELS
|
| 8 |
|
| 9 |
# Average ⬆️ human baseline is 0.897 (source: averaging human baselines below)
|
| 10 |
# ARC human baseline is 0.80 (source: https://lab42.global/arc/)
|
|
|
|
| 43 |
# copy dataframe to avoid modifying the original
|
| 44 |
df = results_df.copy(deep=True)
|
| 45 |
|
| 46 |
+
# Filter out FLAGGED_MODELS to ensure graph is not skewed by mistakes
|
| 47 |
+
df = df[~df["model_name_for_query"].isin(FLAGGED_MODELS.keys())].reset_index(drop=True)
|
| 48 |
+
|
| 49 |
# load cache from disk
|
| 50 |
try:
|
| 51 |
with open("model_info_cache.pkl", "rb") as f:
|
|
|
|
| 220 |
|
| 221 |
# Example Usage:
|
| 222 |
# human_baselines dictionary is defined.
|
| 223 |
+
# chart = create_metric_plot_obj(scores_df, ["ARC", "HellaSwag", "MMLU", "TruthfulQA"], human_baselines, "Graph Title")
|