Spaces:
Sleeping
Sleeping
Update core/image_generator.py
Browse files- core/image_generator.py +22 -12
core/image_generator.py
CHANGED
|
@@ -116,7 +116,7 @@
|
|
| 116 |
# return images
|
| 117 |
|
| 118 |
|
| 119 |
-
|
| 120 |
import os
|
| 121 |
import torch
|
| 122 |
from diffusers import StableDiffusionXLPipeline
|
|
@@ -127,15 +127,15 @@ from io import BytesIO
|
|
| 127 |
import base64
|
| 128 |
from PIL import Image
|
| 129 |
|
| 130 |
-
# ---------------- CACHE & MODEL
|
| 131 |
HF_CACHE_DIR = Path("/tmp/hf_cache")
|
| 132 |
MODEL_DIR = Path("/tmp/models/realvisxl_v4")
|
| 133 |
|
|
|
|
| 134 |
for d in [HF_CACHE_DIR, MODEL_DIR]:
|
| 135 |
d.mkdir(parents=True, exist_ok=True)
|
| 136 |
-
os.chmod(d, 0o777)
|
| 137 |
|
| 138 |
-
#
|
| 139 |
os.environ.update({
|
| 140 |
"HF_HOME": str(HF_CACHE_DIR),
|
| 141 |
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
|
|
@@ -150,7 +150,10 @@ MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
|
|
| 150 |
|
| 151 |
# ---------------- MODEL DOWNLOAD ----------------
|
| 152 |
def download_model() -> Path:
|
| 153 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 154 |
model_path = MODEL_DIR / MODEL_FILENAME
|
| 155 |
if not model_path.exists():
|
| 156 |
print("[ImageGen] Downloading RealVisXL V4.0 model...")
|
|
@@ -161,7 +164,7 @@ def download_model() -> Path:
|
|
| 161 |
local_dir=str(MODEL_DIR),
|
| 162 |
cache_dir=str(HF_CACHE_DIR),
|
| 163 |
force_download=False,
|
| 164 |
-
resume_download=True,
|
| 165 |
)
|
| 166 |
)
|
| 167 |
print(f"[ImageGen] Model downloaded to: {model_path}")
|
|
@@ -171,7 +174,9 @@ def download_model() -> Path:
|
|
| 171 |
|
| 172 |
# ---------------- PIPELINE LOAD ----------------
|
| 173 |
def load_pipeline() -> StableDiffusionXLPipeline:
|
| 174 |
-
"""
|
|
|
|
|
|
|
| 175 |
model_path = download_model()
|
| 176 |
print("[ImageGen] Loading model into pipeline...")
|
| 177 |
|
|
@@ -185,8 +190,10 @@ def load_pipeline() -> StableDiffusionXLPipeline:
|
|
| 185 |
else:
|
| 186 |
pipe.to("cpu")
|
| 187 |
|
| 188 |
-
|
| 189 |
-
pipe.
|
|
|
|
|
|
|
| 190 |
|
| 191 |
print("[ImageGen] Model ready.")
|
| 192 |
return pipe
|
|
@@ -194,8 +201,11 @@ def load_pipeline() -> StableDiffusionXLPipeline:
|
|
| 194 |
# ---------------- GLOBAL PIPELINE CACHE ----------------
|
| 195 |
pipe: StableDiffusionXLPipeline | None = None
|
| 196 |
|
| 197 |
-
# ---------------- UTILITY: PIL β
|
| 198 |
def pil_to_base64(img: Image.Image) -> str:
|
|
|
|
|
|
|
|
|
|
| 199 |
buffered = BytesIO()
|
| 200 |
img.save(buffered, format="PNG")
|
| 201 |
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
|
|
@@ -204,13 +214,13 @@ def pil_to_base64(img: Image.Image) -> str:
|
|
| 204 |
def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]:
|
| 205 |
"""
|
| 206 |
Generates high-quality images using RealVisXL V4.0.
|
| 207 |
-
Returns list of base64-encoded PNGs.
|
| 208 |
"""
|
| 209 |
global pipe
|
| 210 |
if pipe is None:
|
| 211 |
pipe = load_pipeline()
|
| 212 |
|
| 213 |
-
print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}'
|
| 214 |
images: List[str] = []
|
| 215 |
|
| 216 |
for i in range(num_images):
|
|
|
|
| 116 |
# return images
|
| 117 |
|
| 118 |
|
| 119 |
+
# core/image_generator.py
|
| 120 |
import os
|
| 121 |
import torch
|
| 122 |
from diffusers import StableDiffusionXLPipeline
|
|
|
|
| 127 |
import base64
|
| 128 |
from PIL import Image
|
| 129 |
|
| 130 |
+
# ---------------- CACHE & MODEL DIRECTORIES ----------------
|
| 131 |
HF_CACHE_DIR = Path("/tmp/hf_cache")
|
| 132 |
MODEL_DIR = Path("/tmp/models/realvisxl_v4")
|
| 133 |
|
| 134 |
+
# Create directories safely (no chmod)
|
| 135 |
for d in [HF_CACHE_DIR, MODEL_DIR]:
|
| 136 |
d.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 137 |
|
| 138 |
+
# Apply environment variables BEFORE any Hugging Face usage
|
| 139 |
os.environ.update({
|
| 140 |
"HF_HOME": str(HF_CACHE_DIR),
|
| 141 |
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
|
|
|
|
| 150 |
|
| 151 |
# ---------------- MODEL DOWNLOAD ----------------
|
| 152 |
def download_model() -> Path:
|
| 153 |
+
"""
|
| 154 |
+
Downloads RealVisXL V4.0 model if not present.
|
| 155 |
+
Returns local path.
|
| 156 |
+
"""
|
| 157 |
model_path = MODEL_DIR / MODEL_FILENAME
|
| 158 |
if not model_path.exists():
|
| 159 |
print("[ImageGen] Downloading RealVisXL V4.0 model...")
|
|
|
|
| 164 |
local_dir=str(MODEL_DIR),
|
| 165 |
cache_dir=str(HF_CACHE_DIR),
|
| 166 |
force_download=False,
|
| 167 |
+
resume_download=True, # safer if download interrupted
|
| 168 |
)
|
| 169 |
)
|
| 170 |
print(f"[ImageGen] Model downloaded to: {model_path}")
|
|
|
|
| 174 |
|
| 175 |
# ---------------- PIPELINE LOAD ----------------
|
| 176 |
def load_pipeline() -> StableDiffusionXLPipeline:
|
| 177 |
+
"""
|
| 178 |
+
Loads the RealVisXL V4.0 model for image generation.
|
| 179 |
+
"""
|
| 180 |
model_path = download_model()
|
| 181 |
print("[ImageGen] Loading model into pipeline...")
|
| 182 |
|
|
|
|
| 190 |
else:
|
| 191 |
pipe.to("cpu")
|
| 192 |
|
| 193 |
+
# Optional: skip safety checker to save memory/performance
|
| 194 |
+
pipe.safety_checker = None
|
| 195 |
+
# Enable attention slicing for memory-efficient CPU usage
|
| 196 |
+
pipe.enable_attention_slicing()
|
| 197 |
|
| 198 |
print("[ImageGen] Model ready.")
|
| 199 |
return pipe
|
|
|
|
| 201 |
# ---------------- GLOBAL PIPELINE CACHE ----------------
|
| 202 |
pipe: StableDiffusionXLPipeline | None = None
|
| 203 |
|
| 204 |
+
# ---------------- UTILITY: PIL β BASE64 ----------------
|
| 205 |
def pil_to_base64(img: Image.Image) -> str:
|
| 206 |
+
"""
|
| 207 |
+
Converts PIL image to base64 string for frontend.
|
| 208 |
+
"""
|
| 209 |
buffered = BytesIO()
|
| 210 |
img.save(buffered, format="PNG")
|
| 211 |
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
|
|
|
|
| 214 |
def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]:
|
| 215 |
"""
|
| 216 |
Generates high-quality images using RealVisXL V4.0.
|
| 217 |
+
Returns a list of base64-encoded PNGs.
|
| 218 |
"""
|
| 219 |
global pipe
|
| 220 |
if pipe is None:
|
| 221 |
pipe = load_pipeline()
|
| 222 |
|
| 223 |
+
print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' seed={seed}")
|
| 224 |
images: List[str] = []
|
| 225 |
|
| 226 |
for i in range(num_images):
|