Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| MODEL_ID = "internlm/CapRL-3B" | |
| DEFAULT_PROMPT = "Describe the image in detail." | |
| MAX_NEW_TOKENS = 128 | |
| def get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def select_dtype(device: str): | |
| if device == "cuda": | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| return torch.float16 | |
| return torch.float32 | |
| def load_model(): | |
| device = get_device() | |
| dtype = select_dtype(device) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map="auto" if device == "cuda" else None, | |
| trust_remote_code=True, | |
| ) | |
| if device != "cuda": | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| return model, processor | |
| MODEL, PROCESSOR = load_model() | |
| def generate_caption(image: Image.Image): | |
| if image is None: | |
| return "", 0 | |
| device = MODEL.device | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": DEFAULT_PROMPT}, | |
| ], | |
| } | |
| ] | |
| prompt_text = PROCESSOR.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = PROCESSOR( | |
| text=[prompt_text], | |
| images=[image], | |
| return_tensors="pt", | |
| ).to(device) | |
| output_ids = MODEL.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| ) | |
| generated_text = PROCESSOR.batch_decode( | |
| output_ids, skip_special_tokens=True | |
| )[0] | |
| processed_outputs = PROCESSOR.post_process_generation( | |
| generated_text, | |
| messages, | |
| ) | |
| caption = processed_outputs[0].get("generated_text", generated_text).strip() | |
| input_ids = inputs.get("input_ids") | |
| input_length = input_ids.shape[-1] if input_ids is not None else 0 | |
| total_length = output_ids.shape[-1] | |
| num_generated_tokens = max(total_length - input_length, 0) | |
| return caption, int(num_generated_tokens) | |
| with gr.Blocks(title="CapRL Image Captioning") as demo: | |
| gr.Markdown("# CapRL Image Captioning\nUpload an image to generate a caption with CapRL-3B.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| generate_button = gr.Button("Generate Caption") | |
| with gr.Column(): | |
| caption_output = gr.Textbox(label="Caption", lines=6) | |
| token_output = gr.Number(label="Generated Tokens", precision=0) | |
| generate_button.click( | |
| fn=generate_caption, | |
| inputs=image_input, | |
| outputs=[caption_output, token_output], | |
| ) | |
| image_input.upload( | |
| fn=generate_caption, | |
| inputs=image_input, | |
| outputs=[caption_output, token_output], | |
| ) | |
| demo.launch() | |