am commited on
Commit
945806b
·
1 Parent(s): 35cd62d
README.md CHANGED
@@ -1,13 +1,10 @@
1
  ---
2
- title: Nv Reason Cxr
3
- emoji: 📉
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
- pinned: false
10
- short_description: NV-Reason-CXR-3B Demo
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NV-Reason-CXR-3B Demo
3
+ emoji: 🩻
4
+ colorFrom: green
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
  app_file: app.py
9
+ license: apache-2.0
 
10
  ---
 
 
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
+ from threading import Thread
4
+ import torch
5
+ import spaces
6
+ import os
7
+
8
+ pretrained_model_name_or_path=os.environ.get("MODEL", "nvidia/NV-Reason-CXR-3B")
9
+
10
+ auth_token = os.environ.get("HF_TOKEN") or True
11
+ DEFAULT_PROMPT = "Find abnormalities and support devices."
12
+
13
+ model = AutoModelForImageTextToText.from_pretrained(
14
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
15
+ dtype=torch.bfloat16,
16
+ token=auth_token
17
+ ).eval().to("cuda")
18
+
19
+
20
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path,
21
+ use_fast=True,
22
+ )
23
+
24
+
25
+ @spaces.GPU
26
+ def model_inference(
27
+ text, history, image
28
+ ):
29
+
30
+ print(f"text: {text}")
31
+ print(f"history: {history}")
32
+
33
+ if len(text) == 0:
34
+ raise gr.Error("Please input a query.", duration=3, print_exception=False)
35
+
36
+ if image is None:
37
+ raise gr.Error("Please provide an image.", duration=3, print_exception=False)
38
+
39
+ # print(f"image0: {image} size: {image.size}")
40
+
41
+ messages=[]
42
+ if len(history) > 0:
43
+ valid_index = None
44
+ for i in range(len(history)):
45
+ h = history[i]
46
+ if len(h.get("content").strip()) > 0:
47
+ if valid_index is None and h['role'] == 'assistant':
48
+ valid_index = i-1
49
+ messages.append({"role": h['role'], "content": [{"type": "text", "text": h['content']}] })
50
+
51
+ if valid_index is None:
52
+ messages = []
53
+ if len(messages) > 0 and valid_index > 0:
54
+ messages = messages[valid_index:] #remove previous messages (without image)
55
+
56
+ # current prompt
57
+ messages.append({"role": "user","content": [{"type": "text", "text": text}]})
58
+ messages[0]['content'].insert(0, {"type": "image"})
59
+ print(f"messages: {messages}")
60
+
61
+
62
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
63
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
64
+ inputs = inputs.to('cuda')
65
+
66
+
67
+ # Generate
68
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
69
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=4096)
70
+
71
+ with torch.inference_mode():
72
+ thread = Thread(target=model.generate, kwargs=generation_args)
73
+ thread.start()
74
+
75
+ yield "..."
76
+ buffer = ""
77
+
78
+
79
+ for new_text in streamer:
80
+ buffer += new_text
81
+ yield buffer
82
+
83
+
84
+ with gr.Blocks() as demo:
85
+
86
+ # gr.Markdown('<h1 style="text-align:center; margin: 0.2em 0;">Demo.</h1>')
87
+ send_btn = gr.Button("Send", variant="primary", render=False)
88
+ textbox = gr.Textbox(show_label=False, placeholder="Enter your text here and press ENTER", render=False, submit_btn="Send")
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ image_input = gr.Image(type="pil", visible=True, sources="upload", show_label=False)
93
+
94
+ clear_btn = gr.Button("Clear", variant="secondary")
95
+
96
+ ex =gr.Examples(
97
+ examples=[
98
+ ["example_images/35.jpg", "Examine the chest X-ray."],
99
+ ["example_images/363.jpg", "Provide a comprehensive image analysis, and list all abnormalities."],
100
+ ["example_images/4747.jpg", "Find abnormalities and support devices."],
101
+ ["example_images/87.jpg", "Find abnormalities and support devices."],
102
+ ["example_images/6218.jpg", "Find abnormalities and support devices."],
103
+ ["example_images/6447.jpg", "Find abnormalities and support devices."],
104
+
105
+
106
+ ],
107
+ inputs=[image_input, textbox],
108
+ )
109
+
110
+ with gr.Column(scale=2):
111
+ chat_interface = gr.ChatInterface(fn=model_inference,
112
+ type="messages",
113
+ chatbot=gr.Chatbot(type="messages", label="AI", render_markdown=True, sanitize_html=False, allow_tags=True, height='35vw', container=False, show_share_button=False),
114
+ textbox=textbox,
115
+ additional_inputs=image_input,
116
+ multimodal=False,
117
+ fill_height=False,
118
+ show_api=False,
119
+ )
120
+ gr.HTML('<span style="color:lightgray">Start with a full prompt: Find abnormalities and support devices.<br>\
121
+ Follow up with additial questions, such as Provide differentials or Write a structured report.<br>')
122
+
123
+
124
+
125
+ # Clear chat history when an example is selected (keep example-populated inputs intact)
126
+ ex.load_input_event.then(
127
+ lambda: ([], [], [], None),
128
+ None,
129
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input],
130
+ queue=False,
131
+ show_api=False,
132
+ )
133
+
134
+ # Clear chat history when a new image is uploaded via the image input
135
+ image_input.upload(
136
+ lambda: ([], [], [], None, DEFAULT_PROMPT),
137
+ None,
138
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox],
139
+ queue=False,
140
+ show_api=False,
141
+ )
142
+
143
+ # Clear everything on Clear button click
144
+ clear_btn.click(
145
+ lambda: ([], [], [], None, "", None),
146
+ None,
147
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox, image_input],
148
+ queue=False,
149
+ show_api=False,
150
+ )
151
+
152
+
153
+
154
+ demo.queue(max_size=10)
155
+ demo.launch(debug=False, server_name="0.0.0.0")
156
+
example_images/35.jpg ADDED
example_images/363.jpg ADDED
example_images/4747.jpg ADDED
example_images/6218.jpg ADDED
example_images/6447.jpg ADDED
example_images/87.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.7.1
2
+ torchvision==0.22.1
3
+ transformers==4.56.0
4
+ huggingface_hub
5
+ gradio==5.44.1
6
+ spaces==0.40.1
7
+ # qwen_vl_utils==0.0.11
8
+