Spaces:
Sleeping
Sleeping
| # core/image_generator.py | |
| import os | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| from typing import List | |
| # ---------------- MODEL CONFIG ---------------- | |
| MODEL_REPO = "SG161222/RealVisXL_V4.0" | |
| MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors" | |
| MODEL_DIR = Path("/tmp/models/realvisxl_v4") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # ---------------- MODEL DOWNLOAD ---------------- | |
| def download_model() -> Path: | |
| """ | |
| Downloads RealVisXL V4.0 model if not present. | |
| Returns the local model path. | |
| """ | |
| model_path = MODEL_DIR / MODEL_FILENAME | |
| if not model_path.exists(): | |
| print("[ImageGen] Downloading RealVisXL V4.0 model...") | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME, | |
| local_dir=str(MODEL_DIR), | |
| force_download=False, | |
| ) | |
| print(f"[ImageGen] Model downloaded to: {model_path}") | |
| else: | |
| print("[ImageGen] Model already exists. Skipping download.") | |
| return model_path | |
| # ---------------- PIPELINE LOAD ---------------- | |
| def load_pipeline() -> StableDiffusionXLPipeline: | |
| """ | |
| Loads the RealVisXL V4.0 model for image generation. | |
| """ | |
| model_path = download_model() | |
| print("[ImageGen] Loading model into pipeline...") | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| str(model_path), | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| if torch.cuda.is_available(): | |
| pipe.to("cuda") | |
| print("[ImageGen] Model ready.") | |
| return pipe | |
| # ---------------- GLOBAL PIPELINE CACHE ---------------- | |
| pipe: StableDiffusionXLPipeline | None = None | |
| # ---------------- IMAGE GENERATION ---------------- | |
| def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List: | |
| """ | |
| Generates high-quality images using RealVisXL V4.0. | |
| Supports deterministic generation using a seed. | |
| Args: | |
| prompt (str): Text prompt for image generation. | |
| seed (int, optional): Seed for deterministic generation. | |
| num_images (int): Number of images to generate. | |
| Returns: | |
| List: Generated PIL images. | |
| """ | |
| global pipe | |
| if pipe is None: | |
| pipe = load_pipeline() | |
| print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}") | |
| images = [] | |
| for i in range(num_images): | |
| generator = None | |
| if seed is not None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| generator = torch.Generator(device).manual_seed(seed + i) # slightly vary keyframes | |
| result = pipe(prompt, num_inference_steps=30, generator=generator).images[0] | |
| images.append(result) | |
| print(f"[ImageGen] Generated {len(images)} images successfully.") | |
| return images | |