admaker / core /image_generator.py
karthikeya1212's picture
Upload 24 files
eb8c5e1 verified
raw
history blame
2.99 kB
# core/image_generator.py
import os
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from pathlib import Path
from typing import List
# ---------------- MODEL CONFIG ----------------
MODEL_REPO = "SG161222/RealVisXL_V4.0"
MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
MODEL_DIR = Path("/tmp/models/realvisxl_v4")
os.makedirs(MODEL_DIR, exist_ok=True)
# ---------------- MODEL DOWNLOAD ----------------
def download_model() -> Path:
"""
Downloads RealVisXL V4.0 model if not present.
Returns the local model path.
"""
model_path = MODEL_DIR / MODEL_FILENAME
if not model_path.exists():
print("[ImageGen] Downloading RealVisXL V4.0 model...")
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
local_dir=str(MODEL_DIR),
force_download=False,
)
print(f"[ImageGen] Model downloaded to: {model_path}")
else:
print("[ImageGen] Model already exists. Skipping download.")
return model_path
# ---------------- PIPELINE LOAD ----------------
def load_pipeline() -> StableDiffusionXLPipeline:
"""
Loads the RealVisXL V4.0 model for image generation.
"""
model_path = download_model()
print("[ImageGen] Loading model into pipeline...")
pipe = StableDiffusionXLPipeline.from_single_file(
str(model_path),
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
if torch.cuda.is_available():
pipe.to("cuda")
print("[ImageGen] Model ready.")
return pipe
# ---------------- GLOBAL PIPELINE CACHE ----------------
pipe: StableDiffusionXLPipeline | None = None
# ---------------- IMAGE GENERATION ----------------
def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List:
"""
Generates high-quality images using RealVisXL V4.0.
Supports deterministic generation using a seed.
Args:
prompt (str): Text prompt for image generation.
seed (int, optional): Seed for deterministic generation.
num_images (int): Number of images to generate.
Returns:
List: Generated PIL images.
"""
global pipe
if pipe is None:
pipe = load_pipeline()
print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}")
images = []
for i in range(num_images):
generator = None
if seed is not None:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device).manual_seed(seed + i) # slightly vary keyframes
result = pipe(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(result)
print(f"[ImageGen] Generated {len(images)} images successfully.")
return images