karthikeya1212 commited on
Commit
c65196c
Β·
verified Β·
1 Parent(s): 0050d71

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +10 -16
core/image_generator.py CHANGED
@@ -181,19 +181,9 @@ def pil_to_base64(img: Image.Image) -> str:
181
  # UNIFIED IMAGE GENERATION FUNCTION
182
  # --------------------------------------------------------------
183
  async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
184
- """
185
- Universal entrypoint.
186
- - If input is a string β†’ generate list of images
187
- - If input is a JSON (dict) β†’ generate character & keyframe images and return updated JSON
188
- """
189
  global pipe, img2img_pipe
190
  device = "cuda" if torch.cuda.is_available() else "cpu"
191
 
192
- if pipe is None:
193
- pipe = load_pipeline()
194
- if img2img_pipe is None:
195
- img2img_pipe = load_img2img_pipeline()
196
-
197
  # ----------------------------------------------------------
198
  # CASE 1: STRUCTURED JSON (story mode)
199
  # ----------------------------------------------------------
@@ -201,7 +191,8 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
201
  story_json = prompt_or_json
202
  print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
203
 
204
- # ---------------- Step 1: Character Images ----------------
 
205
  seed_to_char_image = {}
206
  for char in story_json.get("characters", []):
207
  char_name = char["name"]
@@ -218,10 +209,13 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
218
  image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
219
  image.save(seed_image_path)
220
 
221
- # Map seed β†’ character image
222
  seed_to_char_image[char_seed] = image
223
 
224
- # ---------------- Step 2: Keyframe Images ----------------
 
 
 
 
225
  for key, scene_data in story_json.items():
226
  if not key.startswith("scene"):
227
  continue
@@ -246,13 +240,11 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
246
  generator=generator
247
  ).images[0]
248
 
249
- # Save temporarily
250
  out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
251
  img.save(out_path)
252
-
253
- # Replace prompt with actual base64 image
254
  frame[kf_key] = pil_to_base64(img)
255
 
 
256
  print("[ImageGen] βœ… Story JSON image generation complete.")
257
  return story_json
258
 
@@ -260,6 +252,7 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
260
  # CASE 2: NORMAL PROMPT
261
  # ----------------------------------------------------------
262
  print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
 
263
  images = []
264
  for i in range(num_images):
265
  gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
@@ -271,5 +264,6 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
271
  except Exception as e:
272
  print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
273
 
 
274
  print(f"[ImageGen] βœ… Generated {len(images)} image(s) successfully.")
275
  return images
 
181
  # UNIFIED IMAGE GENERATION FUNCTION
182
  # --------------------------------------------------------------
183
  async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
 
 
 
 
 
184
  global pipe, img2img_pipe
185
  device = "cuda" if torch.cuda.is_available() else "cpu"
186
 
 
 
 
 
 
187
  # ----------------------------------------------------------
188
  # CASE 1: STRUCTURED JSON (story mode)
189
  # ----------------------------------------------------------
 
191
  story_json = prompt_or_json
192
  print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
193
 
194
+ # Step 1: Load only txt2img for character generation
195
+ pipe = load_pipeline()
196
  seed_to_char_image = {}
197
  for char in story_json.get("characters", []):
198
  char_name = char["name"]
 
209
  image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
210
  image.save(seed_image_path)
211
 
 
212
  seed_to_char_image[char_seed] = image
213
 
214
+ # Free txt2img pipeline
215
+ unload_pipelines(target="pipe")
216
+
217
+ # Step 2: Load only img2img for keyframes
218
+ img2img_pipe = load_img2img_pipeline()
219
  for key, scene_data in story_json.items():
220
  if not key.startswith("scene"):
221
  continue
 
240
  generator=generator
241
  ).images[0]
242
 
 
243
  out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
244
  img.save(out_path)
 
 
245
  frame[kf_key] = pil_to_base64(img)
246
 
247
+ unload_pipelines(target="all") # unload both just in case
248
  print("[ImageGen] βœ… Story JSON image generation complete.")
249
  return story_json
250
 
 
252
  # CASE 2: NORMAL PROMPT
253
  # ----------------------------------------------------------
254
  print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
255
+ pipe = load_pipeline()
256
  images = []
257
  for i in range(num_images):
258
  gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
 
264
  except Exception as e:
265
  print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
266
 
267
+ unload_pipelines(target="pipe")
268
  print(f"[ImageGen] βœ… Generated {len(images)} image(s) successfully.")
269
  return images