import gradio as gr from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer from threading import Thread import torch import spaces import os pretrained_model_name_or_path=os.environ.get("MODEL", "nvidia/NV-Reason-CXR-3B") auth_token = os.environ.get("HF_TOKEN") or True DEFAULT_PROMPT = "Find abnormalities and support devices." model = AutoModelForImageTextToText.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, dtype=torch.bfloat16, token=auth_token ).eval().to("cuda") processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, use_fast=True, ) @spaces.GPU def model_inference( text, history, image ): print(f"text: {text}") print(f"history: {history}") if len(text) == 0: raise gr.Error("Please input a query.", duration=3, print_exception=False) if image is None: raise gr.Error("Please provide an image.", duration=3, print_exception=False) # print(f"image0: {image} size: {image.size}") messages=[] if len(history) > 0: valid_index = None for i in range(len(history)): h = history[i] if len(h.get("content").strip()) > 0: if valid_index is None and h['role'] == 'assistant': valid_index = i-1 messages.append({"role": h['role'], "content": [{"type": "text", "text": h['content']}] }) if valid_index is None: messages = [] if len(messages) > 0 and valid_index > 0: messages = messages[valid_index:] #remove previous messages (without image) # current prompt messages.append({"role": "user","content": [{"type": "text", "text": text}]}) messages[0]['content'].insert(0, {"type": "image"}) print(f"messages: {messages}") prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[image], return_tensors="pt") inputs = inputs.to('cuda') # Generate streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_args = dict(inputs, streamer=streamer, max_new_tokens=4096) with torch.inference_mode(): thread = Thread(target=model.generate, kwargs=generation_args) thread.start() yield "..." buffer = "" for new_text in streamer: buffer += new_text yield buffer with gr.Blocks() as demo: gr.HTML('

NV-Reason-CXR-3B Demo. Check out the model card details here.

') send_btn = gr.Button("Send", variant="primary", render=False) textbox = gr.Textbox(show_label=False, placeholder="Enter your text here and press ENTER", render=False, submit_btn="Send") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", visible=True, sources="upload", show_label=False) clear_btn = gr.Button("Clear", variant="secondary") with gr.Accordion("Examples", open=True): ex =gr.Examples( examples=[ ["example_images/35.jpg", "Examine the chest X-ray."], ["example_images/363.jpg", "Provide a comprehensive image analysis, and list all abnormalities."], ["example_images/4747.jpg", "Find abnormalities and support devices."], ["example_images/87.jpg", "Find abnormalities and support devices."], ["example_images/6218.jpg", "Find abnormalities and support devices."], ["example_images/6447.jpg", "Find abnormalities and support devices."], ], inputs=[image_input, textbox], label=None, ) ex.dataset.show_label = False with gr.Column(scale=2): chat_interface = gr.ChatInterface(fn=model_inference, type="messages", chatbot=gr.Chatbot(type="messages", label="AI", render_markdown=True, sanitize_html=False, allow_tags=True, height='35vw', container=False, show_share_button=False), textbox=textbox, additional_inputs=image_input, multimodal=False, fill_height=False, show_api=False, ) gr.HTML('Start with a full prompt: Find abnormalities and support devices.
\ Follow up with additial questions, such as Provide differentials or Write a structured report.
') # Clear chat history when an example is selected (keep example-populated inputs intact) ex.load_input_event.then( lambda: ([], [], [], None), None, [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input], queue=False, show_api=False, ) # Clear chat history when a new image is uploaded via the image input image_input.upload( lambda: ([], [], [], None, DEFAULT_PROMPT), None, [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox], queue=False, show_api=False, ) # Clear everything on Clear button click clear_btn.click( lambda: ([], [], [], None, "", None), None, [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox, image_input], queue=False, show_api=False, ) demo.queue(max_size=10) demo.launch(debug=False, server_name="0.0.0.0")