Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
import plotly.express as px
|
|
@@ -27,11 +29,12 @@ def load_datasets_data():
|
|
| 27 |
print(err_msg)
|
| 28 |
return pd.DataFrame(), False, err_msg
|
| 29 |
|
| 30 |
-
# ---
|
| 31 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
|
| 32 |
"""
|
| 33 |
-
Filter data and prepare it for
|
| 34 |
-
|
|
|
|
| 35 |
"""
|
| 36 |
if df is None or df.empty:
|
| 37 |
return pd.DataFrame()
|
|
@@ -55,40 +58,46 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
|
|
| 55 |
filtered_df[count_by] = 0.0
|
| 56 |
filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
|
| 57 |
|
| 58 |
-
# 1.
|
| 59 |
all_org_totals = filtered_df.groupby("organization")[count_by].sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# 3. Calculate the sum for the "Other" category
|
| 65 |
-
other_total = all_org_totals.sum() - top_org_totals.sum()
|
| 66 |
|
| 67 |
-
# 4. Create the
|
| 68 |
-
|
| 69 |
|
| 70 |
-
# 5. Add the "Other" row if its value is greater than zero
|
| 71 |
if other_total > 0:
|
| 72 |
-
other_row = pd.DataFrame([{
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
# 6. Apply the skip filter
|
| 76 |
if skip_cats and len(skip_cats) > 0:
|
| 77 |
-
|
| 78 |
|
| 79 |
-
|
| 80 |
-
return
|
| 81 |
|
| 82 |
-
# ---
|
| 83 |
def create_treemap(treemap_data, count_by, title=None):
|
| 84 |
-
"""Generate the Plotly treemap figure from
|
| 85 |
-
if treemap_data.empty or treemap_data[count_by].sum()
|
| 86 |
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
|
| 87 |
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
|
| 88 |
return fig
|
| 89 |
|
| 90 |
-
# The path is
|
| 91 |
-
|
|
|
|
| 92 |
title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
|
| 93 |
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
|
| 94 |
fig.update_traces(
|
|
@@ -97,7 +106,7 @@ def create_treemap(treemap_data, count_by, title=None):
|
|
| 97 |
)
|
| 98 |
return fig
|
| 99 |
|
| 100 |
-
# --- Gradio UI Blocks ---
|
| 101 |
with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
|
| 102 |
datasets_data_state = gr.State(pd.DataFrame())
|
| 103 |
loading_complete_state = gr.State(False)
|
|
@@ -125,7 +134,6 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
|
|
| 125 |
value=25
|
| 126 |
)
|
| 127 |
|
| 128 |
-
# --- MODIFIED: UI updated to reflect the new functionality ---
|
| 129 |
skip_cats_textbox = gr.Textbox(
|
| 130 |
label="Categories to Skip (e.g., Other)",
|
| 131 |
value="Other"
|
|
@@ -174,7 +182,7 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
|
|
| 174 |
|
| 175 |
return current_df, load_success_flag, data_info_text, status_msg_ui
|
| 176 |
|
| 177 |
-
# ---
|
| 178 |
def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs,
|
| 179 |
skip_cats_input, df_current_datasets, progress=gr.Progress()):
|
| 180 |
if df_current_datasets is None or df_current_datasets.empty:
|
|
@@ -190,19 +198,21 @@ with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
|
|
| 190 |
chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization"
|
| 191 |
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
|
| 192 |
|
| 193 |
-
# Update plot statistics to be more accurate for the new view
|
| 194 |
if treemap_df.empty:
|
| 195 |
plot_stats_md = "No data matches the selected filters. Please try different options."
|
| 196 |
else:
|
| 197 |
total_value_in_plot = treemap_df[metric_choice].sum()
|
|
|
|
|
|
|
| 198 |
plot_stats_md = (
|
| 199 |
-
f"## Plot Statistics\n- **
|
|
|
|
| 200 |
f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}"
|
| 201 |
)
|
| 202 |
|
| 203 |
return plotly_fig, plot_stats_md
|
| 204 |
|
| 205 |
-
# --- Event Wiring (no changes needed
|
| 206 |
demo.load(
|
| 207 |
fn=ui_load_data_controller,
|
| 208 |
inputs=[],
|
|
|
|
| 1 |
+
# --- app.py (Dataverse Explorer - Corrected with drill-down) ---
|
| 2 |
+
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
import plotly.express as px
|
|
|
|
| 29 |
print(err_msg)
|
| 30 |
return pd.DataFrame(), False, err_msg
|
| 31 |
|
| 32 |
+
# --- CORRECTED: This function now preserves individual datasets for top orgs ---
|
| 33 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
|
| 34 |
"""
|
| 35 |
+
Filter data and prepare it for a multi-level treemap.
|
| 36 |
+
- Preserves individual datasets for the top K organizations.
|
| 37 |
+
- Groups all other organizations into a single "Other" category.
|
| 38 |
"""
|
| 39 |
if df is None or df.empty:
|
| 40 |
return pd.DataFrame()
|
|
|
|
| 58 |
filtered_df[count_by] = 0.0
|
| 59 |
filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
|
| 60 |
|
| 61 |
+
# 1. Get total for every organization to determine the top K
|
| 62 |
all_org_totals = filtered_df.groupby("organization")[count_by].sum()
|
| 63 |
+
top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()
|
| 64 |
+
|
| 65 |
+
# 2. Get the full data for the individual datasets belonging to the top organizations
|
| 66 |
+
top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
|
| 67 |
|
| 68 |
+
# 3. Calculate the total for the "Other" category
|
| 69 |
+
other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# 4. Create the final DataFrame for the plot
|
| 72 |
+
final_df_for_plot = top_orgs_df
|
| 73 |
|
| 74 |
+
# 5. Add the "Other" row as a single entry if its value is greater than zero
|
| 75 |
if other_total > 0:
|
| 76 |
+
other_row = pd.DataFrame([{
|
| 77 |
+
'organization': 'Other',
|
| 78 |
+
'id': 'Other', # The 'id' for the "Other" category must be defined for the path
|
| 79 |
+
count_by: other_total
|
| 80 |
+
}])
|
| 81 |
+
final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)
|
| 82 |
|
| 83 |
+
# 6. Apply the skip filter to the organization/category level
|
| 84 |
if skip_cats and len(skip_cats) > 0:
|
| 85 |
+
final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]
|
| 86 |
|
| 87 |
+
final_df_for_plot["root"] = "datasets"
|
| 88 |
+
return final_df_for_plot
|
| 89 |
|
| 90 |
+
# --- CORRECTED: The path is now restored to allow drill-down ---
|
| 91 |
def create_treemap(treemap_data, count_by, title=None):
|
| 92 |
+
"""Generate the Plotly treemap figure from the prepared data."""
|
| 93 |
+
if treemap_data.empty or treemap_data[count_by].sum() <= 0:
|
| 94 |
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
|
| 95 |
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
|
| 96 |
return fig
|
| 97 |
|
| 98 |
+
# The path is restored to `["root", "organization", "id"]` to enable drill-down.
|
| 99 |
+
# The "Other" row with id='Other' will correctly be displayed as a single block.
|
| 100 |
+
fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by,
|
| 101 |
title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
|
| 102 |
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
|
| 103 |
fig.update_traces(
|
|
|
|
| 106 |
)
|
| 107 |
return fig
|
| 108 |
|
| 109 |
+
# --- Gradio UI Blocks (no changes needed here) ---
|
| 110 |
with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
|
| 111 |
datasets_data_state = gr.State(pd.DataFrame())
|
| 112 |
loading_complete_state = gr.State(False)
|
|
|
|
| 134 |
value=25
|
| 135 |
)
|
| 136 |
|
|
|
|
| 137 |
skip_cats_textbox = gr.Textbox(
|
| 138 |
label="Categories to Skip (e.g., Other)",
|
| 139 |
value="Other"
|
|
|
|
| 182 |
|
| 183 |
return current_df, load_success_flag, data_info_text, status_msg_ui
|
| 184 |
|
| 185 |
+
# --- CORRECTED: Updated stats to reflect the new plot structure ---
|
| 186 |
def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs,
|
| 187 |
skip_cats_input, df_current_datasets, progress=gr.Progress()):
|
| 188 |
if df_current_datasets is None or df_current_datasets.empty:
|
|
|
|
| 198 |
chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization"
|
| 199 |
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
|
| 200 |
|
|
|
|
| 201 |
if treemap_df.empty:
|
| 202 |
plot_stats_md = "No data matches the selected filters. Please try different options."
|
| 203 |
else:
|
| 204 |
total_value_in_plot = treemap_df[metric_choice].sum()
|
| 205 |
+
# Count datasets, excluding our placeholder "Other" id
|
| 206 |
+
total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique()
|
| 207 |
plot_stats_md = (
|
| 208 |
+
f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n"
|
| 209 |
+
f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n"
|
| 210 |
f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}"
|
| 211 |
)
|
| 212 |
|
| 213 |
return plotly_fig, plot_stats_md
|
| 214 |
|
| 215 |
+
# --- Event Wiring (no changes needed) ---
|
| 216 |
demo.load(
|
| 217 |
fn=ui_load_data_controller,
|
| 218 |
inputs=[],
|