karthikeya1212 commited on
Commit
25e7823
·
verified ·
1 Parent(s): 756c289

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +85 -85
core/image_generator.py CHANGED
@@ -1,85 +1,85 @@
1
- # core/image_generator.py
2
- import os
3
- import torch
4
- from diffusers import StableDiffusionXLPipeline
5
- from huggingface_hub import hf_hub_download
6
- from pathlib import Path
7
- from typing import List
8
-
9
- # ---------------- MODEL CONFIG ----------------
10
- MODEL_REPO = "SG161222/RealVisXL_V4.0"
11
- MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
12
- MODEL_DIR = Path("/tmp/models/realvisxl_v4")
13
- os.makedirs(MODEL_DIR, exist_ok=True)
14
-
15
- # ---------------- MODEL DOWNLOAD ----------------
16
- def download_model() -> Path:
17
- """
18
- Downloads RealVisXL V4.0 model if not present.
19
- Returns the local model path.
20
- """
21
- model_path = MODEL_DIR / MODEL_FILENAME
22
- if not model_path.exists():
23
- print("[ImageGen] Downloading RealVisXL V4.0 model...")
24
- model_path = hf_hub_download(
25
- repo_id=MODEL_REPO,
26
- filename=MODEL_FILENAME,
27
- local_dir=str(MODEL_DIR),
28
- force_download=False,
29
- )
30
- print(f"[ImageGen] Model downloaded to: {model_path}")
31
- else:
32
- print("[ImageGen] Model already exists. Skipping download.")
33
- return model_path
34
-
35
- # ---------------- PIPELINE LOAD ----------------
36
- def load_pipeline() -> StableDiffusionXLPipeline:
37
- """
38
- Loads the RealVisXL V4.0 model for image generation.
39
- """
40
- model_path = download_model()
41
- print("[ImageGen] Loading model into pipeline...")
42
- pipe = StableDiffusionXLPipeline.from_single_file(
43
- str(model_path),
44
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
- )
46
- if torch.cuda.is_available():
47
- pipe.to("cuda")
48
- print("[ImageGen] Model ready.")
49
- return pipe
50
-
51
- # ---------------- GLOBAL PIPELINE CACHE ----------------
52
- pipe: StableDiffusionXLPipeline | None = None
53
-
54
- # ---------------- IMAGE GENERATION ----------------
55
- def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List:
56
- """
57
- Generates high-quality images using RealVisXL V4.0.
58
- Supports deterministic generation using a seed.
59
-
60
- Args:
61
- prompt (str): Text prompt for image generation.
62
- seed (int, optional): Seed for deterministic generation.
63
- num_images (int): Number of images to generate.
64
-
65
- Returns:
66
- List: Generated PIL images.
67
- """
68
- global pipe
69
- if pipe is None:
70
- pipe = load_pipeline()
71
-
72
- print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}")
73
- images = []
74
-
75
- for i in range(num_images):
76
- generator = None
77
- if seed is not None:
78
- device = "cuda" if torch.cuda.is_available() else "cpu"
79
- generator = torch.Generator(device).manual_seed(seed + i) # slightly vary keyframes
80
-
81
- result = pipe(prompt, num_inference_steps=30, generator=generator).images[0]
82
- images.append(result)
83
-
84
- print(f"[ImageGen] Generated {len(images)} images successfully.")
85
- return images
 
1
+ # core/image_generator.py
2
+ import os
3
+ import torch
4
+ from diffusers import StableDiffusionXLPipeline
5
+ from huggingface_hub import hf_hub_download
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ # ---------------- MODEL CONFIG ----------------
10
+ MODEL_REPO = "SG161222/RealVisXL_V4.0"
11
+ MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
12
+ MODEL_DIR = Path("/tmp/models/realvisxl_v4")
13
+ os.makedirs(MODEL_DIR, exist_ok=True)
14
+
15
+ # ---------------- MODEL DOWNLOAD ----------------
16
+ def download_model() -> Path:
17
+ """
18
+ Downloads RealVisXL V4.0 model if not present.
19
+ Returns the local model path.
20
+ """
21
+ model_path = MODEL_DIR / MODEL_FILENAME
22
+ if not model_path.exists():
23
+ print("[ImageGen] Downloading RealVisXL V4.0 model...")
24
+ model_path = hf_hub_download(
25
+ repo_id=MODEL_REPO,
26
+ filename=MODEL_FILENAME,
27
+ local_dir=str(MODEL_DIR),
28
+ force_download=False,
29
+ )
30
+ print(f"[ImageGen] Model downloaded to: {model_path}")
31
+ else:
32
+ print("[ImageGen] Model already exists. Skipping download.")
33
+ return model_path
34
+
35
+ # ---------------- PIPELINE LOAD ----------------
36
+ def load_pipeline() -> StableDiffusionXLPipeline:
37
+ """
38
+ Loads the RealVisXL V4.0 model for image generation.
39
+ """
40
+ model_path = download_model()
41
+ print("[ImageGen] Loading model into pipeline...")
42
+ pipe = StableDiffusionXLPipeline.from_single_file(
43
+ str(model_path),
44
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
+ )
46
+ if torch.cuda.is_available():
47
+ pipe.to("cuda")
48
+ print("[ImageGen] Model ready.")
49
+ return pipe
50
+
51
+ # ---------------- GLOBAL PIPELINE CACHE ----------------
52
+ pipe: StableDiffusionXLPipeline | None = None
53
+
54
+ # ---------------- IMAGE GENERATION ----------------
55
+ def generate_images(prompt: str, seed: int = None, num_images: int = 3) -> List:
56
+ """
57
+ Generates high-quality images using RealVisXL V4.0.
58
+ Supports deterministic generation using a seed.
59
+
60
+ Args:
61
+ prompt (str): Text prompt for image generation.
62
+ seed (int, optional): Seed for deterministic generation.
63
+ num_images (int): Number of images to generate.
64
+
65
+ Returns:
66
+ List: Generated PIL images.
67
+ """
68
+ global pipe
69
+ if pipe is None:
70
+ pipe = load_pipeline()
71
+
72
+ print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' with seed={seed}")
73
+ images = []
74
+
75
+ for i in range(num_images):
76
+ generator = None
77
+ if seed is not None:
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ generator = torch.Generator(device).manual_seed(seed + i) # slightly vary keyframes
80
+
81
+ result = pipe(prompt, num_inference_steps=30, generator=generator).images[0]
82
+ images.append(pil_to_base64(result))
83
+
84
+ print(f"[ImageGen] Generated {len(images)} images successfully.")
85
+ return images