| 
							 | 
						import json | 
					
					
						
						| 
							 | 
						import tempfile | 
					
					
						
						| 
							 | 
						import zipfile | 
					
					
						
						| 
							 | 
						from datetime import datetime | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						from uuid import uuid4 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from huggingface_hub import CommitScheduler, InferenceClient | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) | 
					
					
						
						| 
							 | 
						IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class ZipScheduler(CommitScheduler): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Workflow: | 
					
					
						
						| 
							 | 
						    1. Read metadata + list PNG files. | 
					
					
						
						| 
							 | 
						    2. Zip png files in a single archive. | 
					
					
						
						| 
							 | 
						    3. Create commit (metadata + archive). | 
					
					
						
						| 
							 | 
						    4. Delete local png files to avoid re-uploading them later. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the | 
					
					
						
						| 
							 | 
						    process can be done without blocking the Gradio app. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def push_to_hub(self): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with self.lock: | 
					
					
						
						| 
							 | 
						            png_files = list(self.folder_path.glob("*.png")) | 
					
					
						
						| 
							 | 
						            if len(png_files) == 0: | 
					
					
						
						| 
							 | 
						                return None   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            metadata = IMAGE_JSONL_PATH.read_text() | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                IMAGE_JSONL_PATH.unlink() | 
					
					
						
						| 
							 | 
						            except Exception: | 
					
					
						
						| 
							 | 
						                pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with tempfile.TemporaryDirectory() as tmpdir: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            archive_path = Path(tmpdir) / "train.zip" | 
					
					
						
						| 
							 | 
						            with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                for png_file in png_files: | 
					
					
						
						| 
							 | 
						                    zip.write(filename=png_file, arcname=png_file.name) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                tmp_metadata = Path(tmpdir) / "metadata.jsonl" | 
					
					
						
						| 
							 | 
						                tmp_metadata.write_text(metadata) | 
					
					
						
						| 
							 | 
						                zip.write(filename=tmp_metadata, arcname="metadata.jsonl") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.api.upload_file( | 
					
					
						
						| 
							 | 
						                repo_id=self.repo_id, | 
					
					
						
						| 
							 | 
						                repo_type=self.repo_type, | 
					
					
						
						| 
							 | 
						                revision=self.revision, | 
					
					
						
						| 
							 | 
						                path_in_repo=f"train-{uuid4()}.zip", | 
					
					
						
						| 
							 | 
						                path_or_fileobj=archive_path, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for png_file in png_files: | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                png_file.unlink() | 
					
					
						
						| 
							 | 
						            except Exception: | 
					
					
						
						| 
							 | 
						                pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						scheduler = ZipScheduler( | 
					
					
						
						| 
							 | 
						    repo_id="example-space-to-dataset-image-zip", | 
					
					
						
						| 
							 | 
						    repo_type="dataset", | 
					
					
						
						| 
							 | 
						    folder_path=IMAGE_DATASET_DIR, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						client = InferenceClient() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def generate_image(prompt: str) -> Image: | 
					
					
						
						| 
							 | 
						    return client.text_to_image(prompt) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def save_image(prompt: str, image_array: np.ndarray) -> None: | 
					
					
						
						| 
							 | 
						    print("Saving: " + prompt) | 
					
					
						
						| 
							 | 
						    image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with scheduler.lock: | 
					
					
						
						| 
							 | 
						        Image.fromarray(image_array).save(image_path) | 
					
					
						
						| 
							 | 
						        with IMAGE_JSONL_PATH.open("a") as f: | 
					
					
						
						| 
							 | 
						            json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f) | 
					
					
						
						| 
							 | 
						            f.write("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_demo(): | 
					
					
						
						| 
							 | 
						    with gr.Row(): | 
					
					
						
						| 
							 | 
						        prompt_value = gr.Textbox(label="Prompt") | 
					
					
						
						| 
							 | 
						        image_value = gr.Image(label="Generated image") | 
					
					
						
						| 
							 | 
						    text_to_image_btn = gr.Button("Generate") | 
					
					
						
						| 
							 | 
						    text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success( | 
					
					
						
						| 
							 | 
						        fn=save_image, | 
					
					
						
						| 
							 | 
						        inputs=[prompt_value, image_value], | 
					
					
						
						| 
							 | 
						        outputs=None, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 |