Spaces:
Running
Running
| import os | |
| from typing import Any | |
| import gradio as gr | |
| import pandas as pd | |
| try: | |
| from trackio.sqlite_storage import SQLiteStorage | |
| from trackio.utils import RESERVED_KEYS, TRACKIO_LOGO_PATH | |
| except: # noqa: E722 | |
| from sqlite_storage import SQLiteStorage | |
| from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH | |
| css = """ | |
| #run-cb .wrap { | |
| gap: 2px; | |
| } | |
| #run-cb .wrap label { | |
| line-height: 1; | |
| padding: 6px; | |
| } | |
| """ | |
| COLOR_PALETTE = [ | |
| "#3B82F6", | |
| "#EF4444", | |
| "#10B981", | |
| "#F59E0B", | |
| "#8B5CF6", | |
| "#EC4899", | |
| "#06B6D4", | |
| "#84CC16", | |
| "#F97316", | |
| "#6366F1", | |
| ] | |
| def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]: | |
| """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" | |
| color_map = {} | |
| for i, run in enumerate(runs): | |
| base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)] | |
| if smoothing: | |
| color_map[f"{run}_smoothed"] = base_color | |
| color_map[f"{run}_original"] = base_color + "4D" | |
| else: | |
| color_map[run] = base_color | |
| return color_map | |
| def get_projects(request: gr.Request): | |
| dataset_id = os.environ.get("TRACKIO_DATASET_ID") | |
| projects = SQLiteStorage.get_projects() | |
| if project := request.query_params.get("project"): | |
| interactive = False | |
| else: | |
| interactive = True | |
| project = projects[0] if projects else None | |
| return gr.Dropdown( | |
| label="Project", | |
| choices=projects, | |
| value=project, | |
| allow_custom_value=True, | |
| interactive=interactive, | |
| info=f"↻ Synced to <a href='https://huggingface.co/{dataset_id}' target='_blank'>{dataset_id}</a> every 5 min" | |
| if dataset_id | |
| else None, | |
| ) | |
| def get_runs(project): | |
| if not project: | |
| return [] | |
| return SQLiteStorage.get_runs(project) | |
| def load_run_data(project: str | None, run: str | None, smoothing: bool): | |
| if not project or not run: | |
| return None | |
| metrics = SQLiteStorage.get_metrics(project, run) | |
| if not metrics: | |
| return None | |
| df = pd.DataFrame(metrics) | |
| if "step" not in df.columns: | |
| df["step"] = range(len(df)) | |
| if smoothing: | |
| numeric_cols = df.select_dtypes(include="number").columns | |
| numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] | |
| df_original = df.copy() | |
| df_original["run"] = f"{run}_original" | |
| df_original["data_type"] = "original" | |
| df_smoothed = df.copy() | |
| df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean() | |
| df_smoothed["run"] = f"{run}_smoothed" | |
| df_smoothed["data_type"] = "smoothed" | |
| combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) | |
| return combined_df | |
| else: | |
| df["run"] = run | |
| df["data_type"] = "original" | |
| return df | |
| def update_runs(project, filter_text, user_interacted_with_runs=False): | |
| if project is None: | |
| runs = [] | |
| num_runs = 0 | |
| else: | |
| runs = get_runs(project) | |
| num_runs = len(runs) | |
| if filter_text: | |
| runs = [r for r in runs if filter_text in r] | |
| if not user_interacted_with_runs: | |
| return gr.CheckboxGroup( | |
| choices=runs, value=[runs[0]] if runs else [] | |
| ), gr.Textbox(label=f"Runs ({num_runs})") | |
| else: | |
| return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})") | |
| def filter_runs(project, filter_text): | |
| runs = get_runs(project) | |
| runs = [r for r in runs if filter_text in r] | |
| return gr.CheckboxGroup(choices=runs, value=runs) | |
| def toggle_timer(cb_value): | |
| if cb_value: | |
| return gr.Timer(active=True) | |
| else: | |
| return gr.Timer(active=False) | |
| def log(project: str, run: str, metrics: dict[str, Any], dataset_id: str) -> None: | |
| # Note: the type hint for dataset_id should be str | None but gr.api | |
| # doesn't support that, see: https://github.com/gradio-app/gradio/issues/11175#issuecomment-2920203317 | |
| storage = SQLiteStorage(project, run, {}, dataset_id=dataset_id) | |
| storage.log(metrics) | |
| def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: | |
| """ | |
| Sort metrics by grouping prefixes together. | |
| Metrics without prefixes come first, then grouped by prefix. | |
| Example: | |
| Input: ["train/loss", "loss", "train/acc", "val/loss"] | |
| Output: ["loss", "train/acc", "train/loss", "val/loss"] | |
| """ | |
| no_prefix = [] | |
| with_prefix = [] | |
| for metric in metrics: | |
| if "/" in metric: | |
| with_prefix.append(metric) | |
| else: | |
| no_prefix.append(metric) | |
| no_prefix.sort() | |
| prefix_groups = {} | |
| for metric in with_prefix: | |
| prefix = metric.split("/")[0] | |
| if prefix not in prefix_groups: | |
| prefix_groups[prefix] = [] | |
| prefix_groups[prefix].append(metric) | |
| sorted_with_prefix = [] | |
| for prefix in sorted(prefix_groups.keys()): | |
| sorted_with_prefix.extend(sorted(prefix_groups[prefix])) | |
| return no_prefix + sorted_with_prefix | |
| def configure(request: gr.Request): | |
| if metrics := request.query_params.get("metrics"): | |
| return metrics.split(",") | |
| else: | |
| return [] | |
| with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo: | |
| with gr.Sidebar() as sidebar: | |
| gr.Markdown( | |
| f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={TRACKIO_LOGO_PATH}' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>" | |
| ) | |
| project_dd = gr.Dropdown(label="Project") | |
| run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") | |
| run_cb = gr.CheckboxGroup( | |
| label="Runs", choices=[], interactive=True, elem_id="run-cb" | |
| ) | |
| with gr.Sidebar(position="right", open=False) as settings_sidebar: | |
| gr.Markdown("### ⚙️ Settings") | |
| realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) | |
| smoothing_cb = gr.Checkbox(label="Smoothing", value=True) | |
| timer = gr.Timer(value=1) | |
| metrics_subset = gr.State([]) | |
| user_interacted_with_run_cb = gr.State(False) | |
| gr.on( | |
| [demo.load], | |
| fn=configure, | |
| outputs=metrics_subset, | |
| ) | |
| gr.on( | |
| [demo.load], | |
| fn=get_projects, | |
| outputs=project_dd, | |
| show_progress="hidden", | |
| ) | |
| gr.on( | |
| [timer.tick], | |
| fn=update_runs, | |
| inputs=[project_dd, run_tb, user_interacted_with_run_cb], | |
| outputs=[run_cb, run_tb], | |
| show_progress="hidden", | |
| ) | |
| gr.on( | |
| [demo.load, project_dd.change], | |
| fn=update_runs, | |
| inputs=[project_dd, run_tb], | |
| outputs=[run_cb, run_tb], | |
| show_progress="hidden", | |
| ) | |
| realtime_cb.change( | |
| fn=toggle_timer, | |
| inputs=realtime_cb, | |
| outputs=timer, | |
| api_name="toggle_timer", | |
| ) | |
| run_cb.input( | |
| fn=lambda: True, | |
| outputs=user_interacted_with_run_cb, | |
| ) | |
| run_tb.input( | |
| fn=filter_runs, | |
| inputs=[project_dd, run_tb], | |
| outputs=run_cb, | |
| ) | |
| gr.api( | |
| fn=log, | |
| api_name="log", | |
| ) | |
| x_lim = gr.State(None) | |
| def update_x_lim(select_data: gr.SelectData): | |
| return select_data.index | |
| def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value): | |
| dfs = [] | |
| original_runs = runs.copy() | |
| for run in runs: | |
| df = load_run_data(project, run, smoothing) | |
| if df is not None: | |
| dfs.append(df) | |
| if dfs: | |
| master_df = pd.concat(dfs, ignore_index=True) | |
| else: | |
| master_df = pd.DataFrame() | |
| if master_df.empty: | |
| return | |
| numeric_cols = master_df.select_dtypes(include="number").columns | |
| numeric_cols = [ | |
| c for c in numeric_cols if c not in RESERVED_KEYS and c != "step" | |
| ] | |
| if metrics_subset: | |
| numeric_cols = [c for c in numeric_cols if c in metrics_subset] | |
| numeric_cols = sort_metrics_by_prefix(list(numeric_cols)) | |
| color_map = get_color_mapping(original_runs, smoothing) | |
| with gr.Row(key="row"): | |
| for metric_idx, metric_name in enumerate(numeric_cols): | |
| metric_df = master_df.dropna(subset=[metric_name]) | |
| if not metric_df.empty: | |
| plot = gr.LinePlot( | |
| metric_df, | |
| x="step", | |
| y=metric_name, | |
| color="run" if "run" in metric_df.columns else None, | |
| color_map=color_map, | |
| title=metric_name, | |
| key=f"plot-{metric_idx}", | |
| preserved_by_key=None, | |
| x_lim=x_lim_value, | |
| y_lim=[ | |
| metric_df[metric_name].min(), | |
| metric_df[metric_name].max(), | |
| ], | |
| show_fullscreen_button=True, | |
| min_width=400, | |
| ) | |
| plot.select(update_x_lim, outputs=x_lim, key=f"select-{metric_idx}") | |
| plot.double_click( | |
| lambda: None, outputs=x_lim, key=f"double-{metric_idx}" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False) | |