File size: 9,797 Bytes
bbf45d0
 
 
961c6fe
b06975a
961c6fe
 
b06975a
addb03f
8c60635
 
 
 
 
 
 
 
afd7356
 
 
d858aa5
8c60635
afd7356
 
 
c0b7e37
 
 
961c6fe
addb03f
 
e65f153
 
 
addb03f
8c60635
 
 
9c451ee
8c60635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c6bf95
8c60635
addb03f
e65f153
 
 
 
8c60635
e65f153
addb03f
 
813c7cf
e65f153
addb03f
 
e65f153
addb03f
e65f153
 
9c451ee
 
e65f153
 
f0e2fd8
961c6fe
9c451ee
8c60635
e65f153
8c60635
961c6fe
8c60635
 
 
 
bbf45d0
 
813c7cf
8c60635
 
f0e2fd8
d858aa5
eec69ec
8c60635
eec69ec
98b7de8
f0e2fd8
813c7cf
 
 
 
 
afd7356
f0e2fd8
961c6fe
f0e2fd8
 
c0b7e37
8c60635
 
961c6fe
813c7cf
 
b06975a
813c7cf
961c6fe
8c60635
961c6fe
813c7cf
fa2c2d2
b06975a
 
fa2c2d2
8c60635
813c7cf
 
 
 
4517d15
961c6fe
 
813c7cf
 
c0b7e37
813c7cf
 
8c60635
813c7cf
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0811f
8c60635
addb03f
8c60635
813c7cf
d858aa5
addb03f
 
f0e2fd8
addb03f
9c451ee
b06975a
961c6fe
8c60635
961c6fe
afd7356
961c6fe
b06975a
961c6fe
d858aa5
e65f153
813c7cf
 
 
8c60635
961c6fe
 
813c7cf
 
 
0c6bf95
813c7cf
0c6bf95
813c7cf
47e0cf9
afd7356
8c60635
 
 
 
 
 
bbf45d0
961c6fe
8c60635
addb03f
f0e2fd8
bbf45d0
 
 
8c60635
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import pandas as pd
import plotly.express as px
import time
from datasets import load_dataset

# --- Constants ---
TOP_K_CHOICES = list(range(5, 51, 5))
HF_DATASET_ID = "evijit/dataverse_daily_data"
TAG_FILTER_CHOICES = [ 
    "None", "Audio & Speech", "Time series", "Robotics", "Music", 
    "Video", "Images", "Text", "Biomedical", "Sciences" 
]

def load_datasets_data():
    """Load the processed datasets data from the Hugging Face Hub."""
    start_time = time.time()
    print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}")
    try:
        dataset_dict = load_dataset(HF_DATASET_ID)
        df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas()
        msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s."
        print(msg)
        return df, True, msg
    except Exception as e:
        err_msg = f"Failed to load dataset. Error: {e}"
        print(err_msg)
        return pd.DataFrame(), False, err_msg

def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None):
    """
    Filter data and prepare it for a multi-level treemap.
    - Preserves individual datasets for the top K organizations.
    - Groups all other organizations into a single "Other" category.
    """
    if df is None or df.empty:
        return pd.DataFrame()
        
    filtered_df = df.copy()
    
    col_map = { 
        "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot", 
        "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science", 
        "Video": "has_video", "Images": "has_image", "Text": "has_text" 
    }
    
    if tag_filter and tag_filter != "None" and tag_filter in col_map:
        if col_map[tag_filter] in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[col_map[tag_filter]]]
        
    if filtered_df.empty:
        return pd.DataFrame()
        
    if count_by not in filtered_df.columns:
        filtered_df[count_by] = 0.0
    filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
    
    all_org_totals = filtered_df.groupby("organization")[count_by].sum()
    top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()

    top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
    other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
    
    final_df_for_plot = top_orgs_df
    
    if other_total > 0:
        other_row = pd.DataFrame([{'organization': 'Other', 'id': 'Other', count_by: other_total}])
        final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)

    if skip_cats and len(skip_cats) > 0:
        final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]

    final_df_for_plot["root"] = "datasets"
    return final_df_for_plot

def create_treemap(treemap_data, count_by, title=None):
    """Generate the Plotly treemap figure from the prepared data."""
    if treemap_data.empty or treemap_data[count_by].sum() <= 0:
        fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
        fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
        return fig
        
    fig = px.treemap(treemap_data, path=["root", "organization", "id"], values=count_by, 
                     title=title, color_discrete_sequence=px.colors.qualitative.Plotly)
    fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
    fig.update_traces(
        textinfo="label+value+percent root", 
        hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
    )
    return fig

