spark / app.py
yuhangzang
update
fd68401
import os
import time
import glob
from typing import List
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B")
DTYPE = torch.bfloat16
_model = None
_processor = None
_attn_impl = None
def _load_model_and_processor():
global _model, _processor, _attn_impl
if _model is not None and _processor is not None:
return _model, _processor
# Prefer flash-attn if available, otherwise fall back to eager.
attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")
try:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
# `torch_dtype` was deprecated in Transformers; use `dtype` instead.
dtype=DTYPE,
attn_implementation=attn_impl,
device_map="auto",
)
_attn_impl = attn_impl
except (ImportError, ValueError, RuntimeError):
# Fallback for environments without flash-attn
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
# Use the new `dtype` kwarg for consistency with deprecations
dtype=DTYPE,
attn_implementation="eager",
device_map="auto",
)
_attn_impl = "eager"
processor = AutoProcessor.from_pretrained(MODEL_ID)
_model = model
_processor = processor
return _model, _processor
# Optionally preload the model at app startup so first click is fast.
# - On ZeroGPU, this will instantiate on CPU (no GPU at startup), so the
# first generate only needs to move tensors to CUDA.
# - You can disable by setting env `PRELOAD_MODEL=0`.
if os.environ.get("PRELOAD_MODEL", "1") not in ("0", "false", "False"):
try:
_load_model_and_processor()
print(f"[preload] Loaded {MODEL_ID} (attn_impl={_attn_impl})", flush=True)
except Exception as e:
# Don't block app if preload fails; fallback to lazy load on first call
print(f"[preload] Skipped due to: {type(e).__name__}: {e}", flush=True)
def _prepare_inputs(image, prompt):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
chat_text = _processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = _processor(
text=[chat_text],
# Pass the single image directly; template contains <image> placeholder
images=[image] if image is not None else None,
return_tensors="pt",
)
return inputs
def _decode(generated_ids, input_ids):
# Trim the prompt part before decoding
trimmed = generated_ids[:, input_ids.shape[1] :]
out = _processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return out[0].strip() if out else ""
@spaces.GPU(duration=120)
def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
if image is None:
return "Please upload an image."
prompt = (prompt or "").strip()
if not prompt:
return "Please enter a prompt."
start = time.time()
model, _ = _load_model_and_processor()
try:
# Ensure model resides on GPU during the call
dev = next(model.parameters()).device
if dev.type != "cuda":
model.to("cuda")
dev = torch.device("cuda")
except StopIteration:
dev = torch.device("cuda")
try:
inputs = _prepare_inputs(image, prompt)
inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"use_cache": True,
}
with torch.inference_mode():
out_ids = model.generate(**inputs, **gen_kwargs)
text = _decode(out_ids, inputs["input_ids"])
took = time.time() - start
return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]"
except Exception as e:
return f"Inference failed: {type(e).__name__}: {e}"
finally:
# Release GPU memory cache for ZeroGPU
try:
torch.cuda.empty_cache()
except Exception:
pass
def build_ui():
with gr.Blocks() as demo:
gr.Markdown(
"""
# Spark: Synergistic Policy And Reward Co-Evolving Framework
<h3 align="center">
πŸ“–<a href="https://arxiv.org/abs/2509.22624">Paper</a>
| πŸ€—<a href="https://huggingface.co/internlm/Spark-VL-7B">Models</a>
| πŸ€—<a href="https://huggingface.co/datasets/internlm/Spark-Data">Datasets</a>
| πŸ€—<a href="https://huggingface.co/papers/2509.22624">Daily Paper</a>
</h3>
**🌈 Introduction:** We propose SPARK, <strong>a unified framework that integrates policy and reward into a single model for joint and synchronous training</strong>. SPARK can automatically derive reward and reflection data from verifiable reward, enabling <strong>self-learning and self-evolution</strong>.
**πŸ€— Models:** We release the checkpoints at [internlm/Spark-VL-7B](https://huggingface.co/internlm/Spark-VL-7B).
**πŸ€— Datasets:** Training data is available at [internlm/Spark-Data](https://huggingface.co/datasets/internlm/Spark-Data).
**πŸ’» Training Code:** The training code and implementation details can be found at [InternLM/Spark](https://github.com/InternLM/Spark).
---
πŸ“Έ **Upload an image and enter a prompt** or πŸ–ΌοΈ **choose the input from the example gallery** (image + prompt).
"""
)
# Build an image+prompt gallery from ./examples
# Each example is an image file with an optional sidecar .txt containing the prompt.
# If a .txt is present (same basename), we will display a caption and load the
# prompt alongside the image when the thumbnail is selected.
def _gather_examples() -> List[tuple]:
pairs = [] # (image_path, prompt_text)
imgs = []
for ext in ("jpg", "jpeg", "png", "webp"):
imgs.extend(glob.glob(os.path.join("examples", f"*.{ext}")))
# Deduplicate while keeping order
for img_path in list(dict.fromkeys(sorted(imgs))):
stem, _ = os.path.splitext(img_path)
prompt_path = stem + ".txt"
prompt_text = None
if os.path.exists(prompt_path):
try:
with open(prompt_path, "r", encoding="utf-8") as fh:
prompt_text = fh.read().strip()
except Exception:
prompt_text = None
pairs.append((img_path, prompt_text))
return pairs
example_pairs = _gather_examples()
# Load default image if exists
default_path = os.path.join("examples", "example_0.png")
default_image = Image.open(default_path) if os.path.exists(default_path) else None
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Image", value=default_image)
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
value=(
"As seen in the diagram, three darts are thrown at nine fixed balloons. "
"If a balloon is hit it will burst and the dart continues in the same direction "
"it had beforehand. How many balloons will not be hit by a dart?"
),
lines=4,
)
max_new_tokens = gr.Slider(512, 4096, value=1024, step=8, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
run = gr.Button("Generate")
# Clear prompt when image is removed
image.clear(fn=lambda: "", outputs=prompt)
# Examples section: table-like layout with image and prompt columns
gr.Markdown("## Examples")
# Handler for clicking on example images
def _on_example_click(img_path, prompt_text):
try:
img_val = Image.open(img_path)
except Exception:
img_val = None
return img_val, prompt_text
# Categorize examples by type
math_examples = []
reward_examples = []
other_examples = []
for img_path, prompt_text in example_pairs:
basename = os.path.basename(img_path)
if basename.startswith("example_0"):
math_examples.append((img_path, prompt_text))
elif basename.startswith("example_1"):
reward_examples.append((img_path, prompt_text))
else:
other_examples.append((img_path, prompt_text))
# Display math reasoning examples
if math_examples:
gr.Markdown("### πŸ“ Math Reasoning Examples")
for idx, (img_path, prompt_text) in enumerate(math_examples):
with gr.Row():
with gr.Column(scale=1):
ex_img = gr.Image(
value=img_path,
type="filepath",
label=f"Math Example {idx}",
interactive=False,
show_label=True,
height=200,
)
# Wire click event to load the example
ex_img.select(
fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
outputs=[image, prompt],
)
with gr.Column(scale=3):
ex_text = gr.Textbox(
value=prompt_text or "",
label="Prompt",
lines=8,
max_lines=8,
interactive=False,
show_label=True,
)
# Display reward model examples
if reward_examples:
gr.Markdown("### 🎯 Reward Model Examples")
for idx, (img_path, prompt_text) in enumerate(reward_examples):
with gr.Row():
with gr.Column(scale=1):
ex_img = gr.Image(
value=img_path,
type="filepath",
label=f"Reward Example {idx}",
interactive=False,
show_label=True,
height=200,
)
# Wire click event to load the example
ex_img.select(
fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
outputs=[image, prompt],
)
with gr.Column(scale=3):
ex_text = gr.Textbox(
value=prompt_text or "",
label="Prompt",
lines=8,
max_lines=8,
interactive=False,
show_label=True,
)
# Display other examples if any
if other_examples:
gr.Markdown("### πŸ“‹ Other Examples")
for idx, (img_path, prompt_text) in enumerate(other_examples):
with gr.Row():
with gr.Column(scale=1):
ex_img = gr.Image(
value=img_path,
type="filepath",
label=f"Example {idx}",
interactive=False,
show_label=True,
height=200,
)
# Wire click event to load the example
ex_img.select(
fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
outputs=[image, prompt],
)
with gr.Column(scale=3):
ex_text = gr.Textbox(
value=prompt_text or "",
label="Prompt",
lines=8,
max_lines=8,
interactive=False,
show_label=True,
)
output = gr.Textbox(label="Model Output", lines=8)
run.click(
fn=generate,
inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k],
outputs=output,
show_progress=True,
)
# Citation section at the bottom
gr.Markdown(
"""
---
If you find this project useful, please kindly cite:
```bibtex
@article{liu2025spark,
title={SPARK: Synergistic Policy And Reward Co-Evolving Framework},
author={Liu, Ziyu and Zang, Yuhang and Ding, Shengyuan and Cao, Yuhang and Dong, Xiaoyi and Duan, Haodong and Lin, Dahua and Wang, Jiaqi},
journal={arXiv preprint arXiv:2509.22624},
year={2025}
}
```
"""
)
demo.queue(max_size=10).launch()
return demo
if __name__ == "__main__":
build_ui()