John6666's picture
Upload app.py
7d7fb2b verified
import spaces
import gradio as gr
from huggingface_hub import snapshot_download
import os, platform, shutil
from pathlib import Path
from importlib import metadata as md
from typing import List
print("Python:", platform.python_version())
print("OS:", platform.uname().system)
print("\n".join(sorted(f"{d.metadata['Name']}=={d.version}" for d in md.distributions())))
repo_id = "stabilityai/stable-diffusion-3.5-medium"
local_dir="./test"
cache_dir="./test-cache"
ALLOW_PATTERNS = ["transformer/diffusion_pytorch_model*", "vae/diffusion_pytorch_model*", "text_encoder/model*", "text_encoder_2/model*", "text_encoder_3/model*", "text_encoder/pytorch_model*", "text_encoder_2/pytorch_model*", "text_encoder_3/pytorch_model*"]
IGNORE_PATTERNS = ["*fp16*", "**/*fp16*"]
IGNORE_PATTERNS_EX = ["**/model*", "**/*.safetensors"]
HF_TOKEN = os.getenv("HF_TOKEN", None)
@spaces.GPU
def test(repo_id: str, allow: List[str], ignore: List[str], del_cache: bool, progress=gr.Progress(track_tqdm=True)) -> str:
if del_cache: shutil.rmtree(cache_dir, ignore_errors=True)
shutil.rmtree(local_dir, ignore_errors=True)
kwargs = {}
if len(allow) > 0: kwargs["allow_patterns"] = allow
if len(ignore) > 0: kwargs["ignore_patterns"] = ignore
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
cache_dir=cache_dir,
token=HF_TOKEN,
**kwargs,
)
result = (f"Python:{platform.python_version()}\n"
f"OS:{platform.uname().system}\n\n"
f"repo_id:{repo_id}\n"
f"allow_patterns:{", ".join(allow)}\n"
f"ignore_patterns:{", ".join(ignore)}\n\n"
f'Downloaded:\n{"\n".join([str(p) for p in Path(local_dir).rglob("*") if p.is_file()])}')
return result
with gr.Blocks() as demo:
repo_id = gr.Textbox(label="Repo ID", value=repo_id, lines=1, visible=False)
allow = gr.CheckboxGroup(label="allow_patterns", choices=ALLOW_PATTERNS, value=ALLOW_PATTERNS)
ignore = gr.CheckboxGroup(label="ignore_patterns", choices=IGNORE_PATTERNS + IGNORE_PATTERNS_EX, value=IGNORE_PATTERNS)
del_cache = gr.Checkbox(label="Delete cache", value=False)
run_button = gr.Button("Download", variant="primary")
info = gr.Textbox(label="Result", value="", show_copy_button=True)
run_button.click(test, [repo_id, allow, ignore, del_cache], [info])
demo.queue().launch()