File size: 3,054 Bytes
2a6eacb
 
 
483edf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6eacb
 
 
483edf4
 
 
 
2a6eacb
483edf4
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6eacb
483edf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

MODEL_ID = "internlm/CapRL-3B"
DEFAULT_PROMPT = "Describe the image in detail."
MAX_NEW_TOKENS = 128


def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


def select_dtype(device: str):
    if device == "cuda":
        if torch.cuda.is_bf16_supported():
            return torch.bfloat16
        return torch.float16
    return torch.float32


def load_model():
    device = get_device()
    dtype = select_dtype(device)

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto" if device == "cuda" else None,
        trust_remote_code=True,
    )

    if device != "cuda":
        model.to(device)

    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    return model, processor


MODEL, PROCESSOR = load_model()


@spaces.GPU
@torch.inference_mode()
def generate_caption(image: Image.Image):
    if image is None:
        return "", 0

    device = MODEL.device
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": DEFAULT_PROMPT},
            ],
        }
    ]

    prompt_text = PROCESSOR.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = PROCESSOR(
        text=[prompt_text],
        images=[image],
        return_tensors="pt",
    ).to(device)

    output_ids = MODEL.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
    )

    generated_text = PROCESSOR.batch_decode(
        output_ids, skip_special_tokens=True
    )[0]
    processed_outputs = PROCESSOR.post_process_generation(
        generated_text,
        messages,
    )

    caption = processed_outputs[0].get("generated_text", generated_text).strip()

    input_ids = inputs.get("input_ids")
    input_length = input_ids.shape[-1] if input_ids is not None else 0
    total_length = output_ids.shape[-1]
    num_generated_tokens = max(total_length - input_length, 0)

    return caption, int(num_generated_tokens)


with gr.Blocks(title="CapRL Image Captioning") as demo:
    gr.Markdown("# CapRL Image Captioning\nUpload an image to generate a caption with CapRL-3B.")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Input Image")
            generate_button = gr.Button("Generate Caption")
        with gr.Column():
            caption_output = gr.Textbox(label="Caption", lines=6)
            token_output = gr.Number(label="Generated Tokens", precision=0)

    generate_button.click(
        fn=generate_caption,
        inputs=image_input,
        outputs=[caption_output, token_output],
    )

    image_input.upload(
        fn=generate_caption,
        inputs=image_input,
        outputs=[caption_output, token_output],
    )


demo.launch()