caprl / app.py
yuhangzang
update
483edf4
raw
history blame
3.05 kB
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_ID = "internlm/CapRL-3B"
DEFAULT_PROMPT = "Describe the image in detail."
MAX_NEW_TOKENS = 128
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def select_dtype(device: str):
if device == "cuda":
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
return torch.float32
def load_model():
device = get_device()
dtype = select_dtype(device)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True,
)
if device != "cuda":
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
return model, processor
MODEL, PROCESSOR = load_model()
@spaces.GPU
@torch.inference_mode()
def generate_caption(image: Image.Image):
if image is None:
return "", 0
device = MODEL.device
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": DEFAULT_PROMPT},
],
}
]
prompt_text = PROCESSOR.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = PROCESSOR(
text=[prompt_text],
images=[image],
return_tensors="pt",
).to(device)
output_ids = MODEL.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated_text = PROCESSOR.batch_decode(
output_ids, skip_special_tokens=True
)[0]
processed_outputs = PROCESSOR.post_process_generation(
generated_text,
messages,
)
caption = processed_outputs[0].get("generated_text", generated_text).strip()
input_ids = inputs.get("input_ids")
input_length = input_ids.shape[-1] if input_ids is not None else 0
total_length = output_ids.shape[-1]
num_generated_tokens = max(total_length - input_length, 0)
return caption, int(num_generated_tokens)
with gr.Blocks(title="CapRL Image Captioning") as demo:
gr.Markdown("# CapRL Image Captioning\nUpload an image to generate a caption with CapRL-3B.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
generate_button = gr.Button("Generate Caption")
with gr.Column():
caption_output = gr.Textbox(label="Caption", lines=6)
token_output = gr.Number(label="Generated Tokens", precision=0)
generate_button.click(
fn=generate_caption,
inputs=image_input,
outputs=[caption_output, token_output],
)
image_input.upload(
fn=generate_caption,
inputs=image_input,
outputs=[caption_output, token_output],
)
demo.launch()