|  | import json | 
					
						
						|  | import tempfile | 
					
						
						|  | import zipfile | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from uuid import uuid4 | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | from huggingface_hub import CommitScheduler, InferenceClient | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}" | 
					
						
						|  |  | 
					
						
						|  | IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) | 
					
						
						|  | IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ZipScheduler(CommitScheduler): | 
					
						
						|  | """ | 
					
						
						|  | Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub. | 
					
						
						|  |  | 
					
						
						|  | Workflow: | 
					
						
						|  | 1. Read metadata + list PNG files. | 
					
						
						|  | 2. Zip png files in a single archive. | 
					
						
						|  | 3. Create commit (metadata + archive). | 
					
						
						|  | 4. Delete local png files to avoid re-uploading them later. | 
					
						
						|  |  | 
					
						
						|  | Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the | 
					
						
						|  | process can be done without blocking the Gradio app. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def push_to_hub(self): | 
					
						
						|  |  | 
					
						
						|  | with self.lock: | 
					
						
						|  | png_files = list(self.folder_path.glob("*.png")) | 
					
						
						|  | if len(png_files) == 0: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | metadata = IMAGE_JSONL_PATH.read_text() | 
					
						
						|  | try: | 
					
						
						|  | IMAGE_JSONL_PATH.unlink() | 
					
						
						|  | except Exception: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | with tempfile.TemporaryDirectory() as tmpdir: | 
					
						
						|  |  | 
					
						
						|  | archive_path = Path(tmpdir) / "train.zip" | 
					
						
						|  | with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: | 
					
						
						|  |  | 
					
						
						|  | for png_file in png_files: | 
					
						
						|  | zip.write(filename=png_file, arcname=png_file.name) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tmp_metadata = Path(tmpdir) / "metadata.jsonl" | 
					
						
						|  | tmp_metadata.write_text(metadata) | 
					
						
						|  | zip.write(filename=tmp_metadata, arcname="metadata.jsonl") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.api.upload_file( | 
					
						
						|  | repo_id=self.repo_id, | 
					
						
						|  | repo_type=self.repo_type, | 
					
						
						|  | revision=self.revision, | 
					
						
						|  | path_in_repo=f"train-{uuid4()}.zip", | 
					
						
						|  | path_or_fileobj=archive_path, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for png_file in png_files: | 
					
						
						|  | try: | 
					
						
						|  | png_file.unlink() | 
					
						
						|  | except Exception: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scheduler = ZipScheduler( | 
					
						
						|  | repo_id="example-space-to-dataset-image-zip", | 
					
						
						|  | repo_type="dataset", | 
					
						
						|  | folder_path=IMAGE_DATASET_DIR, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | client = InferenceClient() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_image(prompt: str) -> Image: | 
					
						
						|  | return client.text_to_image(prompt) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_image(prompt: str, image_array: np.ndarray) -> None: | 
					
						
						|  | print("Saving: " + prompt) | 
					
						
						|  | image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png" | 
					
						
						|  |  | 
					
						
						|  | with scheduler.lock: | 
					
						
						|  | Image.fromarray(image_array).save(image_path) | 
					
						
						|  | with IMAGE_JSONL_PATH.open("a") as f: | 
					
						
						|  | json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f) | 
					
						
						|  | f.write("\n") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_demo(): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | prompt_value = gr.Textbox(label="Prompt") | 
					
						
						|  | image_value = gr.Image(label="Generated image") | 
					
						
						|  | text_to_image_btn = gr.Button("Generate") | 
					
						
						|  | text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success( | 
					
						
						|  | fn=save_image, | 
					
						
						|  | inputs=[prompt_value, image_value], | 
					
						
						|  | outputs=None, | 
					
						
						|  | ) | 
					
						
						|  |  |