Spaces:
Runtime error
Runtime error
| import json | |
| import re | |
| import heapq | |
| from collections import defaultdict | |
| import tempfile | |
| from typing import Dict, Tuple, List, Literal | |
| import gradio as gr | |
| from datatrove.utils.stats import MetricStatsDict | |
| PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"] | |
| def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]: | |
| metrics_rounded = defaultdict(lambda: 0) | |
| for key, value in metric.items(): | |
| metrics_rounded[round(float(key), rounding)] += value.total | |
| if normalization: | |
| normalizer = sum(metrics_rounded.values()) | |
| metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()} | |
| assert abs(sum(metrics_rounded.values()) - 1) < 0.01 | |
| return metrics_rounded | |
| def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]: | |
| regex_compiled = re.compile(regex) if regex else None | |
| metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)} | |
| means = {key: round(float(value.mean), rounding) for key, value in metric.items()} | |
| if direction == "Top": | |
| keys = heapq.nlargest(top_k, means, key=means.get) | |
| elif direction == "Most frequent (n_docs)": | |
| totals = {key: int(value.n) for key, value in metric.items()} | |
| keys = heapq.nlargest(top_k, totals, key=totals.get) | |
| else: | |
| keys = heapq.nsmallest(top_k, means, key=means.get) | |
| means = [means[key] for key in keys] | |
| stds = [metric[key].standard_deviation for key in keys] | |
| return keys, means, stds | |
| def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str): | |
| if not exported_data: | |
| return None | |
| with tempfile.NamedTemporaryFile(mode="w", delete=False, prefix=metric_name, suffix=".json") as temp: | |
| json.dump({ | |
| name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"]) | |
| for name, dt in exported_data.items() | |
| }, temp, indent=2) | |
| temp_path = temp.name | |
| return gr.update(visible=True, value=temp_path) |