karthikeya1212 commited on
Commit
d88d679
Β·
verified Β·
1 Parent(s): 1fa5af3

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +29 -57
core/image_generator.py CHANGED
@@ -272,12 +272,11 @@
272
 
273
 
274
 
275
-
276
  import os
277
  from pathlib import Path
278
  import gc
279
  import torch
280
- from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
281
  from huggingface_hub import hf_hub_download
282
  from typing import Dict, Any
283
  from PIL import Image
@@ -322,7 +321,6 @@ def safe_expanduser(path):
322
  return os.path.expanduser_original(path)
323
 
324
  os.path.expanduser = safe_expanduser
325
-
326
  tempfile.tempdir = str(HF_CACHE_DIR)
327
 
328
  print("[DEBUG] βœ… Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
@@ -344,34 +342,11 @@ print("[DEBUG] βœ… Seed directory:", SEED_DIR)
344
  # --------------------------------------------------------------
345
  # MODEL CONFIG
346
  # --------------------------------------------------------------
347
- MODEL_REPO = "Lykon/dreamshaper-8"
348
- MODEL_FILENAME = "dreamshaper_8.safetensors"
349
-
350
  # ---------------- GLOBAL PIPELINE CACHE ----------------
351
- pipe: StableDiffusionXLPipeline | None = None
352
  img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
353
 
354
- # --------------------------------------------------------------
355
- # MODEL DOWNLOAD
356
- # --------------------------------------------------------------
357
- def download_model() -> Path:
358
- model_path = MODEL_DIR / MODEL_FILENAME
359
- if not model_path.exists():
360
- print("[ImageGen] Downloading DreamShaper SD1.5 model...")
361
- model_path = Path(
362
- hf_hub_download(
363
- repo_id=MODEL_REPO,
364
- filename=MODEL_FILENAME,
365
- cache_dir=str(HF_CACHE_DIR),
366
- force_download=False,
367
- resume_download=True,
368
- )
369
- )
370
- print(f"[ImageGen] βœ… Model downloaded to: {model_path}")
371
- else:
372
- print("[ImageGen] βœ… Model already exists at:", model_path)
373
- return model_path
374
-
375
  # --------------------------------------------------------------
376
  # MEMORY-SAFE PIPELINE MANAGER
377
  # --------------------------------------------------------------
@@ -399,18 +374,22 @@ def unload_pipelines(target="all"):
399
  torch.cuda.empty_cache()
400
  print("[ImageGen] βœ… Memory cleared.")
401
 
402
- def safe_load_pipeline(pipeline_class, model_path):
403
- """Safely load a pipeline with retry logic and memory handling."""
404
  try:
405
- print(f"[ImageGen] πŸ”„ Loading {pipeline_class.__name__} from {model_path} ...")
406
- pipe = pipeline_class.from_single_file(
407
- model_path,
408
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
409
  )
410
- print(f"[ImageGen] βœ… Successfully loaded {pipeline_class.__name__}.")
 
 
 
411
  return pipe
412
  except Exception as e:
413
- print(f"[ImageGen] ❌ Failed to load {pipeline_class.__name__}: {e}")
414
  unload_pipelines()
415
  gc.collect()
416
  if torch.cuda.is_available():
@@ -420,26 +399,17 @@ def safe_load_pipeline(pipeline_class, model_path):
420
  def load_pipeline():
421
  global pipe
422
  unload_pipelines(target="pipe")
423
- model_path = download_model()
424
  print("[ImageGen] Loading main (txt2img) pipeline...")
425
- pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
426
- device = "cuda" if torch.cuda.is_available() else "cpu"
427
- pipe.to(device)
428
- pipe.safety_checker = None
429
- pipe.enable_attention_slicing()
430
  print("[ImageGen] βœ… Text-to-image pipeline ready.")
431
  return pipe
432
 
433
  def load_img2img_pipeline():
434
  global img2img_pipe
435
  unload_pipelines(target="img2img_pipe")
436
- model_path = download_model()
437
  print("[ImageGen] Loading img2img pipeline...")
438
- img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
439
- device = "cuda" if torch.cuda.is_available() else "cpu"
440
- img2img_pipe.to(device)
441
- img2img_pipe.safety_checker = None
442
- img2img_pipe.enable_attention_slicing()
443
  print("[ImageGen] βœ… Img2Img pipeline ready.")
444
  return img2img_pipe
445
 
@@ -529,13 +499,15 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
529
  pipe = load_pipeline()
530
  images = []
531
  for i in range(num_images):
532
- gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
533
- try:
534
- img = pipe(prompt_or_json, num_inference_steps=30, generator=gen).images[0]
535
- img_path = TMP_DIR / f"prompt_{i}.png"
536
- img.save(img_path)
537
- images.append(pil_to_base64(img))
538
- except Exception as e:
539
- print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
540
-
541
-
 
 
 
272
 
273
 
274
 
 
275
  import os
276
  from pathlib import Path
277
  import gc
278
  import torch
279
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForText2Image
280
  from huggingface_hub import hf_hub_download
281
  from typing import Dict, Any
282
  from PIL import Image
 
321
  return os.path.expanduser_original(path)
322
 
323
  os.path.expanduser = safe_expanduser
 
324
  tempfile.tempdir = str(HF_CACHE_DIR)
325
 
326
  print("[DEBUG] βœ… Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
 
342
  # --------------------------------------------------------------
343
  # MODEL CONFIG
344
  # --------------------------------------------------------------
345
+ MODEL_REPO = "lykon/dreamshaper-8" # Use Hugging Face repo
 
 
346
  # ---------------- GLOBAL PIPELINE CACHE ----------------
347
+ pipe: StableDiffusionXLPipeline | AutoPipelineForText2Image | None = None
348
  img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  # --------------------------------------------------------------
351
  # MEMORY-SAFE PIPELINE MANAGER
352
  # --------------------------------------------------------------
 
374
  torch.cuda.empty_cache()
375
  print("[ImageGen] βœ… Memory cleared.")
376
 
377
+ def safe_load_pipeline(pretrained_model_name):
378
+ """Load DreamShaper SD1.5 safely via from_pretrained."""
379
  try:
380
+ print(f"[ImageGen] πŸ”„ Loading model {pretrained_model_name} ...")
381
+ pipe = AutoPipelineForText2Image.from_pretrained(
382
+ pretrained_model_name,
383
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
384
+ variant="fp16" # use fp16 if possible
385
  )
386
+ device = "cuda" if torch.cuda.is_available() else "cpu"
387
+ pipe = pipe.to(device)
388
+ pipe.enable_attention_slicing()
389
+ print(f"[ImageGen] βœ… Successfully loaded {pretrained_model_name}.")
390
  return pipe
391
  except Exception as e:
392
+ print(f"[ImageGen] ❌ Failed to load {pretrained_model_name}: {e}")
393
  unload_pipelines()
394
  gc.collect()
395
  if torch.cuda.is_available():
 
399
  def load_pipeline():
400
  global pipe
401
  unload_pipelines(target="pipe")
 
402
  print("[ImageGen] Loading main (txt2img) pipeline...")
403
+ pipe = safe_load_pipeline(MODEL_REPO)
 
 
 
 
404
  print("[ImageGen] βœ… Text-to-image pipeline ready.")
405
  return pipe
406
 
407
  def load_img2img_pipeline():
408
  global img2img_pipe
409
  unload_pipelines(target="img2img_pipe")
 
410
  print("[ImageGen] Loading img2img pipeline...")
411
+ # For DreamShaper, img2img uses the same pipeline
412
+ img2img_pipe = safe_load_pipeline(MODEL_REPO)
 
 
 
413
  print("[ImageGen] βœ… Img2Img pipeline ready.")
414
  return img2img_pipe
415
 
 
499
  pipe = load_pipeline()
500
  images = []
501
  for i in range(num_images):
502
+ gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
503
+ try:
504
+ img = pipe(prompt_or_json, num_inference_steps=30, generator=gen).images[0]
505
+ img_path = TMP_DIR / f"prompt_{i}.png"
506
+ img.save(img_path)
507
+ images.append(pil_to_base64(img))
508
+ except Exception as e:
509
+ print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
510
+
511
+ unload_pipelines(target="pipe")
512
+ print(f"[ImageGen] βœ… Generated {len(images)} image(s) successfully.")
513
+ return images