Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,8 +7,12 @@ import re
|
|
| 7 |
from spaces import GPU
|
| 8 |
|
| 9 |
# --- 1. Configurations and Constants ---
|
| 10 |
-
#
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Target image size for model input
|
| 14 |
TARGET_SIZE = (924, 1204)
|
|
@@ -34,37 +38,49 @@ DEFAULT_PROMPT = (
|
|
| 34 |
"""<image>Please carefully observe the document and detect the following regions: "title", "abstract", "heading", "footnote", "figure", "figure caption", "table", "table caption", "math", "text". Output each detected region's bbox coordinates in JSON format. The format of the output is: <answer>```json[{"bbox_2d": [x1, y1, x2, y2], "label": "region name", "order": "reading order"}]```</answer>."""
|
| 35 |
)
|
| 36 |
|
| 37 |
-
# --- 2. Load
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
try:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
except Exception as e:
|
| 48 |
-
print(f"Error loading
|
| 49 |
exit()
|
| 50 |
|
| 51 |
# --- 3. Core Inference and Visualization Function ---
|
| 52 |
@GPU
|
| 53 |
-
def analyze_and_visualize_layout(input_image: Image.Image, prompt: str, temperature: float, top_p: float, progress=gr.Progress(track_tqdm=True)):
|
| 54 |
"""
|
| 55 |
Takes an image and model parameters, runs inference, and returns a visualized image and raw text output.
|
| 56 |
"""
|
| 57 |
if input_image is None:
|
| 58 |
return None, "Please upload an image first."
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
image = input_image.resize(TARGET_SIZE)
|
| 62 |
image = image.convert("RGBA")
|
| 63 |
|
| 64 |
messages = [
|
| 65 |
{"role": "user", "content": [
|
| 66 |
{"type": "image", "image": image},
|
| 67 |
-
{"type": "text", "text": prompt}
|
| 68 |
]}
|
| 69 |
]
|
| 70 |
|
|
@@ -74,11 +90,10 @@ def analyze_and_visualize_layout(input_image: Image.Image, prompt: str, temperat
|
|
| 74 |
|
| 75 |
progress(0.5, desc="Generating layout data...")
|
| 76 |
with torch.no_grad():
|
| 77 |
-
# Pass new parameters to the model generation
|
| 78 |
output_ids = model.generate(
|
| 79 |
**inputs,
|
| 80 |
max_new_tokens=4096,
|
| 81 |
-
do_sample=True,
|
| 82 |
temperature=temperature,
|
| 83 |
top_p=top_p
|
| 84 |
)
|
|
@@ -150,8 +165,15 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
|
|
| 150 |
with gr.Row():
|
| 151 |
analyze_btn = gr.Button("✨ Analyze Layout", variant="primary", scale=1)
|
| 152 |
|
| 153 |
-
# ---
|
| 154 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
prompt_textbox = gr.Textbox(
|
| 156 |
label="Prompt",
|
| 157 |
value=DEFAULT_PROMPT,
|
|
@@ -181,7 +203,6 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
|
|
| 181 |
examples=[["1.png"], ["2.png"], ["10.png"], ["11.png"], ["3.png"], ["7.png"], ["8.png"]],
|
| 182 |
inputs=[input_image],
|
| 183 |
label="Examples (Click to Run)",
|
| 184 |
-
# Examples now only populate the image input. The user clicks "Analyze" to run with current settings.
|
| 185 |
)
|
| 186 |
|
| 187 |
gr.Markdown("<p style='text-align:center; color:grey;'>Powered by the Latex2Layout dataset generated by Feijiang Han</p>")
|
|
@@ -189,7 +210,7 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
|
|
| 189 |
# --- Event Handlers ---
|
| 190 |
analyze_btn.click(
|
| 191 |
fn=analyze_and_visualize_layout,
|
| 192 |
-
inputs=[input_image, prompt_textbox, temp_slider, top_p_slider],
|
| 193 |
outputs=[output_image, output_text]
|
| 194 |
)
|
| 195 |
|
|
|
|
| 7 |
from spaces import GPU
|
| 8 |
|
| 9 |
# --- 1. Configurations and Constants ---
|
| 10 |
+
# Define IDs for both models
|
| 11 |
+
MODEL_CHOICES = {
|
| 12 |
+
"Latex2Layout-2000-sync (Base)": "ChaseHan/Latex2Layout-2000-sync",
|
| 13 |
+
"Latex2Layout-2000-sync-enhanced (Enhanced)": "ChaseHan/Latex2Layout-2000-sync-enhanced"
|
| 14 |
+
}
|
| 15 |
+
DEFAULT_MODEL_NAME = list(MODEL_CHOICES.keys())[0]
|
| 16 |
|
| 17 |
# Target image size for model input
|
| 18 |
TARGET_SIZE = (924, 1204)
|
|
|
|
| 38 |
"""<image>Please carefully observe the document and detect the following regions: "title", "abstract", "heading", "footnote", "figure", "figure caption", "table", "table caption", "math", "text". Output each detected region's bbox coordinates in JSON format. The format of the output is: <answer>```json[{"bbox_2d": [x1, y1, x2, y2], "label": "region name", "order": "reading order"}]```</answer>."""
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# --- 2. Load Models and Processor ---
|
| 42 |
+
# NOTE: Quantization is used to fit two models in memory.
|
| 43 |
+
# Ensure `bitsandbytes` and `accelerate` are in your requirements.txt
|
| 44 |
+
print("Loading models, this will take some time and VRAM...")
|
| 45 |
+
MODELS = {}
|
| 46 |
try:
|
| 47 |
+
for name, model_id in MODEL_CHOICES.items():
|
| 48 |
+
print(f"Loading {name}...")
|
| 49 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 50 |
+
model_id,
|
| 51 |
+
torch_dtype=torch.float16,
|
| 52 |
+
device_map="auto",
|
| 53 |
+
load_in_4bit=True # Essential for loading two models
|
| 54 |
+
)
|
| 55 |
+
MODELS[name] = model
|
| 56 |
+
|
| 57 |
+
# Processor is the same for both models
|
| 58 |
+
processor = AutoProcessor.from_pretrained(list(MODEL_CHOICES.values())[0])
|
| 59 |
+
print("All models loaded successfully!")
|
| 60 |
except Exception as e:
|
| 61 |
+
print(f"Error loading models: {e}")
|
| 62 |
exit()
|
| 63 |
|
| 64 |
# --- 3. Core Inference and Visualization Function ---
|
| 65 |
@GPU
|
| 66 |
+
def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, temperature: float, top_p: float, progress=gr.Progress(track_tqdm=True)):
|
| 67 |
"""
|
| 68 |
Takes an image and model parameters, runs inference, and returns a visualized image and raw text output.
|
| 69 |
"""
|
| 70 |
if input_image is None:
|
| 71 |
return None, "Please upload an image first."
|
| 72 |
|
| 73 |
+
# Select the model based on user's choice
|
| 74 |
+
model = MODELS[selected_model_name]
|
| 75 |
+
progress(0, desc=f"Resizing image for {selected_model_name}...")
|
| 76 |
+
|
| 77 |
image = input_image.resize(TARGET_SIZE)
|
| 78 |
image = image.convert("RGBA")
|
| 79 |
|
| 80 |
messages = [
|
| 81 |
{"role": "user", "content": [
|
| 82 |
{"type": "image", "image": image},
|
| 83 |
+
{"type": "text", "text": prompt}
|
| 84 |
]}
|
| 85 |
]
|
| 86 |
|
|
|
|
| 90 |
|
| 91 |
progress(0.5, desc="Generating layout data...")
|
| 92 |
with torch.no_grad():
|
|
|
|
| 93 |
output_ids = model.generate(
|
| 94 |
**inputs,
|
| 95 |
max_new_tokens=4096,
|
| 96 |
+
do_sample=True,
|
| 97 |
temperature=temperature,
|
| 98 |
top_p=top_p
|
| 99 |
)
|
|
|
|
| 165 |
with gr.Row():
|
| 166 |
analyze_btn = gr.Button("✨ Analyze Layout", variant="primary", scale=1)
|
| 167 |
|
| 168 |
+
# --- Advanced Settings Panel ---
|
| 169 |
with gr.Accordion("Advanced Settings", open=False):
|
| 170 |
+
# NEW: Model Selector
|
| 171 |
+
model_selector = gr.Radio(
|
| 172 |
+
choices=list(MODEL_CHOICES.keys()),
|
| 173 |
+
value=DEFAULT_MODEL_NAME,
|
| 174 |
+
label="Select Model",
|
| 175 |
+
info="Choose which model to use for inference."
|
| 176 |
+
)
|
| 177 |
prompt_textbox = gr.Textbox(
|
| 178 |
label="Prompt",
|
| 179 |
value=DEFAULT_PROMPT,
|
|
|
|
| 203 |
examples=[["1.png"], ["2.png"], ["10.png"], ["11.png"], ["3.png"], ["7.png"], ["8.png"]],
|
| 204 |
inputs=[input_image],
|
| 205 |
label="Examples (Click to Run)",
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
gr.Markdown("<p style='text-align:center; color:grey;'>Powered by the Latex2Layout dataset generated by Feijiang Han</p>")
|
|
|
|
| 210 |
# --- Event Handlers ---
|
| 211 |
analyze_btn.click(
|
| 212 |
fn=analyze_and_visualize_layout,
|
| 213 |
+
inputs=[input_image, model_selector, prompt_textbox, temp_slider, top_p_slider],
|
| 214 |
outputs=[output_image, output_text]
|
| 215 |
)
|
| 216 |
|