Spaces:
Running
on
Zero
Running
on
Zero
Update inference_coz_single.py
Browse files- inference_coz_single.py +59 -75
inference_coz_single.py
CHANGED
|
@@ -25,66 +25,77 @@ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
|
|
| 25 |
# Helper: Generate a single VLM prompt for recursive_multiscale
|
| 26 |
# -------------------------------------------------------------------
|
| 27 |
def _generate_vlm_prompt(
|
| 28 |
-
vlm_model,
|
| 29 |
-
vlm_processor,
|
| 30 |
-
process_vision_info,
|
| 31 |
-
|
| 32 |
-
|
| 33 |
device: str = "cuda"
|
| 34 |
) -> str:
|
| 35 |
"""
|
| 36 |
-
Given two
|
| 37 |
-
-
|
| 38 |
-
-
|
| 39 |
-
|
| 40 |
-
Returns a string like “cat on sofa, pet, indoor, living room”, etc.
|
| 41 |
"""
|
| 42 |
-
|
|
|
|
| 43 |
message_text = (
|
| 44 |
"The second image is a zoom-in of the first image. "
|
| 45 |
"Based on this knowledge, what is in the second image? "
|
| 46 |
"Give me a set of words."
|
| 47 |
)
|
| 48 |
|
| 49 |
-
# (2) Build the two-image “chat” payload
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
messages = [
|
| 51 |
{"role": "system", "content": message_text},
|
| 52 |
{
|
| 53 |
"role": "user",
|
| 54 |
"content": [
|
| 55 |
-
{"type": "image", "image":
|
| 56 |
-
{"type": "image", "image":
|
| 57 |
],
|
| 58 |
},
|
| 59 |
]
|
| 60 |
|
| 61 |
-
# (3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
text = vlm_processor.apply_chat_template(
|
| 63 |
-
messages,
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
image_inputs, video_inputs = process_vision_info(messages)
|
|
|
|
| 66 |
inputs = vlm_processor(
|
| 67 |
-
text=[text],
|
| 68 |
-
images=image_inputs,
|
| 69 |
-
videos=video_inputs,
|
| 70 |
-
padding=True,
|
| 71 |
return_tensors="pt",
|
| 72 |
).to(device)
|
| 73 |
|
| 74 |
-
# (4) Generate
|
| 75 |
generated = vlm_model.generate(**inputs, max_new_tokens=128)
|
| 76 |
-
# strip off the prompt tokens from each generated sequence:
|
| 77 |
trimmed = [
|
| 78 |
-
out_ids[len(in_ids)
|
|
|
|
| 79 |
]
|
| 80 |
out_text = vlm_processor.batch_decode(
|
| 81 |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 82 |
)[0]
|
| 83 |
|
| 84 |
-
# (5) Return exactly the bare words (no extra “,” if no additional user prompt)
|
| 85 |
return out_text.strip()
|
| 86 |
|
| 87 |
|
|
|
|
| 88 |
# -------------------------------------------------------------------
|
| 89 |
# Main Function: recursive_multiscale_sr (with multiple centers)
|
| 90 |
# -------------------------------------------------------------------
|
|
@@ -203,88 +214,61 @@ def recursive_multiscale_sr(
|
|
| 203 |
###############################
|
| 204 |
# 6. Prepare the very first “full” image
|
| 205 |
###############################
|
| 206 |
-
# 6.1 Load + center crop → first_image
|
| 207 |
img0 = Image.open(input_png_path).convert("RGB")
|
| 208 |
img0 = resize_and_center_crop(img0, process_size)
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
img0.save(prev_path)
|
| 213 |
|
| 214 |
-
# We will maintain lists of PIL outputs and prompts:
|
| 215 |
sr_pil_list: list[Image.Image] = []
|
| 216 |
-
prompt_list:
|
| 217 |
|
| 218 |
-
###############################
|
| 219 |
-
# 7. Recursion loop (now up to rec_num times)
|
| 220 |
-
###############################
|
| 221 |
for rec in range(rec_num):
|
| 222 |
-
# (A)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
|
| 226 |
-
# (1) Compute the “low-res” window size:
|
| 227 |
-
new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
|
| 228 |
-
|
| 229 |
-
# (2) Map normalized center → pixel center, then clamp so crop stays in bounds:
|
| 230 |
cx_norm, cy_norm = centers[rec]
|
| 231 |
cx = int(cx_norm * w)
|
| 232 |
cy = int(cy_norm * h)
|
| 233 |
-
half_w = new_w // 2
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
top = cy - half_h
|
| 239 |
-
# clamp left ∈ [0, w - new_w], top ∈ [0, h - new_h]
|
| 240 |
-
left = max(0, min(left, w - new_w))
|
| 241 |
-
top = max(0, min(top, h - new_h))
|
| 242 |
-
right = left + new_w
|
| 243 |
-
bottom = top + new_h
|
| 244 |
|
| 245 |
cropped = prev_pil.crop((left, top, right, bottom))
|
| 246 |
|
| 247 |
-
# (B)
|
| 248 |
-
|
| 249 |
-
zoom_path = os.path.join(td, f"step{rec+1}_zoom.png")
|
| 250 |
-
zoomed.save(zoom_path)
|
| 251 |
|
| 252 |
-
# (C) Generate
|
| 253 |
prompt_tag = _generate_vlm_prompt(
|
| 254 |
vlm_model=vlm_model,
|
| 255 |
vlm_processor=vlm_processor,
|
| 256 |
process_vision_info=process_vision_info,
|
| 257 |
-
|
| 258 |
-
|
| 259 |
device=device,
|
| 260 |
)
|
| 261 |
-
# (By default, no extra user prompt is appended.)
|
| 262 |
|
| 263 |
-
# (D) Prepare
|
| 264 |
to_tensor = transforms.ToTensor()
|
| 265 |
-
lq = to_tensor(
|
| 266 |
lq = (lq * 2.0) - 1.0
|
| 267 |
|
| 268 |
-
# (E)
|
| 269 |
with torch.no_grad():
|
| 270 |
-
out_tensor = model_test(lq, prompt=prompt_tag)[0]
|
| 271 |
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
|
| 272 |
-
# back to PIL in [0,1]:
|
| 273 |
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
|
| 274 |
|
| 275 |
-
# (F)
|
| 276 |
-
|
| 277 |
-
out_pil.save(out_path)
|
| 278 |
-
prev_path = out_path
|
| 279 |
|
| 280 |
-
# (G) Append
|
| 281 |
sr_pil_list.append(out_pil)
|
| 282 |
prompt_list.append(prompt_tag)
|
| 283 |
|
| 284 |
-
# end for(rec)
|
| 285 |
-
|
| 286 |
-
###############################
|
| 287 |
-
# 8. Return the SR outputs & prompts
|
| 288 |
-
###############################
|
| 289 |
-
# The list sr_pil_list = [ SR1, SR2, …, SR_rec_num ] in order.
|
| 290 |
return sr_pil_list, prompt_list
|
|
|
|
| 25 |
# Helper: Generate a single VLM prompt for recursive_multiscale
|
| 26 |
# -------------------------------------------------------------------
|
| 27 |
def _generate_vlm_prompt(
|
| 28 |
+
vlm_model: Qwen2_5_VLForConditionalGeneration,
|
| 29 |
+
vlm_processor: AutoProcessor,
|
| 30 |
+
process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs
|
| 31 |
+
prev_pil: Image.Image, # <– pass PIL instead of path
|
| 32 |
+
zoomed_pil: Image.Image, # <– pass PIL instead of path
|
| 33 |
device: str = "cuda"
|
| 34 |
) -> str:
|
| 35 |
"""
|
| 36 |
+
Given two PIL.Image inputs:
|
| 37 |
+
- prev_pil: the “full” image at the previous recursion.
|
| 38 |
+
- zoomed_pil: the cropped+resized (zoom) image for this step.
|
| 39 |
+
Returns a single “recursive_multiscale” prompt string.
|
|
|
|
| 40 |
"""
|
| 41 |
+
|
| 42 |
+
# (1) System message
|
| 43 |
message_text = (
|
| 44 |
"The second image is a zoom-in of the first image. "
|
| 45 |
"Based on this knowledge, what is in the second image? "
|
| 46 |
"Give me a set of words."
|
| 47 |
)
|
| 48 |
|
| 49 |
+
# (2) Build the two-image “chat” payload
|
| 50 |
+
#
|
| 51 |
+
# Instead of passing a filename, we pass the actual PIL.Image.
|
| 52 |
+
# The processor’s `process_vision_info` should know how to turn
|
| 53 |
+
# a message of the form {"type":"image","image": PIL_IMAGE} into tensors.
|
| 54 |
messages = [
|
| 55 |
{"role": "system", "content": message_text},
|
| 56 |
{
|
| 57 |
"role": "user",
|
| 58 |
"content": [
|
| 59 |
+
{"type": "image", "image": prev_pil},
|
| 60 |
+
{"type": "image", "image": zoomed_pil},
|
| 61 |
],
|
| 62 |
},
|
| 63 |
]
|
| 64 |
|
| 65 |
+
# (3) Now run the “chat” through the VL processor
|
| 66 |
+
#
|
| 67 |
+
# - `apply_chat_template` will build the tokenized prompt (without running it yet).
|
| 68 |
+
# - `process_vision_info` should inspect the same `messages` list and return
|
| 69 |
+
# `image_inputs` and `video_inputs` (tensors) for any attached PIL images.
|
| 70 |
text = vlm_processor.apply_chat_template(
|
| 71 |
+
messages,
|
| 72 |
+
tokenize=False,
|
| 73 |
+
add_generation_prompt=True
|
| 74 |
)
|
| 75 |
image_inputs, video_inputs = process_vision_info(messages)
|
| 76 |
+
|
| 77 |
inputs = vlm_processor(
|
| 78 |
+
text=[text],
|
| 79 |
+
images=image_inputs,
|
| 80 |
+
videos=video_inputs,
|
| 81 |
+
padding=True,
|
| 82 |
return_tensors="pt",
|
| 83 |
).to(device)
|
| 84 |
|
| 85 |
+
# (4) Generate and decode
|
| 86 |
generated = vlm_model.generate(**inputs, max_new_tokens=128)
|
|
|
|
| 87 |
trimmed = [
|
| 88 |
+
out_ids[len(in_ids):]
|
| 89 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated)
|
| 90 |
]
|
| 91 |
out_text = vlm_processor.batch_decode(
|
| 92 |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 93 |
)[0]
|
| 94 |
|
|
|
|
| 95 |
return out_text.strip()
|
| 96 |
|
| 97 |
|
| 98 |
+
|
| 99 |
# -------------------------------------------------------------------
|
| 100 |
# Main Function: recursive_multiscale_sr (with multiple centers)
|
| 101 |
# -------------------------------------------------------------------
|
|
|
|
| 214 |
###############################
|
| 215 |
# 6. Prepare the very first “full” image
|
| 216 |
###############################
|
| 217 |
+
# (6.1) Load + center crop → first_image (512×512)
|
| 218 |
img0 = Image.open(input_png_path).convert("RGB")
|
| 219 |
img0 = resize_and_center_crop(img0, process_size)
|
| 220 |
|
| 221 |
+
# Note: we no longer need to write “prev.png” to disk. Just keep it in memory.
|
| 222 |
+
prev_pil = img0.copy()
|
|
|
|
| 223 |
|
|
|
|
| 224 |
sr_pil_list: list[Image.Image] = []
|
| 225 |
+
prompt_list: list[str] = []
|
| 226 |
|
|
|
|
|
|
|
|
|
|
| 227 |
for rec in range(rec_num):
|
| 228 |
+
# (A) Compute low-res crop window on prev_pil
|
| 229 |
+
w, h = prev_pil.size # (512×512)
|
| 230 |
+
new_w, new_h = w // upscale, h // upscale
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
cx_norm, cy_norm = centers[rec]
|
| 233 |
cx = int(cx_norm * w)
|
| 234 |
cy = int(cy_norm * h)
|
| 235 |
+
half_w, half_h = new_w // 2, new_h // 2
|
| 236 |
+
|
| 237 |
+
left = max(0, min(cx - half_w, w - new_w))
|
| 238 |
+
top = max(0, min(cy - half_h, h - new_h))
|
| 239 |
+
right, bottom = left + new_w, top + new_h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
cropped = prev_pil.crop((left, top, right, bottom))
|
| 242 |
|
| 243 |
+
# (B) Upsample that crop back to (512×512)
|
| 244 |
+
zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
# (C) Generate VLM prompt by passing PILs directly:
|
| 247 |
prompt_tag = _generate_vlm_prompt(
|
| 248 |
vlm_model=vlm_model,
|
| 249 |
vlm_processor=vlm_processor,
|
| 250 |
process_vision_info=process_vision_info,
|
| 251 |
+
prev_pil=prev_pil, # <– PIL
|
| 252 |
+
zoomed_pil=zoomed_pil, # <– PIL
|
| 253 |
device=device,
|
| 254 |
)
|
|
|
|
| 255 |
|
| 256 |
+
# (D) Prepare “zoomed_pil” → tensor in [−1, 1]
|
| 257 |
to_tensor = transforms.ToTensor()
|
| 258 |
+
lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
|
| 259 |
lq = (lq * 2.0) - 1.0
|
| 260 |
|
| 261 |
+
# (E) Run SR inference
|
| 262 |
with torch.no_grad():
|
| 263 |
+
out_tensor = model_test(lq, prompt=prompt_tag)[0]
|
| 264 |
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
|
|
|
|
| 265 |
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
|
| 266 |
|
| 267 |
+
# (F) Bookkeeping: set prev_pil = out_pil for next iteration
|
| 268 |
+
prev_pil = out_pil
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
# (G) Append to results
|
| 271 |
sr_pil_list.append(out_pil)
|
| 272 |
prompt_list.append(prompt_tag)
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return sr_pil_list, prompt_list
|