Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| from collections import defaultdict | |
| import fsspec.config | |
| import math | |
| from datatrove.io import DataFolder, get_datafolder | |
| from datatrove.utils.stats import MetricStatsDict | |
| BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") | |
| def find_folders(base_folder, path): | |
| return sorted( | |
| [ | |
| folder["name"] | |
| for folder in base_folder.ls(path, detail=True) | |
| if folder["type"] == "directory" and not folder["name"].rstrip("/") == path | |
| ] | |
| ) | |
| def find_stats_folders(base_folder: DataFolder): | |
| # First find all stats-merged.json using globing for stats-merged.json | |
| stats_merged = base_folder.glob("**/stats-merged.json") | |
| # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name) | |
| stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] | |
| # Finally get the unique paths | |
| return list(set(stats_folders)) | |
| RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) | |
| GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])] | |
| STATS = [ | |
| Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0]))) | |
| ] | |
| def load_stats(path, stat_name, group_by): | |
| with BASE_DATA_FOLDER.open( | |
| f"{path}/{group_by}/{stat_name}/stats-merged.json", | |
| filecache={"cache_storage": "/tmp/files"}, | |
| ) as f: | |
| json_stat = json.load(f) | |
| # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme | |
| return MetricStatsDict() + MetricStatsDict(init=json_stat) | |
| def prepare_non_grouped_data(stats: MetricStatsDict): | |
| stats_rounded = defaultdict(lambda: 0) | |
| for key, value in stats.items(): | |
| stats_rounded[float(key)] += value.total | |
| normalizer = sum(stats_rounded.values()) | |
| normalizer = 1 | |
| stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} | |
| return stats_rounded | |
| def prepare_grouped_data(stats: MetricStatsDict, top_k=100): | |
| means = {key: value.mean for key, value in stats.items()} | |
| # Take the top_k most frequent keys | |
| top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k] | |
| return {key: means[key] for key in top_keys} | |
| import math | |
| import plotly.graph_objects as go | |
| from plotly.offline import plot | |
| def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str): | |
| fig = go.Figure() | |
| colors = iter( | |
| [ | |
| "rgba(31, 119, 180, 0.5)", | |
| "rgba(255, 127, 14, 0.5)", | |
| "rgba(44, 160, 44, 0.5)", | |
| "rgba(214, 39, 40, 0.5)", | |
| "rgba(148, 103, 189, 0.5)", | |
| ] | |
| ) | |
| for name, histogram in histograms.items(): | |
| if all(isinstance(k, str) for k in histogram.keys()): | |
| x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] | |
| else: | |
| x = sorted(histogram.keys()) | |
| y = [histogram[k] for k in x] | |
| fig.add_trace( | |
| go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors))) | |
| ) | |
| fig.update_layout( | |
| title=f"Line Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title="Frequency", | |
| xaxis_type="log", | |
| width=1000, | |
| height=600, | |
| ) | |
| return fig | |
| def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str): | |
| fig = go.Figure() | |
| for name, histogram in histograms.items(): | |
| x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] | |
| y = [histogram[k] for k in x] | |
| fig.add_trace(go.Bar(x=x, y=y, name=name)) | |
| fig.update_layout( | |
| title=f"Bar Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title="Frequency", | |
| autosize=True, | |
| width=600, | |
| height=600, | |
| ) | |
| return fig | |
| def update_graph(multiselect_crawls, stat_name, grouping): | |
| if len(multiselect_crawls) <= 0 or not stat_name or not grouping: | |
| return None | |
| # Placeholder for logic to rerender the graph based on the inputs | |
| prepare_fc = ( | |
| prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data | |
| ) | |
| graph_fc = plot_scatter if grouping == "histogram" else plot_bars | |
| print("Loading stats") | |
| histograms = { | |
| path: prepare_fc(load_stats(path, stat_name, grouping)) | |
| for path in multiselect_crawls | |
| } | |
| print("Plotting") | |
| return graph_fc(histograms, stat_name) | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Define the multiselect for crawls | |
| multiselect_crawls = gr.Dropdown( | |
| choices=RUNS, | |
| label="Multiselect for crawls", | |
| multiselect=True, | |
| ) | |
| with gr.Column(scale=1): | |
| # Define the dropdown for stat_name | |
| stat_name_dropdown = gr.Dropdown( | |
| choices=STATS, | |
| label="Stat name", | |
| multiselect=False, | |
| ) | |
| # Define the dropdown for grouping | |
| grouping_dropdown = gr.Dropdown( | |
| choices=GROUPS, | |
| label="Grouping", | |
| multiselect=False, | |
| ) | |
| update_button = gr.Button("Update Graph", variant="primary") | |
| with gr.Row(): | |
| # Define the graph output | |
| graph_output = gr.Plot(label="Graph") | |
| update_button.click( | |
| fn=update_graph, | |
| inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown], | |
| outputs=graph_output, | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch() | |