Spaces:
Sleeping
Sleeping
| # # core/image_generator.py | |
| # import os | |
| # import torch | |
| # from diffusers import StableDiffusionXLPipeline | |
| # from huggingface_hub import hf_hub_download | |
| # from pathlib import Path | |
| # from typing import List | |
| # from io import BytesIO | |
| # import base64 | |
| # from PIL import Image | |
| # # Set cache and model directories early | |
| # HF_CACHE_DIR = Path("/tmp/hf_cache") | |
| # HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # os.chmod(HF_CACHE_DIR, 0o777) | |
| # os.environ["HF_HOME"] = str(HF_CACHE_DIR) | |
| # os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR) | |
| # os.environ["XDG_CACHE_HOME"] = str(HF_CACHE_DIR) | |
| # os.environ["HF_DATASETS_CACHE"] = str(HF_CACHE_DIR) | |
| # os.environ["HF_MODULES_CACHE"] = str(HF_CACHE_DIR) | |
| # MODEL_DIR = Path("/tmp/models/realvisxl_v4") | |
| # MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| # os.chmod(MODEL_DIR, 0o777) | |
| # # ---------------- MODEL CONFIG ---------------- | |
| # MODEL_REPO = "SG161222/RealVisXL_V4.0" | |
| # MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors" | |
| # MODEL_DIR = Path("/tmp/models/realvisxl_v4") | |
| # os.makedirs(MODEL_DIR, exist_ok=True) | |
| # # ---------------- MODEL DOWNLOAD ---------------- | |
| # def download_model() -> Path: | |
| # """ | |
| # Downloads RealVisXL V4.0 model if not present. | |
| # Returns the local model path. | |
| # """ | |
| # model_path = MODEL_DIR / MODEL_FILENAME | |
| # if not model_path.exists(): | |
| # print("[ImageGen] Downloading RealVisXL V4.0 model...") | |
| # model_path = hf_hub_download( | |
| # repo_id=MODEL_REPO, | |
| # filename=MODEL_FILENAME, | |
| # local_dir=str(MODEL_DIR), | |
| # cache_dir=str(HF_CACHE_DIR), # ensure writable cache is used | |
| # force_download=False, | |
| # ) | |
| # print(f"[ImageGen] Model downloaded to: {model_path}") | |
| # else: | |
| # print("[ImageGen] Model already exists. Skipping download.") | |
| # return model_path | |
| # # ---------------- PIPELINE LOAD ---------------- | |
| # def load_pipeline() -> StableDiffusionXLPipeline: | |
| # """ | |
| # Loads the RealVisXL V4.0 model for image generation. | |
| # """ | |
| # model_path = download_model() | |
| # print("[ImageGen] Loading model into pipeline...") | |
| # pipe = StableDiffusionXLPipeline.from_single_file( | |
| # str(model_path), | |
| # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| # ) | |
| # if torch.cuda.is_available(): | |
| # pipe.to("cuda") | |
| # print("[ImageGen] Model ready.") | |
| # return pipe | |
| # # ---------------- GLOBAL PIPELINE CACHE ---------------- | |
| # pipe: StableDiffusionXLPipeline | None = None | |
| # # ---------------- UTILITY: PIL TO BASE64 ---------------- | |
| # def pil_to_base64(img: Image.Image) -> str: | |
| # """ | |
| # Converts a PIL image to a base64 string for frontend display. | |
| # """ | |
| # buffered = BytesIO() | |
| # img.save(buffered, format="PNG") | |
| # img_bytes = buffered.getvalue() | |
| # img_b64 = base64.b64encode(img_bytes).decode("utf-8") | |
| # return f"data:image/png;base64,{img_b64}" | |
| # # ---------------- IMAGE GENERATION ---------------- | |
| # def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List[str]: | |
| # """ | |
| # Generates high-quality images using RealVisXL V4.0. | |
| # Supports deterministic generation using a seed. | |
| # Args: | |
| # prompt (str): Text prompt for image generation. | |
| # seed (int, optional): Seed for deterministic generation. | |
| # num_images (int): Number of images to generate. | |
| # Returns: | |
| # List[str]: List of base64-encoded images. | |
| # """ | |
| # global pipe | |
| # if pipe is None: | |
| # pipe = load_pipeline() | |
| # print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}") | |
| # images: List[str] = [] | |
| # for i in range(num_images): | |
| # generator = None | |
| # if seed is not None: | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # generator = torch.Generator(device).manual_seed(seed + i) | |
| # result = pipe(prompt, num_inference_steps=30, generator=generator).images[0] | |
| # images.append(pil_to_base64(result)) | |
| # print(f"[ImageGen] Generated {len(images)} images successfully.") | |
| # return images | |
| # core/image_generator.py | |
| import os | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| from typing import List | |
| from io import BytesIO | |
| import base64 | |
| from PIL import Image | |
| # ---------------- CACHE & MODEL DIRECTORIES ---------------- | |
| HF_CACHE_DIR = Path("/tmp/hf_cache") | |
| MODEL_DIR = Path("/tmp/models/realvisxl_v4") | |
| # Create directories safely (no chmod) | |
| for d in [HF_CACHE_DIR, MODEL_DIR]: | |
| d.mkdir(parents=True, exist_ok=True) | |
| # Apply environment variables BEFORE any Hugging Face usage | |
| os.environ.update({ | |
| "HF_HOME": str(HF_CACHE_DIR), | |
| "TRANSFORMERS_CACHE": str(HF_CACHE_DIR), | |
| "XDG_CACHE_HOME": str(HF_CACHE_DIR), | |
| "HF_DATASETS_CACHE": str(HF_CACHE_DIR), | |
| "HF_MODULES_CACHE": str(HF_CACHE_DIR), | |
| }) | |
| # ---------------- MODEL CONFIG ---------------- | |
| MODEL_REPO = "SG161222/RealVisXL_V4.0" | |
| MODEL_FILENAME = "RealVisXL_V4.0.safetensors" | |
| # ---------------- MODEL DOWNLOAD ---------------- | |
| def download_model() -> Path: | |
| """ | |
| Downloads RealVisXL V4.0 model if not present. | |
| Returns local path. | |
| """ | |
| model_path = MODEL_DIR / MODEL_FILENAME | |
| if not model_path.exists(): | |
| print("[ImageGen] Downloading RealVisXL V4.0 model...") | |
| model_path = Path( | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME, | |
| cache_dir=str(HF_CACHE_DIR), | |
| force_download=False, | |
| resume_download=True, # safer if download interrupted | |
| ) | |
| ) | |
| print(f"[ImageGen] Model downloaded to: {model_path}") | |
| else: | |
| print("[ImageGen] Model already exists. Skipping download.") | |
| return model_path | |
| # ---------------- PIPELINE LOAD ---------------- | |
| def load_pipeline() -> StableDiffusionXLPipeline: | |
| """ | |
| Loads the RealVisXL V4.0 model for image generation. | |
| """ | |
| model_path = download_model() | |
| print("[ImageGen] Loading model into pipeline...") | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| str(model_path), | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| if torch.cuda.is_available(): | |
| pipe.to("cuda") | |
| else: | |
| pipe.to("cpu") | |
| # Optional: skip safety checker to save memory/performance | |
| pipe.safety_checker = None | |
| # Enable attention slicing for memory-efficient CPU usage | |
| pipe.enable_attention_slicing() | |
| print("[ImageGen] Model ready.") | |
| return pipe | |
| # ---------------- GLOBAL PIPELINE CACHE ---------------- | |
| pipe: StableDiffusionXLPipeline | None = None | |
| # ---------------- UTILITY: PIL → BASE64 ---------------- | |
| def pil_to_base64(img: Image.Image) -> str: | |
| """ | |
| Converts PIL image to base64 string for frontend. | |
| """ | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" | |
| # ---------------- IMAGE GENERATION ---------------- | |
| def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]: | |
| """ | |
| Generates high-quality images using RealVisXL V4.0. | |
| Returns a list of base64-encoded PNGs. | |
| """ | |
| global pipe | |
| if pipe is None: | |
| pipe = load_pipeline() | |
| print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' seed={seed}") | |
| images: List[str] = [] | |
| for i in range(num_images): | |
| generator = None | |
| if seed is not None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| generator = torch.Generator(device).manual_seed(seed + i) | |
| try: | |
| result = pipe(prompt, num_inference_steps=30, generator=generator).images[0] | |
| images.append(pil_to_base64(result)) | |
| except Exception as e: | |
| print(f"[ImageGen] ⚠️ Generation failed on image {i}: {e}") | |
| continue | |
| print(f"[ImageGen] Generated {len(images)} image(s) successfully.") | |
| return images | |