Spaces:
Sleeping
Sleeping
Update core/image_generator.py
Browse files- core/image_generator.py +10 -16
core/image_generator.py
CHANGED
|
@@ -181,19 +181,9 @@ def pil_to_base64(img: Image.Image) -> str:
|
|
| 181 |
# UNIFIED IMAGE GENERATION FUNCTION
|
| 182 |
# --------------------------------------------------------------
|
| 183 |
async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
|
| 184 |
-
"""
|
| 185 |
-
Universal entrypoint.
|
| 186 |
-
- If input is a string β generate list of images
|
| 187 |
-
- If input is a JSON (dict) β generate character & keyframe images and return updated JSON
|
| 188 |
-
"""
|
| 189 |
global pipe, img2img_pipe
|
| 190 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 191 |
|
| 192 |
-
if pipe is None:
|
| 193 |
-
pipe = load_pipeline()
|
| 194 |
-
if img2img_pipe is None:
|
| 195 |
-
img2img_pipe = load_img2img_pipeline()
|
| 196 |
-
|
| 197 |
# ----------------------------------------------------------
|
| 198 |
# CASE 1: STRUCTURED JSON (story mode)
|
| 199 |
# ----------------------------------------------------------
|
|
@@ -201,7 +191,8 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 201 |
story_json = prompt_or_json
|
| 202 |
print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
|
| 203 |
|
| 204 |
-
#
|
|
|
|
| 205 |
seed_to_char_image = {}
|
| 206 |
for char in story_json.get("characters", []):
|
| 207 |
char_name = char["name"]
|
|
@@ -218,10 +209,13 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 218 |
image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
|
| 219 |
image.save(seed_image_path)
|
| 220 |
|
| 221 |
-
# Map seed β character image
|
| 222 |
seed_to_char_image[char_seed] = image
|
| 223 |
|
| 224 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
for key, scene_data in story_json.items():
|
| 226 |
if not key.startswith("scene"):
|
| 227 |
continue
|
|
@@ -246,13 +240,11 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 246 |
generator=generator
|
| 247 |
).images[0]
|
| 248 |
|
| 249 |
-
# Save temporarily
|
| 250 |
out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
|
| 251 |
img.save(out_path)
|
| 252 |
-
|
| 253 |
-
# Replace prompt with actual base64 image
|
| 254 |
frame[kf_key] = pil_to_base64(img)
|
| 255 |
|
|
|
|
| 256 |
print("[ImageGen] β
Story JSON image generation complete.")
|
| 257 |
return story_json
|
| 258 |
|
|
@@ -260,6 +252,7 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 260 |
# CASE 2: NORMAL PROMPT
|
| 261 |
# ----------------------------------------------------------
|
| 262 |
print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
|
|
|
|
| 263 |
images = []
|
| 264 |
for i in range(num_images):
|
| 265 |
gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
|
|
@@ -271,5 +264,6 @@ async def generate_images(prompt_or_json, seed: int | None = None, num_images: i
|
|
| 271 |
except Exception as e:
|
| 272 |
print(f"[ImageGen] β οΈ Failed on image {i}: {e}")
|
| 273 |
|
|
|
|
| 274 |
print(f"[ImageGen] β
Generated {len(images)} image(s) successfully.")
|
| 275 |
return images
|
|
|
|
| 181 |
# UNIFIED IMAGE GENERATION FUNCTION
|
| 182 |
# --------------------------------------------------------------
|
| 183 |
async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
global pipe, img2img_pipe
|
| 185 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# ----------------------------------------------------------
|
| 188 |
# CASE 1: STRUCTURED JSON (story mode)
|
| 189 |
# ----------------------------------------------------------
|
|
|
|
| 191 |
story_json = prompt_or_json
|
| 192 |
print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
|
| 193 |
|
| 194 |
+
# Step 1: Load only txt2img for character generation
|
| 195 |
+
pipe = load_pipeline()
|
| 196 |
seed_to_char_image = {}
|
| 197 |
for char in story_json.get("characters", []):
|
| 198 |
char_name = char["name"]
|
|
|
|
| 209 |
image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
|
| 210 |
image.save(seed_image_path)
|
| 211 |
|
|
|
|
| 212 |
seed_to_char_image[char_seed] = image
|
| 213 |
|
| 214 |
+
# Free txt2img pipeline
|
| 215 |
+
unload_pipelines(target="pipe")
|
| 216 |
+
|
| 217 |
+
# Step 2: Load only img2img for keyframes
|
| 218 |
+
img2img_pipe = load_img2img_pipeline()
|
| 219 |
for key, scene_data in story_json.items():
|
| 220 |
if not key.startswith("scene"):
|
| 221 |
continue
|
|
|
|
| 240 |
generator=generator
|
| 241 |
).images[0]
|
| 242 |
|
|
|
|
| 243 |
out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
|
| 244 |
img.save(out_path)
|
|
|
|
|
|
|
| 245 |
frame[kf_key] = pil_to_base64(img)
|
| 246 |
|
| 247 |
+
unload_pipelines(target="all") # unload both just in case
|
| 248 |
print("[ImageGen] β
Story JSON image generation complete.")
|
| 249 |
return story_json
|
| 250 |
|
|
|
|
| 252 |
# CASE 2: NORMAL PROMPT
|
| 253 |
# ----------------------------------------------------------
|
| 254 |
print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
|
| 255 |
+
pipe = load_pipeline()
|
| 256 |
images = []
|
| 257 |
for i in range(num_images):
|
| 258 |
gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
|
|
|
|
| 264 |
except Exception as e:
|
| 265 |
print(f"[ImageGen] β οΈ Failed on image {i}: {e}")
|
| 266 |
|
| 267 |
+
unload_pipelines(target="pipe")
|
| 268 |
print(f"[ImageGen] β
Generated {len(images)} image(s) successfully.")
|
| 269 |
return images
|