Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -135,7 +135,7 @@ def extract_org_from_id(model_id):
|
|
| 135 |
return model_id.split("/")[0]
|
| 136 |
return "unaffiliated"
|
| 137 |
|
| 138 |
-
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None):
|
| 139 |
"""Process DataFrame into treemap format with filters applied"""
|
| 140 |
# Create a copy to avoid modifying the original
|
| 141 |
filtered_df = df.copy()
|
|
@@ -158,6 +158,10 @@ def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=N
|
|
| 158 |
# Add organization column
|
| 159 |
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
# Aggregate by organization
|
| 162 |
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
|
| 163 |
org_totals = org_totals.sort_values(by=count_by, ascending=False)
|
|
@@ -215,7 +219,7 @@ def create_treemap(treemap_data, count_by, title=None):
|
|
| 215 |
|
| 216 |
return fig
|
| 217 |
|
| 218 |
-
def load_models_csv():
|
| 219 |
# Read the CSV file
|
| 220 |
df = pd.read_csv('models.csv')
|
| 221 |
|
|
@@ -419,6 +423,12 @@ with gr.Blocks() as demo:
|
|
| 419 |
step=5,
|
| 420 |
info="Number of top organizations to include"
|
| 421 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
generate_plot_button = gr.Button("Generate Plot", variant="primary")
|
| 424 |
|
|
@@ -426,7 +436,7 @@ with gr.Blocks() as demo:
|
|
| 426 |
plot_output = gr.Plot()
|
| 427 |
stats_output = gr.Markdown("*Generate a plot to see statistics*")
|
| 428 |
|
| 429 |
-
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, data_df):
|
| 430 |
print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}")
|
| 431 |
|
| 432 |
if data_df is None or len(data_df) == 0:
|
|
@@ -444,6 +454,12 @@ with gr.Blocks() as demo:
|
|
| 444 |
if size_filter != "None":
|
| 445 |
selected_size_filter = size_filter
|
| 446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
# Process data for treemap
|
| 448 |
treemap_data = make_treemap_data(
|
| 449 |
df=data_df,
|
|
@@ -451,7 +467,8 @@ with gr.Blocks() as demo:
|
|
| 451 |
top_k=top_k,
|
| 452 |
tag_filter=selected_tag_filter,
|
| 453 |
pipeline_filter=selected_pipeline_filter,
|
| 454 |
-
size_filter=selected_size_filter
|
|
|
|
| 455 |
)
|
| 456 |
|
| 457 |
# Create plot
|
|
@@ -484,6 +501,10 @@ with gr.Blocks() as demo:
|
|
| 484 |
for org, value in top_5_orgs.items():
|
| 485 |
percentage = (value / total_value) * 100
|
| 486 |
stats_md += f"\n| {org} | {int(value):,} | {percentage:.2f}% |"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
return fig, stats_md
|
| 489 |
|
|
@@ -518,6 +539,7 @@ with gr.Blocks() as demo:
|
|
| 518 |
pipeline_filter_dropdown,
|
| 519 |
size_filter_dropdown,
|
| 520 |
top_k_slider,
|
|
|
|
| 521 |
models_data
|
| 522 |
],
|
| 523 |
outputs=[plot_output, stats_output]
|
|
|
|
| 135 |
return model_id.split("/")[0]
|
| 136 |
return "unaffiliated"
|
| 137 |
|
| 138 |
+
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
| 139 |
"""Process DataFrame into treemap format with filters applied"""
|
| 140 |
# Create a copy to avoid modifying the original
|
| 141 |
filtered_df = df.copy()
|
|
|
|
| 158 |
# Add organization column
|
| 159 |
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
| 160 |
|
| 161 |
+
# Skip organizations if specified
|
| 162 |
+
if skip_orgs and len(skip_orgs) > 0:
|
| 163 |
+
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
|
| 164 |
+
|
| 165 |
# Aggregate by organization
|
| 166 |
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
|
| 167 |
org_totals = org_totals.sort_values(by=count_by, ascending=False)
|
|
|
|
| 219 |
|
| 220 |
return fig
|
| 221 |
|
| 222 |
+
def load_models_csv():
|
| 223 |
# Read the CSV file
|
| 224 |
df = pd.read_csv('models.csv')
|
| 225 |
|
|
|
|
| 423 |
step=5,
|
| 424 |
info="Number of top organizations to include"
|
| 425 |
)
|
| 426 |
+
|
| 427 |
+
skip_orgs_textbox = gr.Textbox(
|
| 428 |
+
label="Organizations to Skip (comma-separated)",
|
| 429 |
+
placeholder="e.g., openai, meta, huggingface",
|
| 430 |
+
info="Enter names of organizations to exclude from the visualization"
|
| 431 |
+
)
|
| 432 |
|
| 433 |
generate_plot_button = gr.Button("Generate Plot", variant="primary")
|
| 434 |
|
|
|
|
| 436 |
plot_output = gr.Plot()
|
| 437 |
stats_output = gr.Markdown("*Generate a plot to see statistics*")
|
| 438 |
|
| 439 |
+
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
| 440 |
print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}")
|
| 441 |
|
| 442 |
if data_df is None or len(data_df) == 0:
|
|
|
|
| 454 |
if size_filter != "None":
|
| 455 |
selected_size_filter = size_filter
|
| 456 |
|
| 457 |
+
# Process skip organizations list
|
| 458 |
+
skip_orgs = []
|
| 459 |
+
if skip_orgs_text and skip_orgs_text.strip():
|
| 460 |
+
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
| 461 |
+
print(f"Skipping organizations: {skip_orgs}")
|
| 462 |
+
|
| 463 |
# Process data for treemap
|
| 464 |
treemap_data = make_treemap_data(
|
| 465 |
df=data_df,
|
|
|
|
| 467 |
top_k=top_k,
|
| 468 |
tag_filter=selected_tag_filter,
|
| 469 |
pipeline_filter=selected_pipeline_filter,
|
| 470 |
+
size_filter=selected_size_filter,
|
| 471 |
+
skip_orgs=skip_orgs
|
| 472 |
)
|
| 473 |
|
| 474 |
# Create plot
|
|
|
|
| 501 |
for org, value in top_5_orgs.items():
|
| 502 |
percentage = (value / total_value) * 100
|
| 503 |
stats_md += f"\n| {org} | {int(value):,} | {percentage:.2f}% |"
|
| 504 |
+
|
| 505 |
+
# Add note about skipped organizations if any
|
| 506 |
+
if skip_orgs:
|
| 507 |
+
stats_md += f"\n\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
|
| 508 |
|
| 509 |
return fig, stats_md
|
| 510 |
|
|
|
|
| 539 |
pipeline_filter_dropdown,
|
| 540 |
size_filter_dropdown,
|
| 541 |
top_k_slider,
|
| 542 |
+
skip_orgs_textbox,
|
| 543 |
models_data
|
| 544 |
],
|
| 545 |
outputs=[plot_output, stats_output]
|