karthikeya1212 commited on
Commit
03b696f
Β·
verified Β·
1 Parent(s): 2db3804

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +22 -12
core/image_generator.py CHANGED
@@ -116,7 +116,7 @@
116
  # return images
117
 
118
 
119
-
120
  import os
121
  import torch
122
  from diffusers import StableDiffusionXLPipeline
@@ -127,15 +127,15 @@ from io import BytesIO
127
  import base64
128
  from PIL import Image
129
 
130
- # ---------------- CACHE & MODEL DIRS ----------------
131
  HF_CACHE_DIR = Path("/tmp/hf_cache")
132
  MODEL_DIR = Path("/tmp/models/realvisxl_v4")
133
 
 
134
  for d in [HF_CACHE_DIR, MODEL_DIR]:
135
  d.mkdir(parents=True, exist_ok=True)
136
- os.chmod(d, 0o777)
137
 
138
- # make sure env vars are applied *before* any huggingface import usage
139
  os.environ.update({
140
  "HF_HOME": str(HF_CACHE_DIR),
141
  "TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
@@ -150,7 +150,10 @@ MODEL_FILENAME = "realvisxlV40_v40LightningBakedvae.safetensors"
150
 
151
  # ---------------- MODEL DOWNLOAD ----------------
152
  def download_model() -> Path:
153
- """Downloads RealVisXL V4.0 model if not present, returns local path."""
 
 
 
154
  model_path = MODEL_DIR / MODEL_FILENAME
155
  if not model_path.exists():
156
  print("[ImageGen] Downloading RealVisXL V4.0 model...")
@@ -161,7 +164,7 @@ def download_model() -> Path:
161
  local_dir=str(MODEL_DIR),
162
  cache_dir=str(HF_CACHE_DIR),
163
  force_download=False,
164
- resume_download=True, # <β€” safer for interrupted downloads
165
  )
166
  )
167
  print(f"[ImageGen] Model downloaded to: {model_path}")
@@ -171,7 +174,9 @@ def download_model() -> Path:
171
 
172
  # ---------------- PIPELINE LOAD ----------------
173
  def load_pipeline() -> StableDiffusionXLPipeline:
174
- """Loads the RealVisXL V4.0 model for image generation."""
 
 
175
  model_path = download_model()
176
  print("[ImageGen] Loading model into pipeline...")
177
 
@@ -185,8 +190,10 @@ def load_pipeline() -> StableDiffusionXLPipeline:
185
  else:
186
  pipe.to("cpu")
187
 
188
- pipe.safety_checker = None # optional: skip safety for performance
189
- pipe.enable_attention_slicing() # <β€” reduces memory use on CPU
 
 
190
 
191
  print("[ImageGen] Model ready.")
192
  return pipe
@@ -194,8 +201,11 @@ def load_pipeline() -> StableDiffusionXLPipeline:
194
  # ---------------- GLOBAL PIPELINE CACHE ----------------
195
  pipe: StableDiffusionXLPipeline | None = None
196
 
197
- # ---------------- UTILITY: PIL β†’ Base64 ----------------
198
  def pil_to_base64(img: Image.Image) -> str:
 
 
 
199
  buffered = BytesIO()
200
  img.save(buffered, format="PNG")
201
  return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
@@ -204,13 +214,13 @@ def pil_to_base64(img: Image.Image) -> str:
204
  def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]:
205
  """
206
  Generates high-quality images using RealVisXL V4.0.
207
- Returns list of base64-encoded PNGs.
208
  """
209
  global pipe
210
  if pipe is None:
211
  pipe = load_pipeline()
212
 
213
- print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' seed={seed}")
214
  images: List[str] = []
215
 
216
  for i in range(num_images):
 
116
  # return images
117
 
118
 
119
+ # core/image_generator.py
120
  import os
121
  import torch
122
  from diffusers import StableDiffusionXLPipeline
 
127
  import base64
128
  from PIL import Image
129
 
130
+ # ---------------- CACHE & MODEL DIRECTORIES ----------------
131
  HF_CACHE_DIR = Path("/tmp/hf_cache")
132
  MODEL_DIR = Path("/tmp/models/realvisxl_v4")
133
 
134
+ # Create directories safely (no chmod)
135
  for d in [HF_CACHE_DIR, MODEL_DIR]:
136
  d.mkdir(parents=True, exist_ok=True)
 
137
 
138
+ # Apply environment variables BEFORE any Hugging Face usage
139
  os.environ.update({
140
  "HF_HOME": str(HF_CACHE_DIR),
141
  "TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
 
150
 
151
  # ---------------- MODEL DOWNLOAD ----------------
152
  def download_model() -> Path:
153
+ """
154
+ Downloads RealVisXL V4.0 model if not present.
155
+ Returns local path.
156
+ """
157
  model_path = MODEL_DIR / MODEL_FILENAME
158
  if not model_path.exists():
159
  print("[ImageGen] Downloading RealVisXL V4.0 model...")
 
164
  local_dir=str(MODEL_DIR),
165
  cache_dir=str(HF_CACHE_DIR),
166
  force_download=False,
167
+ resume_download=True, # safer if download interrupted
168
  )
169
  )
170
  print(f"[ImageGen] Model downloaded to: {model_path}")
 
174
 
175
  # ---------------- PIPELINE LOAD ----------------
176
  def load_pipeline() -> StableDiffusionXLPipeline:
177
+ """
178
+ Loads the RealVisXL V4.0 model for image generation.
179
+ """
180
  model_path = download_model()
181
  print("[ImageGen] Loading model into pipeline...")
182
 
 
190
  else:
191
  pipe.to("cpu")
192
 
193
+ # Optional: skip safety checker to save memory/performance
194
+ pipe.safety_checker = None
195
+ # Enable attention slicing for memory-efficient CPU usage
196
+ pipe.enable_attention_slicing()
197
 
198
  print("[ImageGen] Model ready.")
199
  return pipe
 
201
  # ---------------- GLOBAL PIPELINE CACHE ----------------
202
  pipe: StableDiffusionXLPipeline | None = None
203
 
204
+ # ---------------- UTILITY: PIL β†’ BASE64 ----------------
205
  def pil_to_base64(img: Image.Image) -> str:
206
+ """
207
+ Converts PIL image to base64 string for frontend.
208
+ """
209
  buffered = BytesIO()
210
  img.save(buffered, format="PNG")
211
  return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
 
214
  def generate_images(prompt: str, seed: int | None = None, num_images: int = 3) -> List[str]:
215
  """
216
  Generates high-quality images using RealVisXL V4.0.
217
+ Returns a list of base64-encoded PNGs.
218
  """
219
  global pipe
220
  if pipe is None:
221
  pipe = load_pipeline()
222
 
223
+ print(f"[ImageGen] Generating {num_images} image(s) for prompt: '{prompt}' seed={seed}")
224
  images: List[str] = []
225
 
226
  for i in range(num_images):