Spaces:
Running
Running
| import altair as alt | |
| import fev | |
| import pandas as pd | |
| import pandas.io.formats.style | |
| # Color constants - all colors defined in one place | |
| COLORS = { | |
| "dl_text": "#5A7FA5", | |
| "st_text": "#666666", | |
| "bar_fill": "#8d5eb7", | |
| "error_bar": "#222222", | |
| "point": "#111111", | |
| "text_white": "white", | |
| "text_black": "black", | |
| "text_default": "#111", | |
| "gold": "#F7D36B", | |
| "silver": "#E5E7EB", | |
| "bronze": "#E6B089", | |
| "leakage_impute": "#3B82A0", | |
| "failure_impute": "#E07B39", | |
| } | |
| HEATMAP_COLOR_SCHEME = "purplegreen" | |
| # Model configuration: (url, org, zero_shot, model_type) | |
| MODEL_CONFIG = { | |
| # Chronos Models | |
| "chronos_tiny": ("amazon/chronos-t5-tiny", "AWS", True, "DL"), | |
| "chronos_mini": ("amazon/chronos-t5-mini", "AWS", True, "DL"), | |
| "chronos_small": ("amazon/chronos-t5-small", "AWS", True, "DL"), | |
| "chronos_base": ("amazon/chronos-t5-base", "AWS", True, "DL"), | |
| "chronos_large": ("amazon/chronos-t5-large", "AWS", True, "DL"), | |
| "chronos_bolt_tiny": ("amazon/chronos-bolt-tiny", "AWS", True, "DL"), | |
| "chronos_bolt_mini": ("amazon/chronos-bolt-mini", "AWS", True, "DL"), | |
| "chronos_bolt_small": ("amazon/chronos-bolt-small", "AWS", True, "DL"), | |
| "chronos_bolt_base": ("amazon/chronos-bolt-base", "AWS", True, "DL"), | |
| "chronos-bolt": ("amazon/chronos-bolt-base", "AWS", True, "DL"), | |
| # Moirai Models | |
| "moirai_large": ("Salesforce/moirai-1.1-R-large", "Salesforce", True, "DL"), | |
| "moirai_base": ("Salesforce/moirai-1.1-R-base", "Salesforce", True, "DL"), | |
| "moirai_small": ("Salesforce/moirai-1.1-R-small", "Salesforce", True, "DL"), | |
| "moirai-2.0": ("Salesforce/moirai-2.0-R-small", "Salesforce", True, "DL"), | |
| # TimesFM Models | |
| "timesfm": ("google/timesfm-1.0-200m-pytorch", "Google", True, "DL"), | |
| "timesfm-2.0": ("google/timesfm-2.0-500m-pytorch", "Google", True, "DL"), | |
| "timesfm-2.5": ("google/timesfm-2.5-200m-pytorch", "Google", True, "DL"), | |
| # Toto Models | |
| "toto-1.0": ("Datadog/Toto-Open-Base-1.0", "Datadog", True, "DL"), | |
| # Other Models | |
| "tirex": ("NX-AI/TiRex", "NX-AI", True, "DL"), | |
| "tabpfn-ts": ("Prior-Labs/TabPFN-v2-reg", "Prior Labs", True, "DL"), | |
| "sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"), | |
| "ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"), | |
| # Task-specific models | |
| "stat. ensemble": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "seasonal naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| "naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
| } | |
| ALL_METRICS = { | |
| "SQL": ( | |
| "SQL: Scaled Quantile Loss", | |
| "The [Scaled Quantile Loss (SQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.SQL) is a scale-invariant metric for evaluating probabilistic forecasts.", | |
| ), | |
| "MASE": ( | |
| "MASE: Mean Absolute Scaled Error", | |
| "The [Mean Absolute Scaled Error (MASE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.MASE) is a scale-invariant metric for evaluating point forecasts.", | |
| ), | |
| "WQL": ( | |
| "WQL: Weighted Quantile Loss", | |
| "The [Weighted Quantile Loss (WQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WQL), is a scale-dependent metric for evaluating probabilistic forecasts.", | |
| ), | |
| "WAPE": ( | |
| "WAPE: Weighted Absolute Percentage Error", | |
| "The [Weighted Absolute Percentage Error (WAPE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WAPE) is a scale-dependent metric for evaluating point forecasts.", | |
| ), | |
| } | |
| def format_metric_name(metric_name: str): | |
| return ALL_METRICS[metric_name][0] | |
| def get_metric_description(metric_name: str): | |
| return ALL_METRICS[metric_name][1] | |
| def get_model_link(model_name): | |
| config = MODEL_CONFIG.get(model_name.lower()) | |
| if not config or not config[0]: | |
| return "" | |
| url = config[0] | |
| return url if url.startswith("https:") else f"https://huggingface.co/{url}" | |
| def get_model_organization(model_name): | |
| config = MODEL_CONFIG.get(model_name.lower()) | |
| return config[1] if config else "—" | |
| def get_zero_shot_status(model_name): | |
| config = MODEL_CONFIG.get(model_name.lower()) | |
| return "✓" if config and config[2] else "×" | |
| def get_model_type(model_name): | |
| config = MODEL_CONFIG.get(model_name.lower()) | |
| return config[3] if config else "—" | |
| def highlight_model_type_color(cell): | |
| config = MODEL_CONFIG.get(cell.lower()) | |
| if config: | |
| color = COLORS["dl_text"] if config[3] == "DL" else COLORS["st_text"] | |
| return f"font-weight: bold; color: {color}" | |
| return "font-weight: bold" | |
| def format_leaderboard(df: pd.DataFrame): | |
| df = df.copy() | |
| df["skill_score"] = df["skill_score"].round(1) | |
| df["win_rate"] = df["win_rate"].round(1) | |
| df["zero_shot"] = df["model_name"].apply(get_zero_shot_status) | |
| # Format leakage column: convert to int for all models, 0 for non-zero-shot | |
| df["training_corpus_overlap"] = df.apply( | |
| lambda row: int(round(row["training_corpus_overlap"] * 100)) if row["zero_shot"] == "✓" else 0, axis=1 | |
| ) | |
| df["link"] = df["model_name"].apply(get_model_link) | |
| df["org"] = df["model_name"].apply(get_model_organization) | |
| df = df[ | |
| [ | |
| "model_name", | |
| "win_rate", | |
| "skill_score", | |
| "median_inference_time_s", | |
| "training_corpus_overlap", | |
| "num_failures", | |
| "zero_shot", | |
| "org", | |
| "link", | |
| ] | |
| ] | |
| return df.style.map(highlight_model_type_color, subset=["model_name"]).map( | |
| lambda x: "font-weight: bold", subset=["zero_shot"] | |
| ).apply(lambda x: ['background-color: #f8f9fa' if i % 2 == 1 else '' for i in range(len(x))], axis=0) | |
| def construct_bar_chart(df: pd.DataFrame, col: str, metric_name: str): | |
| label = "Skill Score" if col == "skill_score" else "Win Rate" | |
| tooltip = [ | |
| alt.Tooltip("model_name:N"), | |
| alt.Tooltip(f"{col}:Q", format=".2f"), | |
| alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".2f"), | |
| alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"), | |
| ] | |
| base_encode = {"y": alt.Y("model_name:N", title="Forecasting Model", sort=None), "tooltip": tooltip} | |
| bars = ( | |
| alt.Chart(df) | |
| .mark_bar(color=COLORS["bar_fill"], cornerRadius=4) | |
| .encode(x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)), **base_encode) | |
| ) | |
| error_bars = ( | |
| alt.Chart(df) | |
| .mark_errorbar(ticks={"height": 5}, color=COLORS["error_bar"]) | |
| .encode( | |
| y=alt.Y("model_name:N", title=None, sort=None), | |
| x=alt.X(f"{col}_lower:Q", title=f"{label} (%)"), | |
| x2=alt.X2(f"{col}_upper:Q"), | |
| tooltip=tooltip, | |
| ) | |
| ) | |
| points = ( | |
| alt.Chart(df) | |
| .mark_point(filled=True, color=COLORS["point"]) | |
| .encode(x=alt.X(f"{col}:Q", title=f"{label} (%)"), **base_encode) | |
| ) | |
| return ( | |
| (bars + error_bars + points) | |
| .properties(height=500, title=f"{label} ({metric_name}) with 95% CIs") | |
| .configure_title(fontSize=16) | |
| ) | |
| def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str): | |
| config = { | |
| "win_rate": ("Win Rate", [0, 100], 50, f"abs(datum.{col} - 50) > 30"), | |
| "skill_score": ("Skill Score", [-15, 15], 0, f"abs(datum.{col}) > 10"), | |
| } | |
| cbar_label, domain, domain_mid, text_condition = config[col] | |
| df = df.copy() | |
| for c in [col, f"{col}_lower", f"{col}_upper"]: | |
| df[c] *= 100 | |
| model_order = df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist() | |
| tooltip = [ | |
| alt.Tooltip("model_1:N", title="Model 1"), | |
| alt.Tooltip("model_2:N", title="Model 2"), | |
| alt.Tooltip(f"{col}:Q", title=cbar_label.split(" ")[0], format=".1f"), | |
| alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".1f"), | |
| alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".1f"), | |
| ] | |
| base = alt.Chart(df).encode( | |
| x=alt.X("model_2:N", sort=model_order, title="Model 2", axis=alt.Axis(orient="top", labelAngle=-90)), | |
| y=alt.Y("model_1:N", sort=model_order, title="Model 1"), | |
| ) | |
| heatmap = base.mark_rect().encode( | |
| color=alt.Color( | |
| f"{col}:Q", | |
| legend=alt.Legend(title=f"{cbar_label} (%)", direction="vertical", orient="right"), | |
| scale=alt.Scale(scheme=HEATMAP_COLOR_SCHEME, domain=domain, domainMid=domain_mid, clamp=True), | |
| ), | |
| tooltip=tooltip, | |
| ) | |
| text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode( | |
| text=alt.Text(f"{col}:Q", format=".1f"), | |
| color=alt.condition(text_condition, alt.value(COLORS["text_white"]), alt.value(COLORS["text_black"])), | |
| tooltip=tooltip, | |
| ) | |
| return ( | |
| (heatmap + text_main) | |
| .properties(height=550, title={"text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs", "fontSize": 16}) | |
| .configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold") | |
| .resolve_scale(color="independent") | |
| ) | |
| def construct_pivot_table( | |
| summaries: pd.DataFrame, metric_name: str, baseline_model: str, leakage_imputation_model: str | |
| ) -> pd.io.formats.style.Styler: | |
| errors = fev.pivot_table(summaries=summaries, metric_column=metric_name, task_columns=["task_name"]) | |
| train_overlap = ( | |
| fev.pivot_table(summaries=summaries, metric_column="trained_on_this_dataset", task_columns=["task_name"]) | |
| .fillna(False) | |
| .astype(bool) | |
| ) | |
| is_imputed_baseline = errors.isna() | |
| is_leakage_imputed = train_overlap | |
| # Handle imputations | |
| errors = errors.mask(train_overlap, errors[leakage_imputation_model], axis=0) | |
| for col in errors.columns: | |
| if col != baseline_model: | |
| errors[col] = errors[col].fillna(errors[baseline_model]) | |
| errors = errors[errors.rank(axis=1).mean().sort_values().index] | |
| errors.index.rename("Task name", inplace=True) | |
| def highlight_by_position(styler): | |
| rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]} | |
| for row_idx in errors.index: | |
| row_ranks = errors.loc[row_idx].rank(method="min") | |
| for col_idx in errors.columns: | |
| rank = row_ranks[col_idx] | |
| style_parts = [] | |
| # Rank background colors | |
| if rank <= 3: | |
| style_parts.append(f"background-color: {rank_colors[rank]}") | |
| # Imputation text colors | |
| if is_leakage_imputed.loc[row_idx, col_idx]: | |
| style_parts.append(f"color: {COLORS['leakage_impute']}") | |
| elif is_imputed_baseline.loc[row_idx, col_idx]: | |
| style_parts.append(f"color: {COLORS['failure_impute']}") | |
| elif not style_parts or (len(style_parts) == 1 and "font-weight" in style_parts[0]): | |
| style_parts.append(f"color: {COLORS['text_default']}") | |
| if style_parts: | |
| styler = styler.map( | |
| lambda x, s="; ".join(style_parts): s, subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx] | |
| ) | |
| return styler | |
| return highlight_by_position(errors.style).format(precision=3) | |