import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_ID = "internlm/CapRL-3B"
DEFAULT_PROMPT = "Describe the image in detail."
MAX_NEW_TOKENS = 4096
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def select_dtype(device: str):
if device == "cuda":
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
return torch.float32
def load_model():
device = get_device()
dtype = select_dtype(device)
# Use device_map="auto" for proper GPU allocation with spaces.GPU decorator
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
return model, processor
MODEL, PROCESSOR = load_model()
@spaces.GPU
@torch.inference_mode()
def generate_caption(image: Image.Image):
if image is None:
return "", 0
try:
# Validate image
if not isinstance(image, Image.Image):
return "Error: Invalid image format", 0
# Check image size (warn if too large)
max_size = 4096
if image.width > max_size or image.height > max_size:
# Resize if too large to prevent OOM
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
device = MODEL.device
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": DEFAULT_PROMPT},
],
}
]
prompt_text = PROCESSOR.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = PROCESSOR(
text=[prompt_text],
images=[image],
return_tensors="pt",
).to(device)
generated_ids = MODEL.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = PROCESSOR.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
caption = output_text[0].strip()
input_ids = inputs.get("input_ids")
input_length = input_ids.shape[-1] if input_ids is not None else 0
total_length = generated_ids.shape[-1]
num_generated_tokens = max(total_length - input_length, 0)
return caption, int(num_generated_tokens)
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
return "Error: Out of GPU memory. Please try with a smaller image.", 0
except Exception as e:
return f"Error generating caption: {str(e)}", 0
with gr.Blocks(title="CapRL Image Captioning") as demo:
gr.Markdown("# 🎨 CapRL for Image Captioning")
gr.Markdown("### CapRL: Stimulating Dense Image Caption Capabilities via Reinforcement Learning")
gr.Markdown("✨ Upload an image to generate a detailed caption with CapRL-3B! ✨")
gr.Markdown(
"""
📖 Paper | 🏠 Github | 🤗 CapRL-3B Model | 🤗 CapRL-InternVL3.5-8B Model |
🤗 CapRL-2M Dataset
🤗 CapRL Collection | 📰 Daily Paper | 💾 CapRL-3B-GGUF | 💾 CapRL-3B-i1-GGUF
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
generate_button = gr.Button("Generate Caption")
with gr.Column():
caption_output = gr.Textbox(label="Caption", lines=6)
token_output = gr.Number(label="Generated Tokens", precision=0)
generate_button.click(
fn=generate_caption,
inputs=image_input,
outputs=[caption_output, token_output],
show_progress=True,
)
image_input.upload(
fn=generate_caption,
inputs=image_input,
outputs=[caption_output, token_output],
show_progress=True,
)
gr.Examples(
examples=[
["./examples/example_chinese.png"],
["./examples/example_receipt.jpg"],
["./examples/example_table.png"],
],
inputs=image_input,
outputs=[caption_output, token_output],
fn=generate_caption,
cache_examples=True,
label="📸 Example Images"
)
gr.Markdown("### Citation")
gr.Markdown("If you find this project useful, please kindly cite:")
citation_text = """@article{xing2025caprl,
title={{CapRL}: Stimulating Dense Image Caption Capabilities via Reinforcement Learning},
author={Xing, Long and Dong, Xiaoyi and Zang, Yuhang and Cao, Yuhang and Liang, Jianze and Huang, Qidong and Wang, Jiaqi and Wu, Feng and Lin, Dahua},
journal={arXiv preprint arXiv:2509.22647},
year={2025}
}"""
gr.Code(value=citation_text, language="markdown", label="BibTeX Citation")
demo.launch()