|
|
import hashlib |
|
|
import json |
|
|
import pickle |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import plotly.graph_objects as go |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
global CACHE_DIR |
|
|
global TASKS_INDEX_FILE |
|
|
global TASK_DATA_DIR |
|
|
global DATASET_DATA_DIR |
|
|
global METRICS_INDEX_FILE |
|
|
|
|
|
CACHE_DIR = Path("./pwc_cache") |
|
|
CACHE_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
TASKS_INDEX_FILE = CACHE_DIR / "tasks_index.json" |
|
|
TASK_DATA_DIR = CACHE_DIR / "task_data" |
|
|
DATASET_DATA_DIR = CACHE_DIR / "dataset_data" |
|
|
METRICS_INDEX_FILE = CACHE_DIR / "metrics_index.json" |
|
|
|
|
|
|
|
|
TASK_DATA_DIR.mkdir(exist_ok=True) |
|
|
DATASET_DATA_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
def sanitize_filename(name): |
|
|
"""Convert a string to a safe filename.""" |
|
|
|
|
|
safe_name = name.replace('/', '_').replace('\\', '_').replace(':', '_') |
|
|
safe_name = safe_name.replace('*', '_').replace('?', '_').replace('"', '_') |
|
|
safe_name = safe_name.replace('<', '_').replace('>', '_').replace('|', '_') |
|
|
safe_name = safe_name.replace(' ', '_').replace('.', '_') |
|
|
|
|
|
safe_name = '_'.join(filter(None, safe_name.split('_'))) |
|
|
|
|
|
if len(safe_name) > 200: |
|
|
|
|
|
safe_name = safe_name[:150] + '_' + hashlib.md5(name.encode()).hexdigest()[:8] |
|
|
return safe_name |
|
|
|
|
|
|
|
|
def get_task_filename(task): |
|
|
"""Generate a safe filename for a task.""" |
|
|
safe_name = sanitize_filename(task) |
|
|
return TASK_DATA_DIR / f"task_{safe_name}.pkl" |
|
|
|
|
|
|
|
|
def get_dataset_filename(task, dataset_name): |
|
|
"""Generate a safe filename for a dataset.""" |
|
|
safe_task = sanitize_filename(task) |
|
|
safe_dataset = sanitize_filename(dataset_name) |
|
|
|
|
|
filename = f"data_{safe_task}_{safe_dataset}.pkl" |
|
|
|
|
|
if len(filename) > 255: |
|
|
|
|
|
filename = f"data_{safe_task[:50]}_{safe_dataset[:50]}_{hashlib.md5(f'{task}||{dataset_name}'.encode()).hexdigest()[:8]}.pkl" |
|
|
return DATASET_DATA_DIR / filename |
|
|
|
|
|
|
|
|
def cache_exists(): |
|
|
"""Check if cache structure exists.""" |
|
|
print(f"{TASKS_INDEX_FILE =}") |
|
|
print(f"{METRICS_INDEX_FILE =}") |
|
|
print(f"{TASKS_INDEX_FILE.exists() =}") |
|
|
print(f"{METRICS_INDEX_FILE.exists() =}") |
|
|
|
|
|
return TASKS_INDEX_FILE.exists() and METRICS_INDEX_FILE.exists() |
|
|
|
|
|
|
|
|
def build_disk_based_cache(): |
|
|
"""Build cache with minimal memory usage - process dataset in streaming fashion.""" |
|
|
|
|
|
import os |
|
|
print("Michael test", os.path.isdir("./pwc_cache")) |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("=" * 60) |
|
|
print("Building disk-based cache (one-time operation)...") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
tasks_set = set() |
|
|
metrics_index = {} |
|
|
|
|
|
print("\n[1/4] Streaming dataset and building cache...") |
|
|
|
|
|
|
|
|
ds = load_dataset("pwc-archive/evaluation-tables", split="train", streaming=False) |
|
|
total_items = len(ds) |
|
|
|
|
|
processed_count = 0 |
|
|
dataset_count = 0 |
|
|
|
|
|
for idx, item in tqdm(enumerate(ds), total=total_items): |
|
|
|
|
|
|
|
|
task = item['task'] |
|
|
if not task: |
|
|
continue |
|
|
|
|
|
tasks_set.add(task) |
|
|
|
|
|
|
|
|
task_file = get_task_filename(task) |
|
|
if task_file.exists(): |
|
|
with open(task_file, 'rb') as f: |
|
|
task_data = pickle.load(f) |
|
|
else: |
|
|
task_data = { |
|
|
'categories': set(), |
|
|
'datasets': set(), |
|
|
'date_range': {'min': None, 'max': None} |
|
|
} |
|
|
|
|
|
|
|
|
if item['categories']: |
|
|
task_data['categories'].update(item['categories']) |
|
|
|
|
|
|
|
|
if item['datasets']: |
|
|
for dataset in item['datasets']: |
|
|
if not isinstance(dataset, dict) or 'dataset' not in dataset: |
|
|
continue |
|
|
|
|
|
dataset_name = dataset['dataset'] |
|
|
dataset_file = get_dataset_filename(task, dataset_name) |
|
|
|
|
|
|
|
|
if dataset_file.exists(): |
|
|
task_data['datasets'].add(dataset_name) |
|
|
continue |
|
|
|
|
|
task_data['datasets'].add(dataset_name) |
|
|
|
|
|
|
|
|
if 'sota' not in dataset or 'rows' not in dataset['sota']: |
|
|
continue |
|
|
|
|
|
models_data = [] |
|
|
for row in dataset['sota']['rows']: |
|
|
if not isinstance(row, dict): |
|
|
continue |
|
|
|
|
|
model_name = row.get('model_name', 'Unknown Model') |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
if 'metrics' in row and isinstance(row['metrics'], dict): |
|
|
for metric_name, metric_value in row['metrics'].items(): |
|
|
if metric_value is not None: |
|
|
metrics[metric_name] = metric_value |
|
|
|
|
|
if metric_name not in metrics_index: |
|
|
metrics_index[metric_name] = { |
|
|
'count': 0, |
|
|
'is_lower_better': any(kw in metric_name.lower() |
|
|
for kw in ['error', 'loss', 'time', 'cost']) |
|
|
} |
|
|
metrics_index[metric_name]['count'] += 1 |
|
|
|
|
|
|
|
|
paper_date = row.get('paper_date') |
|
|
try: |
|
|
if paper_date and isinstance(paper_date, str): |
|
|
release_date = pd.to_datetime(paper_date) |
|
|
else: |
|
|
release_date = pd.to_datetime('2020-01-01') |
|
|
except: |
|
|
release_date = pd.to_datetime('2020-01-01') |
|
|
|
|
|
|
|
|
if task_data['date_range']['min'] is None or release_date < task_data['date_range']['min']: |
|
|
task_data['date_range']['min'] = release_date |
|
|
if task_data['date_range']['max'] is None or release_date > task_data['date_range']['max']: |
|
|
task_data['date_range']['max'] = release_date |
|
|
|
|
|
|
|
|
model_entry = { |
|
|
'model_name': model_name, |
|
|
'release_date': release_date, |
|
|
'paper_date': row.get('paper_date', ''), |
|
|
'paper_url': row.get('paper_url', ''), |
|
|
'paper_title': row.get('paper_title', ''), |
|
|
'code_url': row.get('code_links', [''])[0] if row.get('code_links') else '', |
|
|
**metrics |
|
|
} |
|
|
|
|
|
models_data.append(model_entry) |
|
|
|
|
|
if models_data: |
|
|
df = pd.DataFrame(models_data) |
|
|
df = df.sort_values('release_date') |
|
|
|
|
|
|
|
|
with open(dataset_file, 'wb') as f: |
|
|
pickle.dump(df, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
dataset_count += 1 |
|
|
|
|
|
|
|
|
del df |
|
|
del models_data |
|
|
|
|
|
|
|
|
with open(task_file, 'wb') as f: |
|
|
|
|
|
task_data_to_save = { |
|
|
'categories': sorted(list(task_data['categories'])), |
|
|
'datasets': sorted(list(task_data['datasets'])), |
|
|
'date_range': task_data['date_range'] |
|
|
} |
|
|
pickle.dump(task_data_to_save, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
|
|
|
del task_data |
|
|
processed_count += 1 |
|
|
|
|
|
print(f"\nβ Processed {len(tasks_set)} tasks and {dataset_count} datasets") |
|
|
|
|
|
print("\n[2/4] Saving index files...") |
|
|
|
|
|
|
|
|
tasks_list = sorted(list(tasks_set)) |
|
|
with open(TASKS_INDEX_FILE, 'w') as f: |
|
|
json.dump(tasks_list, f) |
|
|
print(f" β Saved tasks index ({len(tasks_list)} tasks)") |
|
|
|
|
|
|
|
|
with open(METRICS_INDEX_FILE, 'w') as f: |
|
|
json.dump(metrics_index, f, indent=2) |
|
|
print(f" β Saved metrics index ({len(metrics_index)} metrics)") |
|
|
|
|
|
print("\n[3/4] Calculating cache statistics...") |
|
|
|
|
|
|
|
|
total_size = 0 |
|
|
for file in TASK_DATA_DIR.glob("*.pkl"): |
|
|
total_size += file.stat().st_size |
|
|
for file in DATASET_DATA_DIR.glob("*.pkl"): |
|
|
total_size += file.stat().st_size |
|
|
|
|
|
print(f" β Total cache size: {total_size / 1024 / 1024:.1f} MB") |
|
|
print(f" β Task files: {len(list(TASK_DATA_DIR.glob('*.pkl')))}") |
|
|
print(f" β Dataset files: {len(list(DATASET_DATA_DIR.glob('*.pkl')))}") |
|
|
|
|
|
print("\n[4/4] Cache building complete!") |
|
|
print("=" * 60) |
|
|
|
|
|
return tasks_list |
|
|
|
|
|
|
|
|
def load_tasks_index(): |
|
|
"""Load just the task list from disk.""" |
|
|
with open(TASKS_INDEX_FILE, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def load_task_data(task): |
|
|
"""Load data for a specific task from disk.""" |
|
|
task_file = get_task_filename(task) |
|
|
if task_file.exists(): |
|
|
with open(task_file, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
return None |
|
|
|
|
|
|
|
|
def load_dataset_data(task, dataset_name): |
|
|
"""Load a specific dataset from disk.""" |
|
|
dataset_file = get_dataset_filename(task, dataset_name) |
|
|
if dataset_file.exists(): |
|
|
with open(dataset_file, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
def load_metrics_index(): |
|
|
"""Load metrics index from disk.""" |
|
|
if METRICS_INDEX_FILE.exists(): |
|
|
with open(METRICS_INDEX_FILE, 'r') as f: |
|
|
return json.load(f) |
|
|
return {} |
|
|
|
|
|
|
|
|
if cache_exists(): |
|
|
print("Loading task index from disk...") |
|
|
TASKS = load_tasks_index() |
|
|
print(f"β Loaded {len(TASKS)} tasks") |
|
|
else: |
|
|
TASKS = build_disk_based_cache() |
|
|
|
|
|
|
|
|
METRICS_INDEX = load_metrics_index() |
|
|
|
|
|
|
|
|
|
|
|
def get_tasks(): |
|
|
"""Get all tasks from index.""" |
|
|
return TASKS |
|
|
|
|
|
|
|
|
def get_task_data(task): |
|
|
"""Load task data from disk on-demand.""" |
|
|
return load_task_data(task) |
|
|
|
|
|
|
|
|
def get_categories(task): |
|
|
"""Get categories for a task (loads from disk).""" |
|
|
task_data = get_task_data(task) |
|
|
return task_data['categories'] if task_data else [] |
|
|
|
|
|
|
|
|
def get_datasets_for_task(task): |
|
|
"""Get datasets for a task (loads from disk).""" |
|
|
task_data = get_task_data(task) |
|
|
return task_data['datasets'] if task_data else [] |
|
|
|
|
|
|
|
|
def get_cached_model_data(task, dataset_name): |
|
|
"""Load dataset from disk on-demand.""" |
|
|
return load_dataset_data(task, dataset_name) |
|
|
|
|
|
|
|
|
def parse_paper_date(paper_date, paper_title="", paper_url=""): |
|
|
"""Parse paper date with improved fallback strategies.""" |
|
|
import re |
|
|
|
|
|
|
|
|
if paper_date and isinstance(paper_date, str) and paper_date.strip(): |
|
|
try: |
|
|
|
|
|
date_formats = [ |
|
|
'%Y-%m-%d', |
|
|
'%Y/%m/%d', |
|
|
'%d-%m-%Y', |
|
|
'%d/%m/%Y', |
|
|
'%Y-%m', |
|
|
'%Y/%m', |
|
|
'%Y' |
|
|
] |
|
|
|
|
|
for fmt in date_formats: |
|
|
try: |
|
|
return pd.to_datetime(paper_date.strip(), format=fmt) |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
return pd.to_datetime(paper_date.strip()) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
year_pattern = r'\b(19[5-9]\d|20[0-9]\d)\b' |
|
|
|
|
|
|
|
|
if paper_title: |
|
|
years = re.findall(year_pattern, str(paper_title)) |
|
|
if years: |
|
|
try: |
|
|
year = max(years) |
|
|
return pd.to_datetime(f'{year}-01-01') |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if paper_url: |
|
|
years = re.findall(year_pattern, str(paper_url)) |
|
|
if years: |
|
|
try: |
|
|
year = max(years) |
|
|
return pd.to_datetime(f'{year}-01-01') |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def get_task_statistics(task): |
|
|
"""Get statistics about a task.""" |
|
|
return {} |
|
|
|
|
|
|
|
|
def create_sota_plot(df, metric): |
|
|
"""Create a plot showing model performance evolution over time. |
|
|
|
|
|
Args: |
|
|
df: DataFrame with model data |
|
|
metric: Metric name to plot on y-axis |
|
|
""" |
|
|
if df.empty or metric not in df.columns: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text="No data available for this metric", |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=20) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="No Data Available", |
|
|
height=600, |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
df_clean = df.dropna(subset=[metric]).copy() |
|
|
|
|
|
if df_clean.empty: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text="No valid data points for this metric", |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=20) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="No Data Available", |
|
|
height=600, |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
try: |
|
|
df_clean[metric] = pd.to_numeric( |
|
|
df_clean[metric].apply(lambda x: x.strip()[:-1] if isinstance(x, str) and x.strip().endswith("%") else x), |
|
|
errors='coerce') |
|
|
|
|
|
df_clean = df_clean.dropna(subset=[metric]) |
|
|
|
|
|
if df_clean.empty: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text=f"No numeric data available for metric: {metric}", |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=20) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="No Numeric Data Available", |
|
|
height=600, |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text=f"Error processing metric data: {str(e)}", |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=16) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="Data Processing Error", |
|
|
height=600, |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
df_processed = df_clean.copy() |
|
|
if 'paper_date' in df_processed.columns: |
|
|
|
|
|
df_processed['dynamic_release_date'] = df_processed.apply( |
|
|
lambda row: parse_paper_date( |
|
|
row.get('paper_date', ''), |
|
|
row.get('paper_title', ''), |
|
|
row.get('paper_url', '') |
|
|
), axis=1 |
|
|
) |
|
|
|
|
|
df_processed['final_release_date'] = df_processed['dynamic_release_date'].fillna(df_processed['release_date']) |
|
|
else: |
|
|
|
|
|
df_processed['final_release_date'] = df_processed['release_date'] |
|
|
|
|
|
|
|
|
df_with_dates = df_processed[df_processed['final_release_date'].notna()].copy() |
|
|
|
|
|
if df_with_dates.empty: |
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text="No valid dates available for this dataset", |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=20) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="No Date Data Available", |
|
|
height=600, |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
df_sorted = df_with_dates.sort_values('final_release_date').copy() |
|
|
|
|
|
|
|
|
is_lower_better = False |
|
|
if metric in METRICS_INDEX: |
|
|
is_lower_better = METRICS_INDEX[metric].get('is_lower_better', False) |
|
|
else: |
|
|
is_lower_better = any(keyword in metric.lower() for keyword in ['error', 'loss', 'time', 'cost']) |
|
|
|
|
|
if is_lower_better: |
|
|
df_sorted['cumulative_best'] = df_sorted[metric].cummin() |
|
|
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] |
|
|
else: |
|
|
df_sorted['cumulative_best'] = df_sorted[metric].cummax() |
|
|
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] |
|
|
|
|
|
|
|
|
sota_df = df_sorted[df_sorted['is_sota']].copy() |
|
|
|
|
|
|
|
|
x_values = df_sorted['final_release_date'] |
|
|
x_axis_title = 'Release Date' |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=x_values, |
|
|
y=df_sorted[metric], |
|
|
mode='markers', |
|
|
name='All models', |
|
|
marker=dict( |
|
|
color=['#00CED1' if is_sota else 'lightgray' |
|
|
for is_sota in df_sorted['is_sota']], |
|
|
size=8, |
|
|
opacity=0.7 |
|
|
), |
|
|
text=df_sorted['model_name'], |
|
|
customdata=df_sorted[['paper_title', 'paper_url', 'code_url']], |
|
|
hovertemplate='<b>%{text}</b><br>' + |
|
|
f'{metric}: %{{y:.4f}}<br>' + |
|
|
'Date: %{x}<br>' + |
|
|
'Paper: %{customdata[0]}<br>' + |
|
|
'<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=x_values, |
|
|
y=df_sorted['cumulative_best'], |
|
|
mode='lines', |
|
|
name=f'SOTA (cumulative {"min" if is_lower_better else "max"})', |
|
|
line=dict(color='#00CED1', width=2, dash='solid'), |
|
|
hovertemplate=f'SOTA {metric}: %{{y:.4f}}<br>{x_axis_title}: %{{x}}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
if not sota_df.empty: |
|
|
|
|
|
y_range = df_sorted[metric].max() - df_sorted[metric].min() |
|
|
|
|
|
|
|
|
if y_range > 0: |
|
|
base_offset = y_range * 0.03 |
|
|
|
|
|
label_offset = max(y_range * 0.01, min(base_offset, y_range * 0.08)) |
|
|
else: |
|
|
|
|
|
label_offset = 1 |
|
|
|
|
|
|
|
|
previous_labels = [] |
|
|
|
|
|
try: |
|
|
date_range = (df_sorted['final_release_date'].max() - df_sorted['final_release_date'].min()).days |
|
|
min_separation = max(30, date_range * 0.05) |
|
|
except (TypeError, AttributeError): |
|
|
|
|
|
min_separation = 30 |
|
|
|
|
|
for i, (_, row) in enumerate(sota_df.iterrows()): |
|
|
|
|
|
if is_lower_better: |
|
|
|
|
|
base_ay_offset = -label_offset |
|
|
base_yshift = -8 |
|
|
alternate_multiplier = -1 |
|
|
else: |
|
|
|
|
|
base_ay_offset = label_offset |
|
|
base_yshift = 8 |
|
|
alternate_multiplier = 1 |
|
|
|
|
|
|
|
|
current_x = row['final_release_date'] |
|
|
collision_detected = False |
|
|
|
|
|
for prev_x, prev_ay in previous_labels: |
|
|
try: |
|
|
x_diff = abs((current_x - prev_x).days) |
|
|
if x_diff < min_separation: |
|
|
collision_detected = True |
|
|
break |
|
|
except (TypeError, AttributeError): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
if collision_detected: |
|
|
|
|
|
ay_offset = base_ay_offset + (alternate_multiplier * label_offset * 0.7 * (i % 2)) |
|
|
yshift = base_yshift + (alternate_multiplier * 12 * (i % 2)) |
|
|
else: |
|
|
ay_offset = base_ay_offset |
|
|
yshift = base_yshift |
|
|
|
|
|
|
|
|
fig.add_annotation( |
|
|
x=current_x, |
|
|
y=row[metric], |
|
|
text=row['model_name'][:25] + '...' if len(row['model_name']) > 25 else row['model_name'], |
|
|
showarrow=True, |
|
|
arrowhead=2, |
|
|
arrowsize=1, |
|
|
arrowwidth=1, |
|
|
arrowcolor='#00CED1', |
|
|
ax=0, |
|
|
ay=ay_offset, |
|
|
yshift=yshift, |
|
|
font=dict(size=8, color='#333333'), |
|
|
bgcolor='rgba(255, 255, 255, 0.9)', |
|
|
borderwidth=0 |
|
|
) |
|
|
|
|
|
|
|
|
previous_labels.append((current_x, ay_offset)) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title=f'SOTA Evolution: {metric}', |
|
|
xaxis_title=x_axis_title, |
|
|
yaxis_title=metric, |
|
|
xaxis=dict(showgrid=True, gridcolor='lightgray'), |
|
|
yaxis=dict(showgrid=True, gridcolor='lightgray'), |
|
|
plot_bgcolor='white', |
|
|
paper_bgcolor='white', |
|
|
height=600, |
|
|
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), |
|
|
hovermode='closest' |
|
|
) |
|
|
|
|
|
|
|
|
del df_clean |
|
|
del df_sorted |
|
|
del sota_df |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π Papers with Code - SOTA Evolution Visualizer") |
|
|
gr.Markdown( |
|
|
"Navigate through ML tasks and datasets to visualize the evolution of state-of-the-art models over time.") |
|
|
gr.Markdown("*Optimized for low memory usage - data is loaded on-demand from disk*") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown(f""" |
|
|
<div style="background-color: #f0f9ff; border-left: 4px solid #00CED1; padding: 10px; margin: 10px 0;"> |
|
|
<b>πΎ Disk-Based Storage Active</b><br> |
|
|
β’ <b>{len(TASKS)}</b> tasks indexed<br> |
|
|
β’ <b>{len(METRICS_INDEX)}</b> unique metrics tracked<br> |
|
|
β’ Data loaded on-demand to minimize RAM usage |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
current_df = gr.State(pd.DataFrame()) |
|
|
current_task = gr.State(None) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
task_dropdown = gr.Dropdown( |
|
|
choices=get_tasks(), |
|
|
label="Select Task", |
|
|
interactive=True |
|
|
) |
|
|
category_dropdown = gr.Dropdown( |
|
|
choices=[], |
|
|
label="Categories (info only)", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
dataset_dropdown = gr.Dropdown( |
|
|
choices=[], |
|
|
label="Select Dataset", |
|
|
interactive=True |
|
|
) |
|
|
metric_dropdown = gr.Dropdown( |
|
|
choices=[], |
|
|
label="Select Metric", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
|
|
|
info_text = gr.Markdown("π Please select a task to begin") |
|
|
|
|
|
|
|
|
plot = gr.Plot(label="SOTA Evolution") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
show_data_btn = gr.Button("π Show/Hide Model Data") |
|
|
export_btn = gr.Button("πΎ Export Current Data (CSV)") |
|
|
clear_memory_btn = gr.Button("π§Ή Clear Memory", variant="secondary") |
|
|
|
|
|
df_display = gr.Dataframe( |
|
|
label="Model Data", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def update_task_selection(task): |
|
|
"""Update dropdowns when task is selected.""" |
|
|
if not task: |
|
|
return [], [], [], "π Please select a task to begin", pd.DataFrame(), None, None |
|
|
|
|
|
|
|
|
categories = get_categories(task) |
|
|
datasets = get_datasets_for_task(task) |
|
|
|
|
|
info = f"### π **Task:** {task}\n" |
|
|
if categories: |
|
|
info += f"- **Categories:** {', '.join(categories[:3])}{'...' if len(categories) > 3 else ''} ({len(categories)} total)\n" |
|
|
|
|
|
return ( |
|
|
gr.Dropdown(choices=categories, value=categories[0] if categories else None), |
|
|
gr.Dropdown(choices=datasets, value=None), |
|
|
gr.Dropdown(choices=[], value=None), |
|
|
info, |
|
|
pd.DataFrame(), |
|
|
None, |
|
|
task |
|
|
) |
|
|
|
|
|
|
|
|
def update_dataset_selection(task, dataset_name): |
|
|
"""Update when dataset is selected - loads from disk.""" |
|
|
if not task or not dataset_name: |
|
|
return [], "", pd.DataFrame(), None |
|
|
|
|
|
|
|
|
df = get_cached_model_data(task, dataset_name) |
|
|
|
|
|
if df.empty: |
|
|
return [], f"β οΈ No models found for dataset: {dataset_name}", df, None |
|
|
|
|
|
|
|
|
exclude_cols = ['model_name', 'release_date', 'paper_date', 'paper_url', 'paper_title', 'code_url'] |
|
|
metric_cols = [col for col in df.columns if col not in exclude_cols] |
|
|
|
|
|
info = f"### π **Dataset:** {dataset_name}\n" |
|
|
info += f"- **Models:** {len(df)} models\n" |
|
|
info += f"- **Metrics:** {len(metric_cols)} metrics available\n" |
|
|
if not df.empty: |
|
|
info += f"- **Date Range:** {df['release_date'].min().strftime('%Y-%m-%d')} to {df['release_date'].max().strftime('%Y-%m-%d')}\n" |
|
|
|
|
|
if metric_cols: |
|
|
info += f"- **Available Metrics:** {', '.join(metric_cols[:5])}{'...' if len(metric_cols) > 5 else ''}" |
|
|
|
|
|
return ( |
|
|
gr.Dropdown(choices=metric_cols, value=metric_cols[0] if metric_cols else None), |
|
|
info, |
|
|
df, |
|
|
None |
|
|
) |
|
|
|
|
|
|
|
|
def update_plot(df, metric): |
|
|
"""Update plot when metric is selected.""" |
|
|
if df.empty or not metric: |
|
|
return None |
|
|
plot_result = create_sota_plot(df, metric) |
|
|
return plot_result |
|
|
|
|
|
|
|
|
def toggle_dataframe(df): |
|
|
"""Toggle dataframe visibility.""" |
|
|
if df.empty: |
|
|
return gr.Dataframe(value=pd.DataFrame(), visible=False) |
|
|
|
|
|
display_cols = ['model_name', 'release_date'] + [col for col in df.columns |
|
|
if col not in ['model_name', 'release_date', 'paper_date', |
|
|
'paper_url', |
|
|
'paper_title', 'code_url']] |
|
|
display_df = df[display_cols].copy() |
|
|
display_df['release_date'] = display_df['release_date'].dt.strftime('%Y-%m-%d') |
|
|
return gr.Dataframe(value=display_df, visible=True) |
|
|
|
|
|
|
|
|
def export_data(df): |
|
|
"""Export current dataframe to CSV.""" |
|
|
if df.empty: |
|
|
return "β οΈ No data to export" |
|
|
|
|
|
filename = f"sota_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" |
|
|
df.to_csv(filename, index=False) |
|
|
return f"β
Data exported to {filename} ({len(df)} models)" |
|
|
|
|
|
|
|
|
def clear_memory(): |
|
|
"""Clear memory by forcing garbage collection.""" |
|
|
import gc |
|
|
gc.collect() |
|
|
return "β
Memory cleared" |
|
|
|
|
|
|
|
|
|
|
|
task_dropdown.change( |
|
|
fn=update_task_selection, |
|
|
inputs=task_dropdown, |
|
|
outputs=[category_dropdown, dataset_dropdown, |
|
|
metric_dropdown, info_text, current_df, plot, current_task] |
|
|
) |
|
|
|
|
|
dataset_dropdown.change( |
|
|
fn=update_dataset_selection, |
|
|
inputs=[task_dropdown, dataset_dropdown], |
|
|
outputs=[metric_dropdown, info_text, current_df, plot] |
|
|
) |
|
|
|
|
|
metric_dropdown.change( |
|
|
fn=update_plot, |
|
|
inputs=[current_df, metric_dropdown], |
|
|
outputs=plot |
|
|
) |
|
|
|
|
|
show_data_btn.click( |
|
|
fn=toggle_dataframe, |
|
|
inputs=current_df, |
|
|
outputs=df_display |
|
|
) |
|
|
|
|
|
export_btn.click( |
|
|
fn=export_data, |
|
|
inputs=current_df, |
|
|
outputs=info_text |
|
|
) |
|
|
|
|
|
clear_memory_btn.click( |
|
|
fn=clear_memory, |
|
|
inputs=[], |
|
|
outputs=info_text |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π How to Use |
|
|
1. **Select a Task** from the first dropdown |
|
|
2. **Select a Dataset** to analyze |
|
|
3. **Select a Metric** to visualize |
|
|
4. The plot shows SOTA model evolution over time with dynamically calculated dates |
|
|
|
|
|
### πΎ Memory Optimization |
|
|
- Data is stored on disk and loaded on-demand |
|
|
- Only the current task and dataset are kept in memory |
|
|
- Use "Clear Memory" button if needed |
|
|
- Infinite disk space is utilized for permanent caching |
|
|
|
|
|
### π¨ Plot Features |
|
|
- **π΅ Cyan dots**: SOTA models when released |
|
|
- **βͺ Gray dots**: Other models |
|
|
- **π Cyan line**: SOTA progression |
|
|
- **π Hover**: View model details |
|
|
- **π·οΈ Smart Labels**: SOTA model labels positioned close to the line with intelligent collision detection |
|
|
""") |
|
|
|
|
|
|
|
|
def test_sota_label_positioning(): |
|
|
"""Test function to validate SOTA label positioning improvements.""" |
|
|
print("π§ͺ Testing SOTA label positioning...") |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
test_data = { |
|
|
'model_name': ['Model A', 'Model B', 'Model C', 'Model D'], |
|
|
'release_date': [ |
|
|
datetime(2020, 1, 1), |
|
|
datetime(2020, 6, 1), |
|
|
datetime(2021, 1, 1), |
|
|
datetime(2021, 6, 1) |
|
|
], |
|
|
'paper_title': ['Paper A', 'Paper B', 'Paper C', 'Paper D'], |
|
|
'paper_url': ['http://example.com/a', 'http://example.com/b', 'http://example.com/c', 'http://example.com/d'], |
|
|
'code_url': ['http://github.com/a', 'http://github.com/b', 'http://github.com/c', 'http://github.com/d'], |
|
|
'accuracy': [0.85, 0.87, 0.90, 0.92], |
|
|
'error_rate': [0.15, 0.13, 0.10, 0.08] |
|
|
} |
|
|
|
|
|
df_test = pd.DataFrame(test_data) |
|
|
|
|
|
|
|
|
print(" Testing with higher-better metric (accuracy)...") |
|
|
try: |
|
|
fig1 = create_sota_plot(df_test, 'accuracy') |
|
|
print(" β
Higher-better metric test passed") |
|
|
except Exception as e: |
|
|
print(f" β Higher-better metric test failed: {e}") |
|
|
|
|
|
|
|
|
print(" Testing with lower-better metric (error_rate)...") |
|
|
try: |
|
|
fig2 = create_sota_plot(df_test, 'error_rate') |
|
|
print(" β
Lower-better metric test passed") |
|
|
except Exception as e: |
|
|
print(f" β Lower-better metric test failed: {e}") |
|
|
|
|
|
|
|
|
print(" Testing with empty dataframe...") |
|
|
try: |
|
|
fig3 = create_sota_plot(pd.DataFrame(), 'test_metric') |
|
|
print(" β
Empty data test passed") |
|
|
except Exception as e: |
|
|
print(f" β Empty data test failed: {e}") |
|
|
|
|
|
|
|
|
print(" Testing with string metric data...") |
|
|
try: |
|
|
df_test_string = df_test.copy() |
|
|
df_test_string['string_metric'] = ['low', 'medium', 'high', 'very_high'] |
|
|
fig4 = create_sota_plot(df_test_string, 'string_metric') |
|
|
print(" β
String metric test passed (handled gracefully)") |
|
|
except Exception as e: |
|
|
print(f" β String metric test failed: {e}") |
|
|
|
|
|
|
|
|
print(" Testing with mixed data types...") |
|
|
try: |
|
|
df_test_mixed = df_test.copy() |
|
|
df_test_mixed['mixed_metric'] = [0.85, 'N/A', 0.90, 0.92] |
|
|
fig5 = create_sota_plot(df_test_mixed, 'mixed_metric') |
|
|
print(" β
Mixed data test passed") |
|
|
except Exception as e: |
|
|
print(f" β Mixed data test failed: {e}") |
|
|
|
|
|
|
|
|
print(" Testing with paper_date column...") |
|
|
try: |
|
|
df_test_dates = df_test.copy() |
|
|
df_test_dates['paper_date'] = ['2015-03-15', '2018-invalid', '2021-12-01', '2022'] |
|
|
fig6 = create_sota_plot(df_test_dates, 'accuracy') |
|
|
print(" β
Paper date parsing test passed") |
|
|
except Exception as e: |
|
|
print(f" β Paper date parsing test failed: {e}") |
|
|
|
|
|
print("π SOTA label positioning tests completed!") |
|
|
return True |
|
|
|
|
|
demo.launch() |