Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -38,30 +38,52 @@ def load_model(): | |
| 38 | 
             
                    logger.error(f"Failed to load model: {e}")
         | 
| 39 | 
             
                    return False
         | 
| 40 |  | 
| 41 | 
            -
            def get_card_info(hub_id: str) -> Tuple[str, str]:
         | 
| 42 | 
             
                """Get card information from a Hugging Face hub_id."""
         | 
| 43 | 
             
                model_exists = False
         | 
| 44 | 
             
                dataset_exists = False
         | 
| 45 | 
             
                model_text = None
         | 
| 46 | 
             
                dataset_text = None
         | 
| 47 |  | 
| 48 | 
            -
                #  | 
| 49 | 
            -
                 | 
| 50 | 
            -
                     | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
                     | 
| 60 | 
            -
                     | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 65 |  | 
| 66 | 
             
                # Handle different cases
         | 
| 67 | 
             
                if model_exists and dataset_exists:
         | 
| @@ -115,12 +137,12 @@ def generate_summary(card_text: str, card_type: str) -> str: | |
| 115 | 
             
                """Cached wrapper for generate_summary with TTL."""
         | 
| 116 | 
             
                return _generate_summary_gpu(card_text, card_type)
         | 
| 117 |  | 
| 118 | 
            -
            def summarize(hub_id: str = "") -> str:
         | 
| 119 | 
             
                """Interface function for Gradio. Returns JSON format."""
         | 
| 120 | 
             
                try:
         | 
| 121 | 
             
                    if hub_id:
         | 
| 122 | 
            -
                        # Fetch  | 
| 123 | 
            -
                        card_type, card_text = get_card_info(hub_id)
         | 
| 124 |  | 
| 125 | 
             
                        if card_type == "both":
         | 
| 126 | 
             
                            model_text, dataset_text = card_text
         | 
| @@ -148,7 +170,15 @@ def summarize(hub_id: str = "") -> str: | |
| 148 | 
             
            def create_interface():
         | 
| 149 | 
             
                interface = gr.Interface(
         | 
| 150 | 
             
                    fn=summarize,
         | 
| 151 | 
            -
                    inputs= | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 | 
             
                    outputs=gr.JSON(label="Output"),
         | 
| 153 | 
             
                    title="Hugging Face Hub TLDR Generator",
         | 
| 154 | 
             
                    description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
         | 
| @@ -160,4 +190,4 @@ if __name__ == "__main__": | |
| 160 | 
             
                    interface = create_interface()
         | 
| 161 | 
             
                    interface.launch()
         | 
| 162 | 
             
                else:
         | 
| 163 | 
            -
                    print("Failed to load model. Please check the logs for details.")
         | 
|  | |
| 38 | 
             
                    logger.error(f"Failed to load model: {e}")
         | 
| 39 | 
             
                    return False
         | 
| 40 |  | 
| 41 | 
            +
            def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]:
         | 
| 42 | 
             
                """Get card information from a Hugging Face hub_id."""
         | 
| 43 | 
             
                model_exists = False
         | 
| 44 | 
             
                dataset_exists = False
         | 
| 45 | 
             
                model_text = None
         | 
| 46 | 
             
                dataset_text = None
         | 
| 47 |  | 
| 48 | 
            +
                # Handle based on repo type
         | 
| 49 | 
            +
                if repo_type == "auto":
         | 
| 50 | 
            +
                    # Try getting model card
         | 
| 51 | 
            +
                    try:
         | 
| 52 | 
            +
                        info = model_info(hub_id)
         | 
| 53 | 
            +
                        card = ModelCard.load(hub_id)
         | 
| 54 | 
            +
                        model_exists = True
         | 
| 55 | 
            +
                        model_text = card.text
         | 
| 56 | 
            +
                    except Exception as e:
         | 
| 57 | 
            +
                        logger.debug(f"No model card found for {hub_id}: {e}")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Try getting dataset card
         | 
| 60 | 
            +
                    try:
         | 
