Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 3,054 Bytes
			
			2a6eacb 483edf4 2a6eacb 483edf4 2a6eacb 483edf4 2a6eacb 483edf4  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121  | 
								import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_ID = "internlm/CapRL-3B"
DEFAULT_PROMPT = "Describe the image in detail."
MAX_NEW_TOKENS = 128
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"
def select_dtype(device: str):
    if device == "cuda":
        if torch.cuda.is_bf16_supported():
            return torch.bfloat16
        return torch.float16
    return torch.float32
def load_model():
    device = get_device()
    dtype = select_dtype(device)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto" if device == "cuda" else None,
        trust_remote_code=True,
    )
    if device != "cuda":
        model.to(device)
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    return model, processor
MODEL, PROCESSOR = load_model()
@spaces.GPU
@torch.inference_mode()
def generate_caption(image: Image.Image):
    if image is None:
        return "", 0
    device = MODEL.device
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": DEFAULT_PROMPT},
            ],
        }
    ]
    prompt_text = PROCESSOR.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = PROCESSOR(
        text=[prompt_text],
        images=[image],
        return_tensors="pt",
    ).to(device)
    output_ids = MODEL.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
    )
    generated_text = PROCESSOR.batch_decode(
        output_ids, skip_special_tokens=True
    )[0]
    processed_outputs = PROCESSOR.post_process_generation(
        generated_text,
        messages,
    )
    caption = processed_outputs[0].get("generated_text", generated_text).strip()
    input_ids = inputs.get("input_ids")
    input_length = input_ids.shape[-1] if input_ids is not None else 0
    total_length = output_ids.shape[-1]
    num_generated_tokens = max(total_length - input_length, 0)
    return caption, int(num_generated_tokens)
with gr.Blocks(title="CapRL Image Captioning") as demo:
    gr.Markdown("# CapRL Image Captioning\nUpload an image to generate a caption with CapRL-3B.")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Input Image")
            generate_button = gr.Button("Generate Caption")
        with gr.Column():
            caption_output = gr.Textbox(label="Caption", lines=6)
            token_output = gr.Number(label="Generated Tokens", precision=0)
    generate_button.click(
        fn=generate_caption,
        inputs=image_input,
        outputs=[caption_output, token_output],
    )
    image_input.upload(
        fn=generate_caption,
        inputs=image_input,
        outputs=[caption_output, token_output],
    )
demo.launch()
 |