Spaces:
Sleeping
Sleeping
Update core/image_generator.py
Browse files- core/image_generator.py +29 -57
core/image_generator.py
CHANGED
|
@@ -272,12 +272,11 @@
|
|
| 272 |
|
| 273 |
|
| 274 |
|
| 275 |
-
|
| 276 |
import os
|
| 277 |
from pathlib import Path
|
| 278 |
import gc
|
| 279 |
import torch
|
| 280 |
-
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
| 281 |
from huggingface_hub import hf_hub_download
|
| 282 |
from typing import Dict, Any
|
| 283 |
from PIL import Image
|
|
@@ -322,7 +321,6 @@ def safe_expanduser(path):
|
|
| 322 |
return os.path.expanduser_original(path)
|
| 323 |
|
| 324 |
os.path.expanduser = safe_expanduser
|
| 325 |
-
|
| 326 |
tempfile.tempdir = str(HF_CACHE_DIR)
|
| 327 |
|
| 328 |
print("[DEBUG] β
Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
|
|
@@ -344,34 +342,11 @@ print("[DEBUG] β
Seed directory:", SEED_DIR)
|
|
| 344 |
# --------------------------------------------------------------
|
| 345 |
# MODEL CONFIG
|
| 346 |
# --------------------------------------------------------------
|
| 347 |
-
MODEL_REPO = "
|
| 348 |
-
MODEL_FILENAME = "dreamshaper_8.safetensors"
|
| 349 |
-
|
| 350 |
# ---------------- GLOBAL PIPELINE CACHE ----------------
|
| 351 |
-
pipe: StableDiffusionXLPipeline | None = None
|
| 352 |
img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
|
| 353 |
|
| 354 |
-
# --------------------------------------------------------------
|
| 355 |
-
# MODEL DOWNLOAD
|
| 356 |
-
# --------------------------------------------------------------
|
| 357 |
-
def download_model() -> Path:
|
| 358 |
-
model_path = MODEL_DIR / MODEL_FILENAME
|
| 359 |
-
if not model_path.exists():
|
| 360 |
-
print("[ImageGen] Downloading DreamShaper SD1.5 model...")
|
| 361 |
-
model_path = Path(
|
| 362 |
-
hf_hub_download(
|
| 363 |
-
repo_id=MODEL_REPO,
|
| 364 |
-
filename=MODEL_FILENAME,
|
| 365 |
-
cache_dir=str(HF_CACHE_DIR),
|
| 366 |
-
force_download=False,
|
| 367 |
-
resume_download=True,
|
| 368 |
-
)
|
| 369 |
-
)
|
| 370 |
-
print(f"[ImageGen] β
Model downloaded to: {model_path}")
|
| 371 |
-
else:
|
| 372 |
-
print("[ImageGen] β
Model already exists at:", model_path)
|
| 373 |
-
return model_path
|
| 374 |
-
|
| 375 |
# --------------------------------------------------------------
|
| 376 |
# MEMORY-SAFE PIPELINE MANAGER
|
| 377 |
# --------------------------------------------------------------
|
|
@@ -399,18 +374,22 @@ def unload_pipelines(target="all"):
|
|
| 399 |
torch.cuda.empty_cache()
|
| 400 |
print("[ImageGen] β
Memory cleared.")
|
| 401 |
|
| 402 |
-
def safe_load_pipeline(
|
| 403 |
-
"""
|
| 404 |
try:
|
| 405 |
-
print(f"[ImageGen] π Loading
|
| 406 |
-
pipe =
|
| 407 |
-
|
| 408 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
|
|
| 409 |
)
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
| 411 |
return pipe
|
| 412 |
except Exception as e:
|
| 413 |
-
print(f"[ImageGen] β Failed to load {
|
| 414 |
unload_pipelines()
|
| 415 |
gc.collect()
|
| 416 |
if torch.cuda.is_available():
|
|
@@ -420,26 +399,17 @@ def safe_load_pipeline(pipeline_class, model_path):
|
|
| 420 |
def load_pipeline():
|
| 421 |
global pipe
|
| 422 |
unload_pipelines(target="pipe")
|
| 423 |
-
model_path = download_model()
|
| 424 |
print("[ImageGen] Loading main (txt2img) pipeline...")
|
| 425 |
-
pipe = safe_load_pipeline(
|
| 426 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 427 |
-
pipe.to(device)
|
| 428 |
-
pipe.safety_checker = None
|
| 429 |
-
pipe.enable_attention_slicing()
|
| 430 |
print("[ImageGen] β
Text-to-image pipeline ready.")
|
| 431 |
return pipe
|
| 432 |
|
| 433 |
def load_img2img_pipeline():
|
| 434 |
global img2img_pipe
|
| 435 |
unload_pipelines(target="img2img_pipe")
|
| 436 |
-
model_path = download_model()
|
| 437 |
print("[ImageGen] Loading img2img pipeline...")
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
img2img_pipe.to(device)
|
| 441 |
-
img2img_pipe.safety_checker = None
|
| 442 |
-
img2img_pipe.enable_attention_slicing()
|
| 443 |
print("[ImageGen] β
Img2Img pipeline ready.")
|
| 444 |
return img2img_pipe
|
| 445 |
|
|
@@ -529,13 +499,15 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 529 |
pipe = load_pipeline()
|
| 530 |
images = []
|
| 531 |
for i in range(num_images):
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
|
| 274 |
|
|
|
|
| 275 |
import os
|
| 276 |
from pathlib import Path
|
| 277 |
import gc
|
| 278 |
import torch
|
| 279 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForText2Image
|
| 280 |
from huggingface_hub import hf_hub_download
|
| 281 |
from typing import Dict, Any
|
| 282 |
from PIL import Image
|
|
|
|
| 321 |
return os.path.expanduser_original(path)
|
| 322 |
|
| 323 |
os.path.expanduser = safe_expanduser
|
|
|
|
| 324 |
tempfile.tempdir = str(HF_CACHE_DIR)
|
| 325 |
|
| 326 |
print("[DEBUG] β
Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
|
|
|
|
| 342 |
# --------------------------------------------------------------
|
| 343 |
# MODEL CONFIG
|
| 344 |
# --------------------------------------------------------------
|
| 345 |
+
MODEL_REPO = "lykon/dreamshaper-8" # Use Hugging Face repo
|
|
|
|
|
|
|
| 346 |
# ---------------- GLOBAL PIPELINE CACHE ----------------
|
| 347 |
+
pipe: StableDiffusionXLPipeline | AutoPipelineForText2Image | None = None
|
| 348 |
img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
# --------------------------------------------------------------
|
| 351 |
# MEMORY-SAFE PIPELINE MANAGER
|
| 352 |
# --------------------------------------------------------------
|
|
|
|
| 374 |
torch.cuda.empty_cache()
|
| 375 |
print("[ImageGen] β
Memory cleared.")
|
| 376 |
|
| 377 |
+
def safe_load_pipeline(pretrained_model_name):
|
| 378 |
+
"""Load DreamShaper SD1.5 safely via from_pretrained."""
|
| 379 |
try:
|
| 380 |
+
print(f"[ImageGen] π Loading model {pretrained_model_name} ...")
|
| 381 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
| 382 |
+
pretrained_model_name,
|
| 383 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 384 |
+
variant="fp16" # use fp16 if possible
|
| 385 |
)
|
| 386 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 387 |
+
pipe = pipe.to(device)
|
| 388 |
+
pipe.enable_attention_slicing()
|
| 389 |
+
print(f"[ImageGen] β
Successfully loaded {pretrained_model_name}.")
|
| 390 |
return pipe
|
| 391 |
except Exception as e:
|
| 392 |
+
print(f"[ImageGen] β Failed to load {pretrained_model_name}: {e}")
|
| 393 |
unload_pipelines()
|
| 394 |
gc.collect()
|
| 395 |
if torch.cuda.is_available():
|
|
|
|
| 399 |
def load_pipeline():
|
| 400 |
global pipe
|
| 401 |
unload_pipelines(target="pipe")
|
|
|
|
| 402 |
print("[ImageGen] Loading main (txt2img) pipeline...")
|
| 403 |
+
pipe = safe_load_pipeline(MODEL_REPO)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
print("[ImageGen] β
Text-to-image pipeline ready.")
|
| 405 |
return pipe
|
| 406 |
|
| 407 |
def load_img2img_pipeline():
|
| 408 |
global img2img_pipe
|
| 409 |
unload_pipelines(target="img2img_pipe")
|
|
|
|
| 410 |
print("[ImageGen] Loading img2img pipeline...")
|
| 411 |
+
# For DreamShaper, img2img uses the same pipeline
|
| 412 |
+
img2img_pipe = safe_load_pipeline(MODEL_REPO)
|
|
|
|
|
|
|
|
|
|
| 413 |
print("[ImageGen] β
Img2Img pipeline ready.")
|
| 414 |
return img2img_pipe
|
| 415 |
|
|
|
|
| 499 |
pipe = load_pipeline()
|
| 500 |
images = []
|
| 501 |
for i in range(num_images):
|
| 502 |
+
gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
|
| 503 |
+
try:
|
| 504 |
+
img = pipe(prompt_or_json, num_inference_steps=30, generator=gen).images[0]
|
| 505 |
+
img_path = TMP_DIR / f"prompt_{i}.png"
|
| 506 |
+
img.save(img_path)
|
| 507 |
+
images.append(pil_to_base64(img))
|
| 508 |
+
except Exception as e:
|
| 509 |
+
print(f"[ImageGen] β οΈ Failed on image {i}: {e}")
|
| 510 |
+
|
| 511 |
+
unload_pipelines(target="pipe")
|
| 512 |
+
print(f"[ImageGen] β
Generated {len(images)} image(s) successfully.")
|
| 513 |
+
return images
|