| 61 | 
            +
                        info = dataset_info(hub_id)
         | 
| 62 | 
            +
                        card = DatasetCard.load(hub_id)
         | 
| 63 | 
            +
                        dataset_exists = True
         | 
| 64 | 
            +
                        dataset_text = card.text
         | 
| 65 | 
            +
                    except Exception as e:
         | 
| 66 | 
            +
                        logger.debug(f"No dataset card found for {hub_id}: {e}")
         | 
| 67 | 
            +
                elif repo_type == "model":
         | 
| 68 | 
            +
                    try:
         | 
| 69 | 
            +
                        info = model_info(hub_id)
         | 
| 70 | 
            +
                        card = ModelCard.load(hub_id)
         | 
| 71 | 
            +
                        model_exists = True
         | 
| 72 | 
            +
                        model_text = card.text
         | 
| 73 | 
            +
                    except Exception as e:
         | 
| 74 | 
            +
                        logger.error(f"Failed to get model card for {hub_id}: {e}")
         | 
| 75 | 
            +
                        raise ValueError(f"Could not find model with id {hub_id}")
         | 
| 76 | 
            +
                elif repo_type == "dataset":
         | 
| 77 | 
            +
                    try:
         | 
| 78 | 
            +
                        info = dataset_info(hub_id)
         | 
| 79 | 
            +
                        card = DatasetCard.load(hub_id)
         | 
| 80 | 
            +
                        dataset_exists = True
         | 
| 81 | 
            +
                        dataset_text = card.text
         | 
| 82 | 
            +
                    except Exception as e:
         | 
| 83 | 
            +
                        logger.error(f"Failed to get dataset card for {hub_id}: {e}")
         | 
| 84 | 
            +
                        raise ValueError(f"Could not find dataset with id {hub_id}")
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'")
         | 
| 87 |  | 
| 88 | 
             
                # Handle different cases
         | 
| 89 | 
             
                if model_exists and dataset_exists:
         | 
|  | |
| 137 | 
             
                """Cached wrapper for generate_summary with TTL."""
         | 
| 138 | 
             
                return _generate_summary_gpu(card_text, card_type)
         | 
| 139 |  | 
| 140 | 
            +
            def summarize(hub_id: str = "", repo_type: str = "auto") -> str:
         | 
| 141 | 
             
                """Interface function for Gradio. Returns JSON format."""
         | 
| 142 | 
             
                try:
         | 
| 143 | 
             
                    if hub_id:
         | 
| 144 | 
            +
                        # Fetch card information with specified repo_type
         | 
| 145 | 
            +
                        card_type, card_text = get_card_info(hub_id, repo_type)
         | 
| 146 |  | 
| 147 | 
             
                        if card_type == "both":
         | 
| 148 | 
             
                            model_text, dataset_text = card_text
         | 
|  | |
| 170 | 
             
            def create_interface():
         | 
| 171 | 
             
                interface = gr.Interface(
         | 
| 172 | 
             
                    fn=summarize,
         | 
| 173 | 
            +
                    inputs=[
         | 
| 174 | 
            +
                        gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
         | 
| 175 | 
            +
                        gr.Radio(
         | 
| 176 | 
            +
                            choices=["auto", "model", "dataset"], 
         | 
| 177 | 
            +
                            value="auto", 
         | 
| 178 | 
            +
                            label="Repository Type",
         | 
| 179 | 
            +
                            info="Choose 'auto' to detect automatically, or specify the repository type"
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                    ],
         | 
| 182 | 
             
                    outputs=gr.JSON(label="Output"),
         | 
| 183 | 
             
                    title="Hugging Face Hub TLDR Generator",
         | 
| 184 | 
             
                    description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
         | 
|  | |
| 190 | 
             
                    interface = create_interface()
         | 
| 191 | 
             
                    interface.launch()
         | 
| 192 | 
             
                else:
         | 
| 193 | 
            +
                    print("Failed to load model. Please check the logs for details.")
         | 
