import spaces import gradio as gr from huggingface_hub import snapshot_download import os, platform, shutil from pathlib import Path from importlib import metadata as md from typing import List print("Python:", platform.python_version()) print("OS:", platform.uname().system) print("\n".join(sorted(f"{d.metadata['Name']}=={d.version}" for d in md.distributions()))) repo_id = "stabilityai/stable-diffusion-3.5-medium" local_dir="./test" cache_dir="./test-cache" ALLOW_PATTERNS = ["transformer/diffusion_pytorch_model*", "vae/diffusion_pytorch_model*", "text_encoder/model*", "text_encoder_2/model*", "text_encoder_3/model*", "text_encoder/pytorch_model*", "text_encoder_2/pytorch_model*", "text_encoder_3/pytorch_model*"] IGNORE_PATTERNS = ["*fp16*", "**/*fp16*"] IGNORE_PATTERNS_EX = ["**/model*", "**/*.safetensors"] HF_TOKEN = os.getenv("HF_TOKEN", None) @spaces.GPU def test(repo_id: str, allow: List[str], ignore: List[str], del_cache: bool, progress=gr.Progress(track_tqdm=True)) -> str: if del_cache: shutil.rmtree(cache_dir, ignore_errors=True) shutil.rmtree(local_dir, ignore_errors=True) kwargs = {} if len(allow) > 0: kwargs["allow_patterns"] = allow if len(ignore) > 0: kwargs["ignore_patterns"] = ignore snapshot_download( repo_id=repo_id, local_dir=local_dir, cache_dir=cache_dir, token=HF_TOKEN, **kwargs, ) result = (f"Python:{platform.python_version()}\n" f"OS:{platform.uname().system}\n\n" f"repo_id:{repo_id}\n" f"allow_patterns:{", ".join(allow)}\n" f"ignore_patterns:{", ".join(ignore)}\n\n" f'Downloaded:\n{"\n".join([str(p) for p in Path(local_dir).rglob("*") if p.is_file()])}') return result with gr.Blocks() as demo: repo_id = gr.Textbox(label="Repo ID", value=repo_id, lines=1, visible=False) allow = gr.CheckboxGroup(label="allow_patterns", choices=ALLOW_PATTERNS, value=ALLOW_PATTERNS) ignore = gr.CheckboxGroup(label="ignore_patterns", choices=IGNORE_PATTERNS + IGNORE_PATTERNS_EX, value=IGNORE_PATTERNS) del_cache = gr.Checkbox(label="Delete cache", value=False) run_button = gr.Button("Download", variant="primary") info = gr.Textbox(label="Result", value="", show_copy_button=True) run_button.click(test, [repo_id, allow, ignore, del_cache], [info]) demo.queue().launch()