Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import glob | |
| from typing import List | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B") | |
| DTYPE = torch.bfloat16 | |
| _model = None | |
| _processor = None | |
| _attn_impl = None | |
| def _load_model_and_processor(): | |
| global _model, _processor, _attn_impl | |
| if _model is not None and _processor is not None: | |
| return _model, _processor | |
| # Prefer flash-attn if available, otherwise fall back to eager. | |
| attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2") | |
| try: | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| # `torch_dtype` was deprecated in Transformers; use `dtype` instead. | |
| dtype=DTYPE, | |
| attn_implementation=attn_impl, | |
| device_map="auto", | |
| ) | |
| _attn_impl = attn_impl | |
| except (ImportError, ValueError, RuntimeError): | |
| # Fallback for environments without flash-attn | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| # Use the new `dtype` kwarg for consistency with deprecations | |
| dtype=DTYPE, | |
| attn_implementation="eager", | |
| device_map="auto", | |
| ) | |
| _attn_impl = "eager" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| _model = model | |
| _processor = processor | |
| return _model, _processor | |
| # Optionally preload the model at app startup so first click is fast. | |
| # - On ZeroGPU, this will instantiate on CPU (no GPU at startup), so the | |
| # first generate only needs to move tensors to CUDA. | |
| # - You can disable by setting env `PRELOAD_MODEL=0`. | |
| if os.environ.get("PRELOAD_MODEL", "1") not in ("0", "false", "False"): | |
| try: | |
| _load_model_and_processor() | |
| print(f"[preload] Loaded {MODEL_ID} (attn_impl={_attn_impl})", flush=True) | |
| except Exception as e: | |
| # Don't block app if preload fails; fallback to lazy load on first call | |
| print(f"[preload] Skipped due to: {type(e).__name__}: {e}", flush=True) | |
| def _prepare_inputs(image, prompt): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| chat_text = _processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = _processor( | |
| text=[chat_text], | |
| # Pass the single image directly; template contains <image> placeholder | |
| images=[image] if image is not None else None, | |
| return_tensors="pt", | |
| ) | |
| return inputs | |
| def _decode(generated_ids, input_ids): | |
| # Trim the prompt part before decoding | |
| trimmed = generated_ids[:, input_ids.shape[1] :] | |
| out = _processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return out[0].strip() if out else "" | |
| def generate(image, prompt, max_new_tokens, temperature, top_p, top_k): | |
| if image is None: | |
| return "Please upload an image." | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| return "Please enter a prompt." | |
| start = time.time() | |
| model, _ = _load_model_and_processor() | |
| try: | |
| # Ensure model resides on GPU during the call | |
| dev = next(model.parameters()).device | |
| if dev.type != "cuda": | |
| model.to("cuda") | |
| dev = torch.device("cuda") | |
| except StopIteration: | |
| dev = torch.device("cuda") | |
| try: | |
| inputs = _prepare_inputs(image, prompt) | |
| inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()} | |
| gen_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": True, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "top_k": int(top_k), | |
| "use_cache": True, | |
| } | |
| with torch.inference_mode(): | |
| out_ids = model.generate(**inputs, **gen_kwargs) | |
| text = _decode(out_ids, inputs["input_ids"]) | |
| took = time.time() - start | |
| return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]" | |
| except Exception as e: | |
| return f"Inference failed: {type(e).__name__}: {e}" | |
| finally: | |
| # Release GPU memory cache for ZeroGPU | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Spark: Synergistic Policy And Reward Co-Evolving Framework | |
| <h3 align="center"> | |
| π<a href="https://arxiv.org/abs/2509.22624">Paper</a> | |
| | π€<a href="https://huggingface.co/internlm/Spark-VL-7B">Models</a> | |
| | π€<a href="https://huggingface.co/datasets/internlm/Spark-Data">Datasets</a> | |
| | π€<a href="https://huggingface.co/papers/2509.22624">Daily Paper</a> | |
| </h3> | |
| **π Introduction:** We propose SPARK, <strong>a unified framework that integrates policy and reward into a single model for joint and synchronous training</strong>. SPARK can automatically derive reward and reflection data from verifiable reward, enabling <strong>self-learning and self-evolution</strong>. | |
| **π€ Models:** We release the checkpoints at [internlm/Spark-VL-7B](https://huggingface.co/internlm/Spark-VL-7B). | |
| **π€ Datasets:** Training data is available at [internlm/Spark-Data](https://huggingface.co/datasets/internlm/Spark-Data). | |
| **π» Training Code:** The training code and implementation details can be found at [InternLM/Spark](https://github.com/InternLM/Spark). | |
| --- | |
| πΈ **Upload an image and enter a prompt** or πΌοΈ **choose the input from the example gallery** (image + prompt). | |
| """ | |
| ) | |
| # Build an image+prompt gallery from ./examples | |
| # Each example is an image file with an optional sidecar .txt containing the prompt. | |
| # If a .txt is present (same basename), we will display a caption and load the | |
| # prompt alongside the image when the thumbnail is selected. | |
| def _gather_examples() -> List[tuple]: | |
| pairs = [] # (image_path, prompt_text) | |
| imgs = [] | |
| for ext in ("jpg", "jpeg", "png", "webp"): | |
| imgs.extend(glob.glob(os.path.join("examples", f"*.{ext}"))) | |
| # Deduplicate while keeping order | |
| for img_path in list(dict.fromkeys(sorted(imgs))): | |
| stem, _ = os.path.splitext(img_path) | |
| prompt_path = stem + ".txt" | |
| prompt_text = None | |
| if os.path.exists(prompt_path): | |
| try: | |
| with open(prompt_path, "r", encoding="utf-8") as fh: | |
| prompt_text = fh.read().strip() | |
| except Exception: | |
| prompt_text = None | |
| pairs.append((img_path, prompt_text)) | |
| return pairs | |
| example_pairs = _gather_examples() | |
| # Load default image if exists | |
| default_path = os.path.join("examples", "example_0.png") | |
| default_image = Image.open(default_path) if os.path.exists(default_path) else None | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(type="pil", label="Image", value=default_image) | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value=( | |
| "As seen in the diagram, three darts are thrown at nine fixed balloons. " | |
| "If a balloon is hit it will burst and the dart continues in the same direction " | |
| "it had beforehand. How many balloons will not be hit by a dart?" | |
| ), | |
| lines=4, | |
| ) | |
| max_new_tokens = gr.Slider(512, 4096, value=1024, step=8, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p") | |
| top_k = gr.Slider(1, 200, value=50, step=1, label="top_k") | |
| run = gr.Button("Generate") | |
| # Clear prompt when image is removed | |
| image.clear(fn=lambda: "", outputs=prompt) | |
| # Examples section: table-like layout with image and prompt columns | |
| gr.Markdown("## Examples") | |
| # Handler for clicking on example images | |
| def _on_example_click(img_path, prompt_text): | |
| try: | |
| img_val = Image.open(img_path) | |
| except Exception: | |
| img_val = None | |
| return img_val, prompt_text | |
| # Categorize examples by type | |
| math_examples = [] | |
| reward_examples = [] | |
| other_examples = [] | |
| for img_path, prompt_text in example_pairs: | |
| basename = os.path.basename(img_path) | |
| if basename.startswith("example_0"): | |
| math_examples.append((img_path, prompt_text)) | |
| elif basename.startswith("example_1"): | |
| reward_examples.append((img_path, prompt_text)) | |
| else: | |
| other_examples.append((img_path, prompt_text)) | |
| # Display math reasoning examples | |
| if math_examples: | |
| gr.Markdown("### π Math Reasoning Examples") | |
| for idx, (img_path, prompt_text) in enumerate(math_examples): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ex_img = gr.Image( | |
| value=img_path, | |
| type="filepath", | |
| label=f"Math Example {idx}", | |
| interactive=False, | |
| show_label=True, | |
| height=200, | |
| ) | |
| # Wire click event to load the example | |
| ex_img.select( | |
| fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt), | |
| outputs=[image, prompt], | |
| ) | |
| with gr.Column(scale=3): | |
| ex_text = gr.Textbox( | |
| value=prompt_text or "", | |
| label="Prompt", | |
| lines=8, | |
| max_lines=8, | |
| interactive=False, | |
| show_label=True, | |
| ) | |
| # Display reward model examples | |
| if reward_examples: | |
| gr.Markdown("### π― Reward Model Examples") | |
| for idx, (img_path, prompt_text) in enumerate(reward_examples): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ex_img = gr.Image( | |
| value=img_path, | |
| type="filepath", | |
| label=f"Reward Example {idx}", | |
| interactive=False, | |
| show_label=True, | |
| height=200, | |
| ) | |
| # Wire click event to load the example | |
| ex_img.select( | |
| fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt), | |
| outputs=[image, prompt], | |
| ) | |
| with gr.Column(scale=3): | |
| ex_text = gr.Textbox( | |
| value=prompt_text or "", | |
| label="Prompt", | |
| lines=8, | |
| max_lines=8, | |
| interactive=False, | |
| show_label=True, | |
| ) | |
| # Display other examples if any | |
| if other_examples: | |
| gr.Markdown("### π Other Examples") | |
| for idx, (img_path, prompt_text) in enumerate(other_examples): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ex_img = gr.Image( | |
| value=img_path, | |
| type="filepath", | |
| label=f"Example {idx}", | |
| interactive=False, | |
| show_label=True, | |
| height=200, | |
| ) | |
| # Wire click event to load the example | |
| ex_img.select( | |
| fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt), | |
| outputs=[image, prompt], | |
| ) | |
| with gr.Column(scale=3): | |
| ex_text = gr.Textbox( | |
| value=prompt_text or "", | |
| label="Prompt", | |
| lines=8, | |
| max_lines=8, | |
| interactive=False, | |
| show_label=True, | |
| ) | |
| output = gr.Textbox(label="Model Output", lines=8) | |
| run.click( | |
| fn=generate, | |
| inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k], | |
| outputs=output, | |
| show_progress=True, | |
| ) | |
| # Citation section at the bottom | |
| gr.Markdown( | |
| """ | |
| --- | |
| If you find this project useful, please kindly cite: | |
| ```bibtex | |
| @article{liu2025spark, | |
| title={SPARK: Synergistic Policy And Reward Co-Evolving Framework}, | |
| author={Liu, Ziyu and Zang, Yuhang and Ding, Shengyuan and Cao, Yuhang and Dong, Xiaoyi and Duan, Haodong and Lin, Dahua and Wang, Jiaqi}, | |
| journal={arXiv preprint arXiv:2509.22624}, | |
| year={2025} | |
| } | |
| ``` | |
| """ | |
| ) | |
| demo.queue(max_size=10).launch() | |
| return demo | |
| if __name__ == "__main__": | |
| build_ui() | |