admaker / core /image_generator.py
karthikeya1212's picture
Update core/image_generator.py
aff9467 verified
raw
history blame
8.14 kB
# # core/image_generator.py
# import os
# import torch
# from diffusers import StableDiffusionXLPipeline
# from huggingface_hub import hf_hub_download
# from pathlib import Path
# from typing import List
# from io import BytesIO
# import base64
# from PIL import Image
# # Set cache and model directories early
# HF_CACHE_DIR = Path("/tmp/hf_cache")
# HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# os.chmod(HF_CACHE_DIR, 0o777)
# os.environ["HF_HOME"] = str(HF_CACHE_DIR)
# os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
# os.environ["XDG_CACHE_HOME"] = str(HF_CACHE_DIR)
# os.environ["HF_DATASETS_CACHE"] = str(HF_CACHE_DIR)
# os.environ["HF_MODULES_CACHE"] = str(HF_CACHE_DIR)
# MODEL_DIR = Path("/tmp/models/realvisxl_v4")
# MODEL_DIR.mkdir(parents=True, exist_ok=True)
# os.chmod(MODEL_DIR, 0o777)
# # ---------------- MODEL CONFIG ----------------
# MODEL_REPO = "SG161222/RealVisXL_V4.0"
# MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
# MODEL_DIR = Path("/tmp/models/realvisxl_v4")
# os.makedirs(MODEL_DIR, exist_ok=True)
# # ---------------- MODEL DOWNLOAD ----------------
# def download_model() -> Path:
# """
# Downloads RealVisXL V4.0 model if not present.
# Returns the local model path.
# """
# model_path = MODEL_DIR / MODEL_FILENAME
# if not model_path.exists():
# print("[ImageGen] Downloading RealVisXL V4.0 model...")
# model_path = hf_hub_download(
# repo_id=MODEL_REPO,
# filename=MODEL_FILENAME,
# local_dir=str(MODEL_DIR),
# cache_dir=str(HF_CACHE_DIR), # ensure writable cache is used
# force_download=False,
# )
# print(f"[ImageGen] Model downloaded to: {model_path}")
# else:
# print("[ImageGen] Model already exists. Skipping download.")
# return model_path
# # ---------------- PIPELINE LOAD ----------------
# def load_pipeline() -> StableDiffusionXLPipeline:
# """
# Loads the RealVisXL V4.0 model for image generation.
# """
# model_path = download_model()
# print("[ImageGen] Loading model into pipeline...")
# pipe = StableDiffusionXLPipeline.from_single_file(
# str(model_path),
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# )
# if torch.cuda.is_available():
# pipe.to("cuda")
# print("[ImageGen] Model ready.")
# return pipe
# # ---------------- GLOBAL PIPELINE CACHE ----------------
# pipe: StableDiffusionXLPipeline | None = None
# # ---------------- UTILITY: PIL TO BASE64 ----------------
# def pil_to_base64(img: Image.Image) -> str:
# """
# Converts a PIL image to a base64 string for frontend display.
# """
# buffered = BytesIO()
# img.save(buffered, format="PNG")
# img_bytes = buffered.getvalue()
# img_b64 = base64.b64encode(img_bytes).decode("utf-8")
# return f"data:image/png;base64,{img_b64}"
# # ---------------- IMAGE GENERATION ----------------
# def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List[str]:
# """
# Generates high-quality images using RealVisXL V4.0.
# Supports deterministic generation using a seed.
# Args:
# prompt (str): Text prompt for image generation.
# seed (int, optional): Seed for deterministic generation.
# num_images (int): Number of images to generate.
# Returns:
# List[str]: List of base64-encoded images.
# """
# global pipe
# if pipe is None:
# pipe = load_pipeline()
# print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}")
# images: List[str] = []
# for i in range(num_images):
# generator = None
# if seed is not None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# generator = torch.Generator(device).manual_seed(seed + i)
# result = pipe(prompt, num_inference_steps=30, generator=generator).images[0]
# images.append(pil_to_base64(result))
# print(f"[ImageGen] Generated {len(images)} images successfully.")
# return images
# core/image_generator.py
import os
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from pathlib import Path
from typing import List
from io import BytesIO
import base64
from PIL import Image
# ---------------- CACHE & MODEL DIRECTORIES ----------------
HF_CACHE_DIR = Path("/tmp/hf_cache")
MODEL_DIR = Path("/tmp/models/realvisxl_v4")
# Create directories safely (no chmod)
for d in [HF_CACHE_DIR, MODEL_DIR]:
d.mkdir(parents=True, exist_ok=True)
# Apply environment variables BEFORE any Hugging Face usage
os.environ.update({
"HF_HOME": str(HF_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
"XDG_CACHE_HOME": str(HF_CACHE_DIR),
"HF_DATASETS_CACHE": str(HF_CACHE_DIR),
"HF_MODULES_CACHE": str(HF_CACHE_DIR),
})
# ---------------- MODEL CONFIG ----------------
MODEL_REPO = "SG161222/RealVisXL_V4.0"
MODEL_FILENAME = "RealVisXL_V4.0.safetensors"
# ---------------- MODEL DOWNLOAD ----------------
def download_model() -> Path:
"""
Downloads RealVisXL V4.0 model if not present.
Returns local path.
"""
model_path = MODEL_DIR / MODEL_FILENAME
if not model_path.exists():
print("[ImageGen] Downloading RealVisXL V4.0 model...")
model_path = Path(
hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
cache_dir=str(HF_CACHE_DIR),
force_download=False,
resume_download=True, # safer if download interrupted
)
)
print(f"[ImageGen] Model downloaded to: {model_path}")
else:
print("[ImageGen] Model already exists. Skipping download.")
return model_path
# ---------------- PIPELINE LOAD ----------------
def load_pipeline() -> StableDiffusionXLPipeline:
"""
Loads the RealVisXL V4.0 model for image generation.
"""
model_path = download_model()
print("[ImageGen] Loading model into pipeline...")
pipe = StableDiffusionXLPipeline.from_single_file(
str(model_path),
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
if torch.cuda.is_available():
pipe.to("cuda")
else:
pipe.to("cpu")
# Optional: skip safety checker to save memory/performance
pipe.safety_checker = None
# Enable attention slicing for memory-efficient CPU usage
pipe.enable_attention_slicing()
print("[ImageGen] Model ready.")
return pipe
# ---------------- GLOBAL PIPELINE CACHE ----------------
pipe: StableDiffusionXLPipeline | None = None
# ---------------- UTILITY: PIL → BASE64 ----------------
def pil_to_base64(img: Image.Image) -> str:
"""
Converts PIL image to base64 string for frontend.
"""
buffered = BytesIO()
img.save(buffered, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
# ---------------- IMAGE GENERATION ----------------
def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]:
"""
Generates high-quality images using RealVisXL V4.0.
Returns a list of base64-encoded PNGs.
"""
global pipe
if pipe is None:
pipe = load_pipeline()
print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' seed={seed}")
images: List[str] = []
for i in range(num_images):
generator = None
if seed is not None:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device).manual_seed(seed + i)
try:
result = pipe(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(pil_to_base64(result))
except Exception as e:
print(f"[ImageGen] ⚠️ Generation failed on image {i}: {e}")
continue
print(f"[ImageGen] Generated {len(images)} image(s) successfully.")
return images