Spaces:
Build error
Build error
| import json | |
| import tempfile | |
| import zipfile | |
| from datetime import datetime | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import CommitScheduler, InferenceClient | |
| IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}" | |
| IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" | |
| class ZipScheduler(CommitScheduler): | |
| """ | |
| Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub. | |
| Workflow: | |
| 1. Read metadata + list PNG files. | |
| 2. Zip png files in a single archive. | |
| 3. Create commit (metadata + archive). | |
| 4. Delete local png files to avoid re-uploading them later. | |
| Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the | |
| process can be done without blocking the Gradio app. | |
| """ | |
| def push_to_hub(self): | |
| # 1. Read metadata + list PNG files | |
| with self.lock: | |
| png_files = list(self.folder_path.glob("*.png")) | |
| if len(png_files) == 0: | |
| return None # return early if nothing to commit | |
| # Read and delete metadata file | |
| metadata = IMAGE_JSONL_PATH.read_text() | |
| try: | |
| IMAGE_JSONL_PATH.unlink() | |
| except Exception: | |
| pass | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # 2. Zip png files + metadata in a single archive | |
| archive_path = Path(tmpdir) / "train.zip" | |
| with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: | |
| # PNG files | |
| for png_file in png_files: | |
| zip.write(filename=png_file, arcname=png_file.name) | |
| # Metadata | |
| tmp_metadata = Path(tmpdir) / "metadata.jsonl" | |
| tmp_metadata.write_text(metadata) | |
| zip.write(filename=tmp_metadata, arcname="metadata.jsonl") | |
| # 3. Create commit | |
| self.api.upload_file( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| revision=self.revision, | |
| path_in_repo=f"train-{uuid4()}.zip", | |
| path_or_fileobj=archive_path, | |
| ) | |
| # 4. Delete local png files to avoid re-uploading them later | |
| for png_file in png_files: | |
| try: | |
| png_file.unlink() | |
| except Exception: | |
| pass | |
| scheduler = ZipScheduler( | |
| repo_id="example-space-to-dataset-image-zip", | |
| repo_type="dataset", | |
| folder_path=IMAGE_DATASET_DIR, | |
| ) | |
| client = InferenceClient() | |
| def generate_image(prompt: str) -> Image: | |
| return client.text_to_image(prompt) | |
| def save_image(prompt: str, image_array: np.ndarray) -> None: | |
| print("Saving: " + prompt) | |
| image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png" | |
| with scheduler.lock: | |
| Image.fromarray(image_array).save(image_path) | |
| with IMAGE_JSONL_PATH.open("a") as f: | |
| json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f) | |
| f.write("\n") | |
| def get_demo(): | |
| with gr.Row(): | |
| prompt_value = gr.Textbox(label="Prompt") | |
| image_value = gr.Image(label="Generated image") | |
| text_to_image_btn = gr.Button("Generate") | |
| text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success( | |
| fn=save_image, | |
| inputs=[prompt_value, image_value], | |
| outputs=None, | |
| ) | |