# --- Gradio UI Blocks ---
with gr.Blocks(title="🤗 Dataverse Explorer", fill_width=True) as demo:
    datasets_data_state = gr.State(pd.DataFrame())
    loading_complete_state = gr.State(False)
    
    with gr.Row():
        gr.Markdown("# 🤗 Dataverse Explorer")

    with gr.Row():
        with gr.Column(scale=1):
            count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
            tag_filter_dropdown = gr.Dropdown(label="Filter by Tag", choices=TAG_FILTER_CHOICES, value="None")
            top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
            skip_cats_textbox = gr.Textbox(label="Organizations to Skip from the plot", value="Other")
            generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)

        with gr.Column(scale=3):
            plot_output = gr.Plot()
            status_message_md = gr.Markdown("Initializing...")
            data_info_md = gr.Markdown("")
    
    def _update_button_interactivity(is_loaded_flag):
        return gr.update(interactive=is_loaded_flag)

    ## CHANGE: New combined function to load data and generate the initial plot on startup.
    def load_and_generate_initial_plot(progress=gr.Progress()):
        progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...")
        # --- Part 1: Data Loading ---
        try:
            current_df, load_success_flag, status_msg_from_load = load_datasets_data()
            if load_success_flag:
                progress(0.5, desc="Processing data...")
                date_display = "Pre-processed (date unavailable)"
                if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
                    ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
                    date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
                
                data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n"
                                  f"- Status: {status_msg_from_load}\n"
                                  f"- Total datasets loaded: {len(current_df):,}\n"
                                  f"- Data as of: {date_display}\n")
            else:
                data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
        except Exception as e:
            status_msg_from_load = f"An unexpected error occurred: {str(e)}"
            data_info_text = f"### Critical Error\n- {status_msg_from_load}"
            load_success_flag = False
            current_df = pd.DataFrame() # Ensure df is empty on failure
            print(f"Critical error in load_and_generate_initial_plot: {e}")
            
        # --- Part 2: Generate Initial Plot ---
        progress(0.6, desc="Generating initial plot...")
        # Get default values directly from the UI component definitions
        default_metric = "downloads"
        default_tag = "None"
        default_k = 25
        default_skip_cats = "Other"

        # Reuse the existing controller function for plotting
        initial_plot, initial_status = ui_generate_plot_controller(
            default_metric, default_tag, default_k, default_skip_cats, current_df, progress
        )
        
        return current_df, load_success_flag, data_info_text, initial_status, initial_plot

    def ui_generate_plot_controller(metric_choice, tag_choice, k_orgs, 
                                   skip_cats_input, df_current_datasets, progress=gr.Progress()):
        if df_current_datasets is None or df_current_datasets.empty:
            return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot."
        
        progress(0.1, desc="Aggregating data...")
        cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()]
        
        treemap_df = make_treemap_data(df_current_datasets, metric_choice, k_orgs, tag_choice, cats_to_skip)
        
        progress(0.7, desc="Generating plot...")
        title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
        chart_title = f"HuggingFace Datasets - {title_labels.get(metric_choice, metric_choice)} by Organization"
        plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
        
        if treemap_df.empty:
            plot_stats_md = "No data matches the selected filters. Please try different options."
        else:
            total_value_in_plot = treemap_df[metric_choice].sum()
            total_datasets_in_plot = treemap_df[treemap_df['id'] != 'Other']['id'].nunique()
            plot_stats_md = (f"## Plot Statistics\n- **Organizations/Categories Shown**: {treemap_df['organization'].nunique():,}\n"
                             f"- **Individual Datasets Shown**: {total_datasets_in_plot:,}\n"
                             f"- **Total {metric_choice} in plot**: {int(total_value_in_plot):,}")
            
        return plotly_fig, plot_stats_md

    # --- Event Wiring ---
    
    ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list.
    demo.load(
        fn=load_and_generate_initial_plot, 
        inputs=[], 
        outputs=[datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output]
    )

    loading_complete_state.change(
        fn=_update_button_interactivity, 
        inputs=loading_complete_state, 
        outputs=generate_plot_button
    )

    generate_plot_button.click(
        fn=ui_generate_plot_controller,
        inputs=[count_by_dropdown, tag_filter_dropdown, top_k_dropdown, 
                skip_cats_textbox, datasets_data_state],
        outputs=[plot_output, status_message_md]
    )

if __name__ == "__main__":
    print("Application starting...")
    demo.queue().launch()