Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import time | |
| def render_batch_operations(): | |
| """Render the batch operations page""" | |
| st.title("๐ Batch Operations") | |
| if "models" not in st.session_state or not st.session_state.models: | |
| st.info("No models found. Please create repositories first.") | |
| if st.button("Go to Dashboard", use_container_width=True): | |
| st.session_state.page = "home" | |
| st.experimental_rerun() | |
| return | |
| # Create a dataframe for model selection | |
| models_data = [] | |
| for model in st.session_state.models: | |
| try: | |
| models_data.append({ | |
| "Select": False, # Checkbox column | |
| "Model Name": model.modelId.split("/")[-1], | |
| "Full ID": model.modelId, | |
| "Downloads": getattr(model, "downloads", 0), | |
| "Likes": getattr(model, "likes", 0), | |
| "Private": getattr(model, "private", False), | |
| "Tags": ", ".join(getattr(model, "tags", []) or []), | |
| }) | |
| except Exception as e: | |
| st.warning(f"Error processing model {getattr(model, 'modelId', 'unknown')}: {str(e)}") | |
| if not models_data: | |
| st.error("Failed to process model data.") | |
| return | |
| # Convert to DataFrame for display | |
| df = pd.DataFrame(models_data) | |
| st.markdown("### Select Models for Batch Operations") | |
| st.markdown("Use the checkboxes to select models you want to operate on.") | |
| # Editable dataframe | |
| edited_df = st.data_editor( | |
| df, | |
| column_config={ | |
| "Select": st.column_config.CheckboxColumn( | |
| "Select", | |
| help="Select for batch operations", | |
| default=False, | |
| ), | |
| "Full ID": st.column_config.TextColumn( | |
| "Repository ID", | |
| help="Full repository ID", | |
| disabled=True, | |
| ), | |
| "Downloads": st.column_config.NumberColumn( | |
| "Downloads", | |
| help="Number of downloads", | |
| disabled=True, | |
| ), | |
| "Likes": st.column_config.NumberColumn( | |
| "Likes", | |
| help="Number of likes", | |
| disabled=True, | |
| ), | |
| "Private": st.column_config.CheckboxColumn( | |
| "Private", | |
| help="Repository visibility", | |
| disabled=True, | |
| ), | |
| "Tags": st.column_config.TextColumn( | |
| "Tags", | |
| help="Current tags", | |
| disabled=True, | |
| ), | |
| }, | |
| hide_index=True, | |
| use_container_width=True, | |
| ) | |
| # Get selected models | |
| selected_models = edited_df[edited_df["Select"] == True] | |
| selected_count = len(selected_models) | |
| if selected_count > 0: | |
| st.success(f"Selected {selected_count} models for batch operations.") | |
| else: | |
| st.info("Please select at least one model to perform batch operations.") | |
| # Batch operations tabs | |
| if selected_count > 0: | |
| tab1, tab2, tab3, tab4 = st.tabs(["Update Tags", "Update Visibility", "Add Collaborators", "Delete"]) | |
| with tab1: | |
| st.subheader("Update Tags") | |
| # Get available tags | |
| available_tags = st.session_state.client.get_model_tags() | |
| # Tags selection | |
| selected_tags = st.multiselect( | |
| "Select tags to add to all selected models", | |
| options=available_tags, | |
| help="These tags will be added to all selected models" | |
| ) | |
| tags_action = st.radio( | |
| "Tag Operation", | |
| ["Add tags (keep existing)", "Replace tags (remove existing)"], | |
| index=0 | |
| ) | |
| if st.button("Apply Tags", use_container_width=True, type="primary"): | |
| if not selected_tags: | |
| st.warning("Please select at least one tag to add.") | |
| else: | |
| with st.spinner(f"Updating tags for {selected_count} models..."): | |
| # Track success and failures | |
| successes = 0 | |
| failures = [] | |
| # Process each selected model | |
| for idx, row in selected_models.iterrows(): | |
| try: | |
| repo_id = row["Full ID"] | |
| model_info = st.session_state.client.get_model_info(repo_id) | |
| if model_info: | |
| # Get current model card content | |
| try: | |
| model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" | |
| response = st.session_state.client.api._get_paginated(model_card_url) | |
| if response.status_code != 200: | |
| failures.append((repo_id, "Failed to fetch model card")) | |
| continue | |
| model_card_content = response.text | |
| # Update tags in the model card | |
| import re | |
| yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) | |
| if yaml_match: | |
| yaml_content = yaml_match.group(1) | |
| tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) | |
| if tags_match and tags_action == "Add tags (keep existing)": | |
| # Extract existing tags | |
| existing_tags = [ | |
| line.strip("- \n") | |
| for line in tags_match.group(1).split("\n") | |
| if line.strip().startswith("-") | |
| ] | |
| # Combine existing and new tags | |
| all_tags = list(set(existing_tags + selected_tags)) | |
| # Replace tags section | |
| new_yaml = yaml_content.replace( | |
| tags_match.group(0), | |
| f"tags:\n" + "\n".join([f"- {tag}" for tag in all_tags]) + "\n", | |
| ) | |
| # Update the model card | |
| new_content = model_card_content.replace( | |
| yaml_match.group(0), f"---\n{new_yaml}---" | |
| ) | |
| elif tags_match and tags_action == "Replace tags (remove existing)": | |
| # Replace tags section | |
| new_yaml = yaml_content.replace( | |
| tags_match.group(0), | |
| f"tags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n", | |
| ) | |
| # Update the model card | |
| new_content = model_card_content.replace( | |
| yaml_match.group(0), f"---\n{new_yaml}---" | |
| ) | |
| elif tags_action == "Add tags (keep existing)" or tags_action == "Replace tags (remove existing)": | |
| # Add tags section if it doesn't exist | |
| new_yaml = yaml_content + f"\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n" | |
| # Update the model card | |
| new_content = model_card_content.replace( | |
| yaml_match.group(0), f"---\n{new_yaml}---" | |
| ) | |
| else: | |
| failures.append((repo_id, "Failed to update tags in model card")) | |
| continue | |
| else: | |
| # Add YAML frontmatter with tags | |
| tags_yaml = "---\ntags:\n" + "\n".join([f"- {tag}" for tag in selected_tags]) + "\n---\n\n" | |
| new_content = tags_yaml + model_card_content | |
| # Update the model card | |
| success, _ = st.session_state.client.update_model_card(repo_id, new_content) | |
| if success: | |
| successes += 1 | |
| else: | |
| failures.append((repo_id, "Failed to update model card")) | |
| except Exception as e: | |
| failures.append((repo_id, str(e))) | |
| else: | |
| failures.append((repo_id, "Failed to fetch model info")) | |
| except Exception as e: | |
| failures.append((row["Full ID"], str(e))) | |
| # Show results | |
| if successes > 0: | |
| st.success(f"Successfully updated tags for {successes} models") | |
| if failures: | |
| st.error(f"Failed to update {len(failures)} models") | |
| for repo_id, error in failures: | |
| st.warning(f"Failed to update {repo_id}: {error}") | |
| # Refresh models after batch operation | |
| st.session_state.models = st.session_state.client.get_user_models() | |
| st.info("Model list refreshed. You may need to wait a few minutes for all changes to propagate.") | |
| with tab2: | |
| st.subheader("Update Visibility") | |
| visibility = st.radio( | |
| "Set visibility for selected models", | |
| ["Public", "Private"], | |
| index=0, | |
| help="Change the visibility of all selected models" | |
| ) | |
| if st.button("Update Visibility", use_container_width=True, type="primary"): | |
| with st.spinner(f"Updating visibility for {selected_count} models..."): | |
| st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") | |
| st.info("In the actual implementation, this would update the models' visibility settings.") | |
| # This is a placeholder for the actual implementation | |
| time.sleep(2) | |
| st.success(f"Successfully updated visibility for {selected_count} models") | |
| with tab3: | |
| st.subheader("Add Collaborators") | |
| collaborators = st.text_area( | |
| "Enter usernames of collaborators (one per line)", | |
| help="These users will be added as collaborators to all selected models" | |
| ) | |
| role = st.selectbox( | |
| "Collaborator role", | |
| ["read", "write", "admin"], | |
| index=0 | |
| ) | |
| if st.button("Add Collaborators", use_container_width=True, type="primary"): | |
| if not collaborators.strip(): | |
| st.warning("Please enter at least one collaborator username.") | |
| else: | |
| with st.spinner(f"Adding collaborators to {selected_count} models..."): | |
| # This is a placeholder for the actual implementation | |
| collaborator_list = [c.strip() for c in collaborators.split("\n") if c.strip()] | |
| st.info(f"Adding {len(collaborator_list)} collaborators with '{role}' role to {selected_count} models.") | |
| st.warning("This feature requires Hugging Face Pro or Enterprise subscription.") | |
| time.sleep(2) | |
| st.success(f"Successfully added collaborators to {selected_count} models") | |
| with tab4: | |
| st.subheader("โ ๏ธ Delete Models") | |
| st.warning( | |
| "This operation is irreversible. All selected models will be permanently deleted." | |
| ) | |
| # Confirmation | |
| confirmation = st.text_input( | |
| "Type 'DELETE' to confirm deletion of all selected models", | |
| key="batch_delete_confirm" | |
| ) | |
| if st.button("Delete Selected Models", use_container_width=True, type="primary"): | |
| if confirmation != "DELETE": | |
| st.error("Please type 'DELETE' to confirm.") | |
| else: | |
| with st.spinner(f"Deleting {selected_count} models..."): | |
| # Track success and failures | |
| successes = 0 | |
| failures = [] | |
| # Process each selected model | |
| for idx, row in selected_models.iterrows(): | |
| try: | |
| repo_id = row["Full ID"] | |
| # Delete the repository | |
| success, message = st.session_state.client.delete_model_repository(repo_id) | |
| if success: | |
| successes += 1 | |
| else: | |
| failures.append((repo_id, message)) | |
| except Exception as e: | |
| failures.append((row["Full ID"], str(e))) | |
| # Show results | |
| if successes > 0: | |
| st.success(f"Successfully deleted {successes} models") | |
| if failures: | |
| st.error(f"Failed to delete {len(failures)} models") | |
| for repo_id, error in failures: | |
| st.warning(f"Failed to delete {repo_id}: {error}") | |
| # Refresh models after batch operation | |
| st.session_state.models = st.session_state.client.get_user_models() | |