import gradio as gr import spaces import torch from PIL import Image from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration MODEL_ID = "internlm/CapRL-3B" DEFAULT_PROMPT = "Describe the image in detail." MAX_NEW_TOKENS = 4096 def get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def select_dtype(device: str): if device == "cuda": if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 return torch.float32 def load_model(): device = get_device() dtype = select_dtype(device) # Use device_map="auto" for proper GPU allocation with spaces.GPU decorator model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True, ) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) return model, processor MODEL, PROCESSOR = load_model() @spaces.GPU @torch.inference_mode() def generate_caption(image: Image.Image): if image is None: return "", 0 try: # Validate image if not isinstance(image, Image.Image): return "Error: Invalid image format", 0 # Check image size (warn if too large) max_size = 4096 if image.width > max_size or image.height > max_size: # Resize if too large to prevent OOM image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) device = MODEL.device messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": DEFAULT_PROMPT}, ], } ] prompt_text = PROCESSOR.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = PROCESSOR( text=[prompt_text], images=[image], return_tensors="pt", ).to(device) generated_ids = MODEL.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = PROCESSOR.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) caption = output_text[0].strip() input_ids = inputs.get("input_ids") input_length = input_ids.shape[-1] if input_ids is not None else 0 total_length = generated_ids.shape[-1] num_generated_tokens = max(total_length - input_length, 0) return caption, int(num_generated_tokens) except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() return "Error: Out of GPU memory. Please try with a smaller image.", 0 except Exception as e: return f"Error generating caption: {str(e)}", 0 with gr.Blocks(title="CapRL Image Captioning") as demo: gr.Markdown("# 🎨 CapRL for Image Captioning") gr.Markdown("### CapRL: Stimulating Dense Image Caption Capabilities via Reinforcement Learning") gr.Markdown("✨ Upload an image to generate a detailed caption with CapRL-3B! ✨") gr.Markdown( """ 📖 Paper | 🏠 Github | 🤗 CapRL-3B Model | 🤗 CapRL-InternVL3.5-8B Model | 🤗 CapRL-2M Dataset 🤗 CapRL Collection | 📰 Daily Paper | 💾 CapRL-3B-GGUF | 💾 CapRL-3B-i1-GGUF """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") generate_button = gr.Button("Generate Caption") with gr.Column(): caption_output = gr.Textbox(label="Caption", lines=6) token_output = gr.Number(label="Generated Tokens", precision=0) generate_button.click( fn=generate_caption, inputs=image_input, outputs=[caption_output, token_output], show_progress=True, ) image_input.upload( fn=generate_caption, inputs=image_input, outputs=[caption_output, token_output], show_progress=True, ) gr.Examples( examples=[ ["./examples/example_chinese.png"], ["./examples/example_receipt.jpg"], ["./examples/example_table.png"], ], inputs=image_input, outputs=[caption_output, token_output], fn=generate_caption, cache_examples=True, label="📸 Example Images" ) gr.Markdown("### Citation") gr.Markdown("If you find this project useful, please kindly cite:") citation_text = """@article{xing2025caprl, title={{CapRL}: Stimulating Dense Image Caption Capabilities via Reinforcement Learning}, author={Xing, Long and Dong, Xiaoyi and Zang, Yuhang and Cao, Yuhang and Liang, Jianze and Huang, Qidong and Wang, Jiaqi and Wu, Feng and Lin, Dahua}, journal={arXiv preprint arXiv:2509.22647}, year={2025} }""" gr.Code(value=citation_text, language="markdown", label="BibTeX Citation") demo.launch()