Chain-of-Zoom / app.py
alexnasa's picture
Update app.py
dc4acf0 verified
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
import spaces
# import the updated recursive_multiscale_sr that expects a list of centers
from inference_coz_single import recursive_multiscale_sr
from PIL import Image, ImageDraw
# ------------------------------------------------------------------
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
# ------------------------------------------------------------------
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
# ------------------------------------------------------------------
# HELPER: Resize & center-crop to 512, preserving aspect ratio
# ------------------------------------------------------------------
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
"""
Resize the input PIL image so that its shorter side == `size`,
then center-crop to exactly (size x size).
"""
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
# ------------------------------------------------------------------
# HELPER: Draw four true “nested” rectangles, matching the SR logic
# ------------------------------------------------------------------
def make_preview_with_boxes(
image_path: str,
scale_option: str,
cx_norm: float,
cy_norm: float,
) -> tuple[Image.Image, list[tuple[float, float]]]:
"""
Returns:
- The preview image with drawn boxes.
- A list of (cx_norm, cy_norm) for each box (normalized to 512×512).
"""
try:
orig = Image.open(image_path).convert("RGB")
except Exception as e:
fallback = Image.new("RGB", (512, 512), (200, 200, 200))
ImageDraw.Draw(fallback).text((20, 20), f"Error:\n{e}", fill="red")
return fallback, []
base = resize_and_center_crop(orig, 512)
scale_int = int(scale_option.replace("x", ""))
if scale_int <= 1:
sizes = [512.0, 512.0, 512.0, 512.0]
else:
sizes = [512.0 / (scale_int ** (i + 1)) for i in range(4)]
draw = ImageDraw.Draw(base)
colors = ["red", "lime", "cyan", "yellow"]
width = 3
abs_cx = cx_norm * 512.0
abs_cy = cy_norm * 512.0
prev_x0, prev_y0, prev_size = 0.0, 0.0, 512.0
centers: list[tuple[float, float]] = []
for i, crop_size in enumerate(sizes):
x0 = abs_cx - (crop_size / 2.0)
y0 = abs_cy - (crop_size / 2.0)
min_x0 = prev_x0
max_x0 = prev_x0 + prev_size - crop_size
min_y0 = prev_y0
max_y0 = prev_y0 + prev_size - crop_size
x0 = max(min_x0, min(x0, max_x0))
y0 = max(min_y0, min(y0, max_y0))
x1 = x0 + crop_size
y1 = y0 + crop_size
draw.rectangle([(int(round(x0)), int(round(y0))),
(int(round(x1)), int(round(y1)))],
outline=colors[i % len(colors)], width=width)
# --- compute normalized center of this box ---
cx_box = ((x0 - prev_x0) + crop_size / 2.0) / float(prev_size)
cy_box = ((y0 - prev_y0) + crop_size / 2.0) / float(prev_size)
centers.append((cx_box, cy_box))
prev_x0, prev_y0, prev_size = x0, y0, crop_size
return base, centers
# ------------------------------------------------------------------
# HELPER FUNCTION FOR INFERENCE (build a list of identical centers)
# ------------------------------------------------------------------
@spaces.GPU()
def run_with_upload(
uploaded_image_path: str,
upscale_option: str,
cx_norm: float,
cy_norm: float,
):
"""
Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point.
This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model.
Args:
uploaded_image_path (str): Path to the input image file on disk.
upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x".
- "1x" means no upscaling.
- "2x" means 2× enlargement per zoom step.
- "4x" means 4× enlargement per zoom step.
cx_norm (float): Normalized X-coordinate (0 to 1) of the zoom center.
cy_norm (float): Normalized Y-coordinate (0 to 1) of the zoom center.
Returns:
list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4),
centered around the user-specified point.
Note:
The center point is repeated for each recursion level to maintain consistency during zooming.
This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference.
"""
if uploaded_image_path is None:
return []
upscale_value = int(upscale_option.replace("x", ""))
rec_num = 4 # match the SR pipeline’s default recursion depth
centers = [(cx_norm, cy_norm)] * rec_num
# Call the modified SR function
sr_list, _ = recursive_multiscale_sr(
uploaded_image_path,
upscale=upscale_value,
rec_num=rec_num,
centers=centers,
)
# Return the list of PIL images (Gradio Gallery expects a list)
return sr_list
@spaces.GPU()
def magnify(
uploaded_image_path: str,
upscale_option: str,
centres: list
):
"""
Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point.
This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model.
Args:
uploaded_image_path (str): Path to the input image file on disk.
upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x".
- "1x" means no upscaling.
- "2x" means 2× enlargement per zoom step.
- "4x" means 4× enlargement per zoom step.
centres (list): Normalized list of X-coordinate, Y-coordinate (0 to 1) of the zoom center.
Returns:
list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4),
centered around the user-specified point.
Note:
The center point is repeated for each recursion level to maintain consistency during zooming.
This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference.
"""
if uploaded_image_path is None:
return []
upscale_value = int(upscale_option.replace("x", ""))
rec_num = len(centres)
# Call the modified SR function
sr_list, _ = recursive_multiscale_sr(
uploaded_image_path,
upscale=upscale_value,
rec_num=rec_num,
centers=centres,
)
# Return the list of PIL images (Gradio Gallery expects a list)
return sr_list
# ------------------------------------------------------------------
# BUILD THE GRADIO INTERFACE (two sliders + correct preview)
# ------------------------------------------------------------------
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
session_centres = gr.State()
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: left;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>Chain-of-Zoom</strong> – Extreme Super-Resolution via Scale Autoregression and Preference Alignment
</p>
<a href="https://github.com/bryanswkim/Chain-of-Zoom" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
[Github]
</a>
</div>
<div style="text-align: left;">
<strong>HF Space by:</strong>
<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
# 1) Image upload component
upload_image = gr.Image(
label="Input image",
type="filepath"
)
# 2) Radio for choosing 1× / 2× / 4× upscaling
upscale_radio = gr.Radio(
choices=["1x", "2x", "4x"],
value="2x",
show_label=False
)
# 3) Two sliders for normalized center (0..1)
center_x = gr.Slider(
label="Center X (normalized)",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5
)
center_y = gr.Slider(
label="Center Y (normalized)",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5
)
# 4) Button to launch inference
run_button = gr.Button("🔎 Chain-of-Zoom it", variant="primary")
gr.Markdown("*Click anywhere on the preview image to select coordinates to zoom*")
# 5) Preview (512×512 + four truly nested boxes)
preview_with_box = gr.Image(
label="Preview",
type="pil",
interactive=False
)
with gr.Column():
# 6) Gallery to display multiple output images
output_gallery = gr.Gallery(
label="Inference Results",
show_label=True,
elem_id="gallery",
columns=[2], rows=[2]
)
examples = gr.Examples(
# List of example-rows. Each row is [input_image, scale, cx, cy]
examples=[["samples/0479.png", "4x", 0.5, 0.5], ["samples/0064.png", "4x", 0.5, 0.5], ["samples/0245.png", "4x", 0.5, 0.5], ["samples/0393.png", "4x", 0.5, 0.5]],
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[output_gallery],
fn=run_with_upload,
cache_examples=True
)
# ------------------------------------------------------------------
# CALLBACK #1: update the preview whenever inputs change
# ------------------------------------------------------------------
def update_preview(
img_path: str,
scale_opt: str,
cx: float,
cy: float
) -> Image.Image | None:
"""
If no image uploaded, show blank; otherwise, draw four nested boxes
exactly as the SR pipeline would crop at each recursion.
"""
if img_path is None:
return None, []
return make_preview_with_boxes(img_path, scale_opt, cx, cy)
def get_select_coords(input_img, evt: gr.SelectData):
print("coordinates selected")
i = evt.index[1]
j = evt.index[0]
w, h = input_img.size
return gr.update(value=j/w), gr.update(value=i/h)
preview_with_box.select(get_select_coords, [preview_with_box], [center_x, center_y])
upload_image.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box, session_centres],
show_api=False
)
upscale_radio.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box, session_centres],
show_api=False
)
center_x.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box, session_centres],
show_api=False
)
center_y.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box, session_centres],
show_api=False
)
# ------------------------------------------------------------------
# CALLBACK #2: on button‐click, run the SR pipeline
# ------------------------------------------------------------------
run_button.click(
fn=magnify,
inputs=[upload_image, upscale_radio, session_centres],
outputs=[output_gallery]
)
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
demo.queue()
demo.launch(share=True, mcp_server=True)