Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import re | |
| import json | |
| import time | |
| def model_documentation_generator(model_info): | |
| """Generate comprehensive model documentation based on metadata""" | |
| if not model_info: | |
| st.error("Model information not found") | |
| return | |
| st.subheader("π Automated Model Documentation Generator") | |
| st.markdown("This tool generates a comprehensive model card based on model metadata and your input.") | |
| # Extract existing model card content if available | |
| model_card_content = "" | |
| yaml_content = "" | |
| markdown_content = "" | |
| try: | |
| repo_id = model_info.modelId | |
| model_card_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" | |
| response = st.session_state.client.api._get_paginated(model_card_url) | |
| if response.status_code == 200: | |
| model_card_content = response.text | |
| # Extract YAML frontmatter | |
| yaml_match = re.search(r"---\s+(.*?)\s+---", model_card_content, re.DOTALL) | |
| if yaml_match: | |
| yaml_content = yaml_match.group(1) | |
| # Extract markdown content (everything after frontmatter) | |
| markdown_match = re.search(r"---\s+.*?\s+---\s*(.*)", model_card_content, re.DOTALL) | |
| if markdown_match: | |
| markdown_content = markdown_match.group(1).strip() | |
| except Exception as e: | |
| st.warning(f"Couldn't load model card: {str(e)}") | |
| # Form for model metadata input | |
| with st.form("model_doc_form"): | |
| st.markdown("### Model Metadata") | |
| # Basic Information | |
| st.markdown("#### Basic Information") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Extract model name from repo ID | |
| model_name = model_info.modelId.split("/")[-1] | |
| model_title = st.text_input("Model Title", value=model_name.replace("-", " ").title()) | |
| with col2: | |
| # Model type selection | |
| model_type_options = [ | |
| "Text Classification", | |
| "Token Classification", | |
| "Question Answering", | |
| "Summarization", | |
| "Translation", | |
| "Text Generation", | |
| "Image Classification", | |
| "Object Detection", | |
| "Other" | |
| ] | |
| # Try to determine model type from tags | |
| default_type_index = 0 | |
| tags = getattr(model_info, "tags", []) | |
| for i, option in enumerate(model_type_options): | |
| option_key = option.lower().replace(" ", "-") | |
| if option_key in tags or option_key.replace("-", "_") in tags: | |
| default_type_index = i | |
| break | |
| model_type = st.selectbox( | |
| "Model Type", | |
| model_type_options, | |
| index=default_type_index | |
| ) | |
| # Model description | |
| description = st.text_area( | |
| "Model Description", | |
| value=getattr(model_info, "description", "") or "", | |
| height=100, | |
| help="A brief overview of what the model does" | |
| ) | |
| # Technical Information | |
| st.markdown("#### Technical Information") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Model Architecture | |
| architecture_options = [ | |
| "BERT", "GPT-2", "T5", "RoBERTa", "DeBERTa", "DistilBERT", | |
| "BART", "ResNet", "YOLO", "Other" | |
| ] | |
| architecture = st.selectbox("Model Architecture", architecture_options) | |
| # Framework | |
| framework_options = ["PyTorch", "TensorFlow", "JAX", "Other"] | |
| framework = st.selectbox("Framework", framework_options) | |
| with col2: | |
| # Model size | |
| model_size = st.text_input("Model Size (e.g., 110M parameters)") | |
| # Language | |
| language_options = ["English", "French", "German", "Spanish", "Chinese", "Japanese", "Multilingual", "Other"] | |
| language = st.selectbox("Language", language_options) | |
| # Training Information | |
| st.markdown("#### Training Information") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Training Dataset | |
| training_data = st.text_input("Training Dataset(s)") | |
| # Training compute | |
| training_compute = st.text_input("Training Infrastructure (e.g., TPU v3-8, 4x A100)") | |
| with col2: | |
| # Evaluation Dataset | |
| eval_data = st.text_input("Evaluation Dataset(s)") | |
| # Training time | |
| training_time = st.text_input("Training Time (e.g., 3 days, 12 hours)") | |
| # Performance Metrics | |
| st.markdown("#### Performance Metrics") | |
| metrics_data = st.text_area( | |
| "Performance Metrics (one per line, e.g., 'Accuracy: 0.92')", | |
| height=100, | |
| help="Key metrics and their values" | |
| ) | |
| # Limitations | |
| st.markdown("#### Limitations and Biases") | |
| limitations = st.text_area( | |
| "Known Limitations and Biases", | |
| height=100, | |
| help="Document any known limitations, biases, or ethical considerations" | |
| ) | |
| # Usage Information | |
| st.markdown("#### Usage Information") | |
| use_cases = st.text_area( | |
| "Intended Use Cases", | |
| height=100, | |
| help="Describe how the model should be used" | |
| ) | |
| code_example = st.text_area( | |
| "Code Example", | |
| height=150, | |
| value=f""" | |
| ```python | |
| from transformers import AutoTokenizer, AutoModel | |
| tokenizer = AutoTokenizer.from_pretrained("{model_info.modelId}") | |
| model = AutoModel.from_pretrained("{model_info.modelId}") | |
| inputs = tokenizer("Hello, world!", return_tensors="pt") | |
| outputs = model(**inputs) | |
| ``` | |
| """, | |
| help="Provide a simple code example showing how to use the model" | |
| ) | |
| # License and Citation | |
| st.markdown("#### License and Citation") | |
| license_options = ["MIT", "Apache-2.0", "GPL-3.0", "CC-BY-SA-4.0", "CC-BY-4.0", "Proprietary", "Other"] | |
| license_type = st.selectbox("License", license_options) | |
| citation = st.text_area( | |
| "Citation Information", | |
| height=100, | |
| help="Provide citation information if applicable" | |
| ) | |
| # Tags | |
| st.markdown("#### Tags") | |
| # Get available tags | |
| available_tags = st.session_state.client.get_model_tags() | |
| # Extract existing tags | |
| existing_tags = [] | |
| if yaml_content: | |
| tags_match = re.search(r"tags:\s*((?:- .*?\n)+)", yaml_content, re.DOTALL) | |
| if tags_match: | |
| existing_tags = [ | |
| line.strip("- \n") | |
| for line in tags_match.group(1).split("\n") | |
| if line.strip().startswith("-") | |
| ] | |
| selected_tags = st.multiselect( | |
| "Select tags for your model", | |
| options=available_tags, | |
| default=existing_tags, | |
| help="Tags help others discover your model" | |
| ) | |
| # Advanced options | |
| with st.expander("Advanced Options"): | |
| keep_existing_content = st.checkbox( | |
| "Keep existing custom content", | |
| value=True, | |
| help="If checked, we'll try to preserve custom sections from your existing model card" | |
| ) | |
| additional_sections = st.text_area( | |
| "Additional Custom Sections (in Markdown)", | |
| height=200, | |
| help="Add any additional custom sections in Markdown format" | |
| ) | |
| # Submit button | |
| submitted = st.form_submit_button("Generate Model Card", use_container_width=True) | |
| if submitted: | |
| with st.spinner("Generating comprehensive model card..."): | |
| try: | |
| # Parse performance metrics | |
| metrics_list = [] | |
| for line in metrics_data.split("\n"): | |
| line = line.strip() | |
| if line: | |
| metrics_list.append(line) | |
| # Generate YAML frontmatter | |
| yaml_frontmatter = f"""tags: | |
| {chr(10).join(['- ' + tag for tag in selected_tags])} | |
| license: {license_type}""" | |
| if language and language != "Other": | |
| yaml_frontmatter += f"\nlanguage: {language.lower()}" | |
| if model_type and model_type != "Other": | |
| yaml_frontmatter += f"\npipeline_tag: {model_type.lower().replace(' ', '-')}" | |
| # Generate markdown content | |
| md_content = f"""# {model_title} | |
| {description} | |
| ## Model Description | |
| This model is a {architecture}-based model for {model_type} tasks. It was developed using {framework} and consists of {model_size if model_size else "multiple"} parameters. | |
| """ | |
| # Training section | |
| if training_data or eval_data or training_compute or training_time: | |
| md_content += "## Training and Evaluation Data\n\n" | |
| if training_data: | |
| md_content += f"The model was trained on {training_data}. " | |
| if training_compute: | |
| md_content += f"Training was performed using {training_compute}. " | |
| if training_time: | |
| md_content += f"The total training time was approximately {training_time}." | |
| md_content += "\n\n" | |
| if eval_data: | |
| md_content += f"Evaluation was performed on {eval_data}.\n\n" | |
| # Performance metrics | |
| if metrics_list: | |
| md_content += "## Model Performance\n\n" | |
| md_content += "The model achieves the following performance metrics:\n\n" | |
| for metric in metrics_list: | |
| md_content += f"- {metric}\n" | |
| md_content += "\n" | |
| # Limitations | |
| if limitations: | |
| md_content += "## Limitations and Biases\n\n" | |
| md_content += f"{limitations}\n\n" | |
| # Usage | |
| if use_cases: | |
| md_content += "## Intended Uses & Limitations\n\n" | |
| md_content += f"{use_cases}\n\n" | |
| # Code example | |
| if code_example: | |
| md_content += "## How to Use\n\n" | |
| md_content += "Here's an example of how to use this model:\n\n" | |
| md_content += f"{code_example}\n\n" | |
| # Citation | |
| if citation: | |
| md_content += "## Citation\n\n" | |
| md_content += f"{citation}\n\n" | |
| # Keep existing custom content if requested | |
| if keep_existing_content and markdown_content: | |
| # Try to extract sections we haven't covered | |
| existing_sections = re.findall(r"^## (.+?)\n\n(.*?)(?=^## |\Z)", markdown_content, re.MULTILINE | re.DOTALL) | |
| standard_sections = ["Model Description", "Training and Evaluation Data", "Model Performance", | |
| "Limitations and Biases", "Intended Uses & Limitations", "How to Use", "Citation"] | |
| for section_title, section_content in existing_sections: | |
| if section_title.strip() not in standard_sections: | |
| md_content += f"## {section_title}\n\n{section_content}\n\n" | |
| # Add additional custom sections | |
| if additional_sections: | |
| md_content += f"\n{additional_sections}\n" | |
| # Combine everything into the final model card | |
| final_model_card = f"---\n{yaml_frontmatter}\n---\n\n{md_content.strip()}" | |
| # Display the generated model card | |
| st.markdown("### Generated Model Card") | |
| st.code(final_model_card, language="markdown") | |
| # Option to update the model card | |
| if st.button("Update Model Card", use_container_width=True, type="primary"): | |
| with st.spinner("Updating model card..."): | |
| try: | |
| # Update the model card | |
| success, _ = st.session_state.client.update_model_card( | |
| model_info.modelId, final_model_card | |
| ) | |
| if success: | |
| st.success("Model card updated successfully!") | |
| time.sleep(1) # Give API time to update | |
| st.rerun() | |
| else: | |
| st.error("Failed to update model card") | |
| except Exception as e: | |
| st.error(f"Error updating model card: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error generating model card: {str(e)}") | |