yuhangzang commited on
Commit
483edf4
·
1 Parent(s): 2a6eacb
Files changed (3) hide show
  1. README.md +9 -0
  2. app.py +112 -7
  3. requirements.txt +6 -0
README.md CHANGED
@@ -12,3 +12,12 @@ short_description: Generate captions for images with CapRL
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ Citation:
17
+
18
+ @article{xing2025caprl,
19
+ title={CapRL: Stimulating Dense Image Caption Capabilities via Reinforcement Learning},
20
+ author={Xing, Long and Dong, Xiaoyi and Zang, Yuhang and Cao, Yuhang and Liang, Jianze and Huang, Qidong and Wang, Jiaqi and Wu, Feng and Lin, Dahua},
21
+ journal={arXiv preprint arXiv:2509.22647},
22
+ year={2025}
23
+ }
app.py CHANGED
@@ -1,15 +1,120 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
6
+
7
+ MODEL_ID = "internlm/CapRL-3B"
8
+ DEFAULT_PROMPT = "Describe the image in detail."
9
+ MAX_NEW_TOKENS = 128
10
+
11
+
12
+ def get_device() -> str:
13
+ return "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+
16
+ def select_dtype(device: str):
17
+ if device == "cuda":
18
+ if torch.cuda.is_bf16_supported():
19
+ return torch.bfloat16
20
+ return torch.float16
21
+ return torch.float32
22
+
23
+
24
+ def load_model():
25
+ device = get_device()
26
+ dtype = select_dtype(device)
27
+
28
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
+ MODEL_ID,
30
+ torch_dtype=dtype,
31
+ device_map="auto" if device == "cuda" else None,
32
+ trust_remote_code=True,
33
+ )
34
+
35
+ if device != "cuda":
36
+ model.to(device)
37
+
38
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
39
+ return model, processor
40
+
41
+
42
+ MODEL, PROCESSOR = load_model()
43
 
 
 
44
 
45
  @spaces.GPU
46
+ @torch.inference_mode()
47
+ def generate_caption(image: Image.Image):
48
+ if image is None:
49
+ return "", 0
50
 
51
+ device = MODEL.device
52
+ messages = [
53
+ {
54
+ "role": "user",
55
+ "content": [
56
+ {"type": "image"},
57
+ {"type": "text", "text": DEFAULT_PROMPT},
58
+ ],
59
+ }
60
+ ]
61
+
62
+ prompt_text = PROCESSOR.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True
64
+ )
65
 
66
+ inputs = PROCESSOR(
67
+ text=[prompt_text],
68
+ images=[image],
69
+ return_tensors="pt",
70
+ ).to(device)
71
+
72
+ output_ids = MODEL.generate(
73
+ **inputs,
74
+ max_new_tokens=MAX_NEW_TOKENS,
75
+ do_sample=False,
76
+ )
77
+
78
+ generated_text = PROCESSOR.batch_decode(
79
+ output_ids, skip_special_tokens=True
80
+ )[0]
81
+ processed_outputs = PROCESSOR.post_process_generation(
82
+ generated_text,
83
+ messages,
84
+ )
85
+
86
+ caption = processed_outputs[0].get("generated_text", generated_text).strip()
87
+
88
+ input_ids = inputs.get("input_ids")
89
+ input_length = input_ids.shape[-1] if input_ids is not None else 0
90
+ total_length = output_ids.shape[-1]
91
+ num_generated_tokens = max(total_length - input_length, 0)
92
+
93
+ return caption, int(num_generated_tokens)
94
+
95
+
96
+ with gr.Blocks(title="CapRL Image Captioning") as demo:
97
+ gr.Markdown("# CapRL Image Captioning\nUpload an image to generate a caption with CapRL-3B.")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ image_input = gr.Image(type="pil", label="Input Image")
102
+ generate_button = gr.Button("Generate Caption")
103
+ with gr.Column():
104
+ caption_output = gr.Textbox(label="Caption", lines=6)
105
+ token_output = gr.Number(label="Generated Tokens", precision=0)
106
+
107
+ generate_button.click(
108
+ fn=generate_caption,
109
+ inputs=image_input,
110
+ outputs=[caption_output, token_output],
111
+ )
112
+
113
+ image_input.upload(
114
+ fn=generate_caption,
115
+ inputs=image_input,
116
+ outputs=[caption_output, token_output],
117
+ )
118
+
119
+
120
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==5.49.1
2
+ spaces
3
+ transformers
4
+ torch
5
+ Pillow
6
+ sentencepiece