Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| from collections import defaultdict | |
| import fsspec.config | |
| import math | |
| from datatrove.io import DataFolder, get_datafolder | |
| from datatrove.utils.stats import MetricStatsDict | |
| BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") | |
| LOG_SCALE_STATS = { | |
| "length", | |
| "n_lines", | |
| "n_docs", | |
| "n_words", | |
| "avg_words_per_line", | |
| "pages_with_lorem_ipsum", | |
| } | |
| colors = list( | |
| [ | |
| "rgba(31, 119, 180, 0.5)", | |
| "rgba(255, 127, 14, 0.5)", | |
| "rgba(44, 160, 44, 0.5)", | |
| "rgba(214, 39, 40, 0.5)", | |
| "rgba(148, 103, 189, 0.5)", | |
| "rgba(227, 119, 194, 0.5)", | |
| "rgba(127, 127, 127, 0.5)", | |
| "rgba(188, 189, 34, 0.5)", | |
| "rgba(23, 190, 207, 0.5)", | |
| "rgba(255, 193, 7, 0.5)", | |
| "rgba(40, 167, 69, 0.5)", | |
| "rgba(23, 162, 184, 0.5)", | |
| "rgba(108, 117, 125, 0.5)", | |
| "rgba(0, 123, 255, 0.5)", | |
| "rgba(220, 53, 69, 0.5)", | |
| "rgba(255, 159, 67, 0.5)", | |
| "rgba(255, 87, 34, 0.5)", | |
| "rgba(41, 182, 246, 0.5)", | |
| "rgba(142, 36, 170, 0.5)", | |
| "rgba(0, 188, 212, 0.5)", | |
| "rgba(255, 235, 59, 0.5)", | |
| "rgba(156, 39, 176, 0.5)", | |
| ] | |
| ) | |
| def find_folders(base_folder, path): | |
| return sorted( | |
| [ | |
| folder["name"] | |
| for folder in base_folder.ls(path, detail=True) | |
| if folder["type"] == "directory" and not folder["name"].rstrip("/") == path | |
| ] | |
| ) | |
| def find_stats_folders(base_folder: DataFolder): | |
| # First find all stats-merged.json using globing for stats-merged.json | |
| stats_merged = base_folder.glob("**/stats-merged.json") | |
| # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name) | |
| stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] | |
| # Finally get the unique paths | |
| return sorted(list(set(stats_folders))) | |
| RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) | |
| def fetch_groups(runs, old_groups): | |
| GROUPS = [ | |
| [Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs | |
| ] | |
| # DO the intersection | |
| if len(GROUPS) == 0: | |
| return gr.update(choices=[], value=None) | |
| new_choices = set.intersection(*(set(g) for g in GROUPS)) | |
| value = None | |
| if old_groups: | |
| value = list(set.intersection(new_choices, {old_groups})) | |
| value = value[0] if value else None | |
| # now take the intersection of all grups | |
| return gr.update(choices=list(new_choices), value=value) | |
| def fetch_stats(runs, group, old_stats): | |
| STATS = [ | |
| [Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")] | |
| for run in runs | |
| ] | |
| if len(STATS) == 0: | |
| return gr.update(choices=[], value=None) | |
| new_possibles_choices = set.intersection(*(set(s) for s in STATS)) | |
| value = None | |
| if old_stats: | |
| value = list(set.intersection(new_possibles_choices, {old_stats})) | |
| value = value[0] if value else None | |
| return gr.update(choices=list(new_possibles_choices), value=value) | |
| def load_stats(path, stat_name, group_by): | |
| with BASE_DATA_FOLDER.open( | |
| f"{path}/{group_by}/{stat_name}/stats-merged.json", | |
| filecache={"cache_storage": "/tmp/files"}, | |
| ) as f: | |
| json_stat = json.load(f) | |
| # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme | |
| return MetricStatsDict() + MetricStatsDict(init=json_stat) | |
| def prepare_non_grouped_data(path, stat_name, grouping, normalization): | |
| stats = load_stats(path, stat_name, grouping) | |
| stats_rounded = defaultdict(lambda: 0) | |
| for key, value in stats.items(): | |
| stats_rounded[float(key)] += value.total | |
| if normalization: | |
| normalizer = sum(stats_rounded.values()) | |
| stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} | |
| return stats_rounded | |
| def prepare_grouped_data(path, stat_name, grouping, top_k, direction): | |
| import heapq | |
| stats = load_stats(path, stat_name, grouping) | |
| means = {key: value.mean for key, value in stats.items()} | |
| # Use heap to get top_k keys | |
| if direction == "Top": | |
| keys = heapq.nlargest(top_k, means, key=means.get) | |
| elif direction == "Most frequent (n_docs)": | |
| n_docs = load_stats(path, "n_docs", grouping) | |
| totals = {key: value.total for key, value in n_docs.items()} | |
| keys = heapq.nlargest(top_k, totals, key=totals.get) | |
| elif direction == "Most frequent (length)": | |
| n_docs = load_stats(path, "n_docs", grouping) | |
| totals = {key: value.total for key, value in n_docs.items()} | |
| keys = heapq.nlargest(top_k, totals, key=totals.get) | |
| else: | |
| keys = heapq.nsmallest(top_k, means, key=means.get) | |
| return [(key, means[key]) for key in keys] | |
| import math | |
| import plotly.graph_objects as go | |
| from plotly.offline import plot | |
| def plot_scatter( | |
| histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool | |
| ): | |
| fig = go.Figure() | |
| for i, (name, histogram) in enumerate(histograms.items()): | |
| if all(isinstance(k, str) for k in histogram.keys()): | |
| x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] | |
| else: | |
| x = sorted(histogram.keys()) | |
| y = [histogram[k] for k in x] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="lines", | |
| name=name, | |
| line=dict(color=colors[i % len(colors)]), | |
| ) | |
| ) | |
| xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear" | |
| yaxis_title = "Frequency" if normalization else "Total" | |
| fig.update_layout( | |
| title=f"Line Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title=yaxis_title, | |
| xaxis_type=xaxis_scale, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| def plot_bars(histograms: dict[str, list[tuple[str, float]]], stat_name: str): | |
| fig = go.Figure() | |
| for i, (name, histogram) in enumerate(histograms.items()): | |
| x = [k for k, v in histogram] | |
| y = [v for k, v in histogram] | |
| fig.add_trace(go.Bar(x=x, y=y, name=name, marker_color=colors[i % len(colors)])) | |
| fig.update_layout( | |
| title=f"Bar Plots for {stat_name}", | |
| xaxis_title=stat_name, | |
| yaxis_title="Mean value", | |
| autosize=True, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| def update_graph( | |
| multiselect_crawls, stat_name, grouping, normalization, top_k, direction | |
| ): | |
| if len(multiselect_crawls) <= 0 or not stat_name or not grouping: | |
| return None | |
| # Placeholder for logic to rerender the graph based on the inputs | |
| prepare_fc = ( | |
| partial(prepare_non_grouped_data, normalization=normalization) | |
| if grouping == "histogram" | |
| else partial(prepare_grouped_data, top_k=top_k, direction=direction) | |
| ) | |
| graph_fc = ( | |
| partial(plot_scatter, normalization=normalization) | |
| if grouping == "histogram" | |
| else plot_bars | |
| ) | |
| print("Loading stats") | |
| histograms = { | |
| path: prepare_fc(path, stat_name, grouping) for path in multiselect_crawls | |
| } | |
| print("Plotting") | |
| return graph_fc(histograms, stat_name) | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Define the multiselect for crawls | |
| multiselect_crawls = gr.Dropdown( | |
| choices=RUNS, | |
| label="Multiselect for crawls", | |
| multiselect=True, | |
| ) | |
| # add a readme description | |
| readme_description = gr.Markdown( | |
| label="Readme", | |
| value=""" | |
| Explaination of the tool: | |
| Groupings: | |
| - histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1. | |
| - (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain | |
| * k: the number of groups to show | |
| * Top/Bottom: the top/bottom k groups are shown | |
| - summary: simply shows the average value of given stat for selected crawls | |
| """, | |
| ) | |
| with gr.Column(scale=1): | |
| # Define the dropdown for grouping | |
| grouping_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Grouping", | |
| multiselect=False, | |
| ) | |
| # Define the dropdown for stat_name | |
| stat_name_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Stat name", | |
| multiselect=False, | |
| ) | |
| with gr.Row(visible=False) as histogram_choices: | |
| normalization_checkbox = gr.Checkbox( | |
| label="Normalize", | |
| value=False, # Default value | |
| ) | |
| with gr.Row(visible=False) as group_choices: | |
| top_select = gr.Number( | |
| label="K", | |
| value=100, | |
| interactive=True, | |
| ) | |
| direction_checkbox = gr.Radio( | |
| label="Partition", | |
| choices=[ | |
| "Top", | |
| "Bottom", | |
| "Most frequent (n_docs)", | |
| "Most frequent (length)", | |
| ], | |
| ) | |
| update_button = gr.Button("Update Graph", variant="primary") | |
| with gr.Row(): | |
| # Define the graph output | |
| graph_output = gr.Plot(label="Graph") | |
| update_button.click( | |
| fn=update_graph, | |
| inputs=[ | |
| multiselect_crawls, | |
| stat_name_dropdown, | |
| grouping_dropdown, | |
| normalization_checkbox, | |
| top_select, | |
| direction_checkbox, | |
| ], | |
| outputs=graph_output, | |
| ) | |
| multiselect_crawls.select( | |
| fn=fetch_groups, | |
| inputs=[multiselect_crawls, grouping_dropdown], | |
| outputs=grouping_dropdown, | |
| ) | |
| grouping_dropdown.select( | |
| fn=fetch_stats, | |
| inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown], | |
| outputs=stat_name_dropdown, | |
| ) | |
| def update_grouping_options(grouping): | |
| if grouping == "histogram": | |
| return { | |
| histogram_choices: gr.Column(visible=True), | |
| group_choices: gr.Column(visible=False), | |
| } | |
| else: | |
| return { | |
| histogram_choices: gr.Column(visible=False), | |
| group_choices: gr.Column(visible=True), | |
| } | |
| grouping_dropdown.select( | |
| fn=update_grouping_options, | |
| inputs=[grouping_dropdown], | |
| outputs=[histogram_choices, group_choices], | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch() | |