Spaces:
Runtime error
Runtime error
| from concurrent.futures import ThreadPoolExecutor | |
| import enum | |
| from functools import partial | |
| import json | |
| from pathlib import Path | |
| import re | |
| import tempfile | |
| from typing import Literal | |
| import gradio as gr | |
| from collections import defaultdict | |
| from datatrove.io import DataFolder, get_datafolder | |
| import plotly.graph_objects as go | |
| from datatrove.utils.stats import MetricStatsDict | |
| import plotly.express as px | |
| import gradio as gr | |
| PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"] | |
| LOG_SCALE_STATS = { | |
| "length", | |
| "n_lines", | |
| "n_docs", | |
| "n_words", | |
| "avg_words_per_line", | |
| "pages_with_lorem_ipsum", | |
| } | |
| def find_folders(base_folder, path): | |
| base_folder = get_datafolder(base_folder) | |
| if not base_folder.exists(path): | |
| return [] | |
| return sorted( | |
| [ | |
| folder["name"] | |
| for folder in base_folder.ls(path, detail=True) | |
| if folder["type"] == "directory" and not folder["name"].rstrip("/") == path | |
| ] | |
| ) | |
| def find_stats_folders(base_folder: str): | |
| base_data_folder = get_datafolder(base_folder) | |
| # First find all stats-merged.json using globing for stats-merged.json | |
| stats_merged = base_data_folder.glob("**/stats-merged.json") | |
| # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name) | |
| stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] | |
| # Finally get the unique paths | |
| return sorted(list(set(stats_folders))) | |
| def fetch_datasets(base_folder: str): | |
| datasets = sorted(find_stats_folders(base_folder)) | |
| return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union") | |
| def export_data(exported_data): | |
| if not exported_data: | |
| return None | |
| # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported | |
| with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as temp: | |
| json.dump(exported_data, temp) | |
| temp_path = temp.name | |
| return gr.update(visible=True, value=temp_path) | |
| def fetch_groups(base_folder, datasets, old_groups, type="intersection"): | |
| if not datasets: | |
| return gr.update(choices=[], value=None) | |
| with ThreadPoolExecutor() as executor: | |
| GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets)) | |
| if len(GROUPS) == 0: | |
| return gr.update(choices=[], value=None) | |
| if type == "intersection": | |
| new_choices = set.intersection(*(set(g) for g in GROUPS)) | |
| elif type == "union": | |
| new_choices = set.union(*(set(g) for g in GROUPS)) | |
| value = None | |
| if old_groups: | |
| value = list(set.intersection(new_choices, {old_groups})) | |
| value = value[0] if value else None | |
| # now take the intersection of all grups | |
| return gr.update(choices=sorted(list(new_choices)), value=value) | |
| def fetch_stats(base_folder, datasets, group, old_stats, type="intersection"): | |
| print("Fetching stats") | |
| with ThreadPoolExecutor() as executor: | |
| STATS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets)) | |
| if len(STATS) == 0: | |
| return gr.update(choices=[], value=None) | |
| if type == "intersection": | |
| new_possibles_choices = set.intersection(*(set(s) for s in STATS)) | |
| elif type == "union": | |
| new_possibles_choices = set.union(*(set(s) for s in STATS)) | |
| value = None | |
| if old_stats: | |
| value = list(set.intersection(new_possibles_choices, {old_stats})) | |
| value = value[0] if value else None | |
| return gr.update(choices=sorted(list(new_possibles_choices)), value=value) | |
| def reverse_search(base_folder, possible_datasets, grouping, stat_name): | |
| with ThreadPoolExecutor() as executor: | |
| found_datasets = list(executor.map(lambda dataset: dataset if stat_exists(base_folder, dataset, stat_name, grouping) else None, possible_datasets)) | |
| found_datasets = [dataset for dataset in found_datasets if dataset is not None] | |
| return "\n".join(found_datasets) | |
| def reverse_search_add(datasets, reverse_search_results): | |
| datasets = datasets or [] | |
| return sorted(list(set(datasets + reverse_search_results.strip().split("\n")))) | |
| def stat_exists(base_folder, path, stat_name, group_by): | |
| base_folder = get_datafolder(base_folder) | |
| return base_folder.exists(f"{path}/{group_by}/{stat_name}/stats-merged.json") | |
| def load_stats(base_folder, path, stat_name, group_by): | |
| base_folder = get_datafolder(base_folder) | |
| with base_folder.open( | |
| f"{path}/{group_by}/{stat_name}/stats-merged.json", | |
| ) as f: | |
| json_stat = json.load(f) | |
| # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme | |
| return MetricStatsDict() + MetricStatsDict(init=json_stat) | |
| def prepare_non_grouped_data(dataset_path, base_folder, grouping, stat_name, normalization): | |
| stats = load_stats(base_folder, dataset_path, stat_name, grouping) | |
| stats_rounded = defaultdict(lambda: 0) | |
| for key, value in stats.items(): | |
| stats_rounded[float(key)] += value.total | |
| if normalization: | |
| normalizer = sum(stats_rounded.values()) | |
| stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} | |
| return stats_rounded | |
| def prepare_grouped_data(dataset_path, base_folder, grouping, stat_name, top_k, direction: PARTITION_OPTIONS): | |
| import heapq | |
| stats = load_stats(base_folder, dataset_path, stat_name, grouping) | |
| means = {key: value.mean for key, value in stats.items()} | |
| # Use heap to get top_k keys | |
| if direction == "Top": | |
| keys = heapq.nlargest(top_k, means, key=means.get) | |
| elif direction == "Most frequent (n_docs)": | |
| totals = {key: value.n for key, value in stats.items()} | |
| keys = heapq.nlargest(top_k, totals, key=totals.get) | |
| else: | |
| keys = heapq.nsmallest(top_k, means, key=means.get) | |
| return [(key, means[key]) for key in keys] | |
| def set_alpha(color, alpha): | |
| """ | |
| Takes a hex color and returns | |
| rgba(r, g, b, a) | |
| """ | |
| if color.startswith('#'): | |
| r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16) | |
| else: | |
| r, g, b = 0, 0, 0 # Fallback to black if the color format is not recognized | |
| return f"rgba({r}, {g}, {b}, {alpha})" | |
| def plot_scatter( | |
| histograms: dict[str, dict[float, float]], | |
| stat_name: str, | |
| normalization: bool, | |
| progress: gr.Progress, | |
| ): | |
| fig = go.Figure() | |
| for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")): | |
| if all(isinstance(k, str) for k in histogram.keys()): | |
| x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] | |
| else: | |
| x = sorted(histogram.keys()) | |
| y = [histogram[k] for k in x] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="lines", | |
| name=name, | |
| marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), | |
| ) | |
| ) | |
| xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear" | |
| yaxis_title = "Frequency" if normalization else "Total" | |
| fig.update_layout( | |
| title=f"Line Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title=yaxis_title, | |
| xaxis_type=xaxis_scale, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| def plot_bars( | |
| histograms: dict[str, list[tuple[str, float]]], | |
| stat_name: str, | |
| progress: gr.Progress, | |
| ): | |
| fig = go.Figure() | |
| for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")): | |
| x = [k for k, v in histogram] | |
| y = [v for k, v in histogram] | |
| fig.add_trace(go.Bar(x=x, y=y, name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)))) | |
| fig.update_layout( | |
| title=f"Bar Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title="Mean value", | |
| autosize=True, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| def update_graph( | |
| base_folder, | |
| datasets, | |
| stat_name, | |
| grouping, | |
| normalization, | |
| top_k, | |
| direction, | |
| progress=gr.Progress(), | |
| ): | |
| if len(datasets) <= 0 or not stat_name or not grouping: | |
| return None | |
| # Placeholder for logic to rerender the graph based on the inputs | |
| prepare_fc = ( | |
| partial(prepare_non_grouped_data, normalization=normalization) | |
| if grouping == "histogram" | |
| else partial(prepare_grouped_data, top_k=top_k, direction=direction) | |
| ) | |
| graph_fc = ( | |
| partial(plot_scatter, normalization=normalization) | |
| if grouping == "histogram" | |
| else plot_bars | |
| ) | |
| with ThreadPoolExecutor() as pool: | |
| data = list( | |
| progress.tqdm( | |
| pool.map( | |
| partial(prepare_fc, base_folder=base_folder, stat_name=stat_name, grouping=grouping), | |
| datasets, | |
| ), | |
| total=len(datasets), | |
| desc="Loading data...", | |
| ) | |
| ) | |
| histograms = {path: result for path, result in zip(datasets, data)} | |
| return graph_fc(histograms=histograms, stat_name=stat_name, progress=progress), histograms, gr.update(visible=True) | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| datasets = gr.State([]) | |
| exported_data = gr.State([]) | |
| stats_headline = gr.Markdown(value="# Stats Exploration") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Define the multiselect for crawls | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| base_folder = gr.Textbox( | |
| label="Stats Location", | |
| value="s3://fineweb-stats/summary/", | |
| ) | |
| datasets_refetch = gr.Button("Fetch Datasets") | |
| with gr.Column(scale=1): | |
| regex_select = gr.Text(label="Regex select datasets", value=".*") | |
| regex_button = gr.Button("Filter") | |
| with gr.Row(): | |
| datasets_selected = gr.Dropdown( | |
| choices=[], | |
| label="Datasets", | |
| multiselect=True, | |
| ) | |
| # add a readme description | |
| readme_description = gr.Markdown( | |
| label="Readme", | |
| value=""" | |
| Explaination of the tool: | |
| Groupings: | |
| - histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1. | |
| - (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain | |
| * k: the number of groups to show | |
| * Top/Bottom: the top/bottom k groups are shown | |
| - summary: simply shows the average value of given stat for selected crawls | |
| """, | |
| ) | |
| with gr.Column(scale=1): | |
| # Define the dropdown for grouping | |
| grouping_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Grouping", | |
| multiselect=False, | |
| ) | |
| # Define the dropdown for stat_name | |
| stat_name_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Stat name", | |
| multiselect=False, | |
| ) | |
| with gr.Row(visible=False) as histogram_choices: | |
| normalization_checkbox = gr.Checkbox( | |
| label="Normalize", | |
| value=False, # Default value | |
| ) | |
| with gr.Row(visible=False) as group_choices: | |
| top_select = gr.Number( | |
| label="K", | |
| value=100, | |
| interactive=True, | |
| ) | |
| direction_checkbox = gr.Radio( | |
| label="Partition", | |
| choices=[ | |
| "Top", | |
| "Bottom", | |
| "Most frequent (n_docs)", | |
| ], | |
| value="Top", | |
| ) | |
| update_button = gr.Button("Update Graph", variant="primary") | |
| with gr.Row(): | |
| export_data_button = gr.Button("Export data", visible=False) | |
| export_data_json = gr.File(visible=False) | |
| with gr.Row(): | |
| # Define the graph output | |
| graph_output = gr.Plot(label="Graph") | |
| with gr.Row(): | |
| reverse_search_headline = gr.Markdown(value="# Reverse stats search") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Define the dropdown for grouping | |
| reverse_grouping_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Grouping", | |
| multiselect=False, | |
| ) | |
| # Define the dropdown for stat_name | |
| reverse_stat_name_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Stat name", | |
| multiselect=False, | |
| ) | |
| with gr.Column(scale=1): | |
| reverse_search_button = gr.Button("Search") | |
| reverse_search_add_button = gr.Button("Add to selection") | |
| with gr.Column(scale=2): | |
| reverse_search_results = gr.Textbox( | |
| label="Found datasets", | |
| lines=10, | |
| ) | |
| update_button.click( | |
| fn=update_graph, | |
| inputs=[ | |
| base_folder, | |
| datasets_selected, | |
| stat_name_dropdown, | |
| grouping_dropdown, | |
| normalization_checkbox, | |
| top_select, | |
| direction_checkbox, | |
| ], | |
| outputs=[graph_output, exported_data, export_data_button], | |
| ) | |
| export_data_button.click( | |
| fn=export_data, | |
| inputs=[exported_data], | |
| outputs=export_data_json, | |
| ) | |
| datasets_selected.change( | |
| fn=fetch_groups, | |
| inputs=[base_folder, datasets_selected, grouping_dropdown], | |
| outputs=grouping_dropdown, | |
| ) | |
| grouping_dropdown.select( | |
| fn=fetch_stats, | |
| inputs=[base_folder, datasets_selected, grouping_dropdown, stat_name_dropdown], | |
| outputs=stat_name_dropdown, | |
| ) | |
| reverse_grouping_dropdown.select( | |
| fn=partial(fetch_stats, type="union"), | |
| inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown], | |
| outputs=reverse_stat_name_dropdown, | |
| ) | |
| reverse_search_button.click( | |
| fn=reverse_search, | |
| inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown], | |
| outputs=reverse_search_results, | |
| ) | |
| reverse_search_add_button.click( | |
| fn=reverse_search_add, | |
| inputs=[datasets_selected, reverse_search_results], | |
| outputs=datasets_selected, | |
| ) | |
| datasets_refetch.click( | |
| fn=fetch_datasets, | |
| inputs=[base_folder], | |
| outputs=[datasets, datasets_selected, reverse_grouping_dropdown], | |
| ) | |
| def update_datasets_with_regex(regex, selected_runs, all_runs): | |
| if not regex: | |
| return | |
| new_dsts = {run for run in all_runs if re.search(regex, run)} | |
| dst_union = new_dsts.union(selected_runs) | |
| return gr.update(value=list(dst_union)) | |
| regex_button.click( | |
| fn=update_datasets_with_regex, | |
| inputs=[regex_select, datasets_selected, datasets], | |
| outputs=datasets_selected, | |
| ) | |
| def update_grouping_options(grouping): | |
| if grouping == "histogram": | |
| return { | |
| histogram_choices: gr.Column(visible=True), | |
| group_choices: gr.Column(visible=False), | |
| } | |
| else: | |
| return { | |
| histogram_choices: gr.Column(visible=False), | |
| group_choices: gr.Column(visible=True), | |
| } | |
| grouping_dropdown.select( | |
| fn=update_grouping_options, | |
| inputs=[grouping_dropdown], | |
| outputs=[histogram_choices, group_choices], | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch() | |