jocko commited on
Commit
319c9f8
·
1 Parent(s): 5e132cf

fix image similarity detection

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -23
src/streamlit_app.py CHANGED
@@ -143,8 +143,6 @@ TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
143
  st.title("🩺 Multimodal Medical Chatbot")
144
 
145
  query = st.text_input("Enter your medical question or symptom description:")
146
- uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
147
-
148
 
149
  if query:
150
  with st.spinner("Searching medical cases..."):
@@ -170,37 +168,42 @@ if query:
170
  else:
171
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
172
 
 
 
 
173
  if uploaded_files is not None:
174
  with st.spinner("Searching medical cases..."):
175
- st.write("Processing img")
176
  st.write(f"Number of files: {len(uploaded_files)}")
177
  for uploaded_file in uploaded_files:
178
  st.write(f"File name: {uploaded_file.name}")
179
 
180
- print(uploaded_files)
181
- uploaded_file = uploaded_files[0]
182
- st.write(f'uploading file {uploaded_file.name}')
183
- query_image = Image.open(uploaded_file).convert("RGB")
184
- st.image(query_image, caption="Your uploaded image", use_container_width=True)
185
 
186
- # Embed uploaded image
187
- inputs = clip_processor(images=query_image, return_tensors="pt")
188
- with torch.no_grad():
189
- query_feat = clip_model.get_image_features(**inputs)
190
- query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
 
 
 
 
 
 
 
191
 
192
- # Compute cosine similarity
193
- similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
194
 
195
- top_k = 3
196
- top_results = torch.topk(similarities, k=top_k)
197
 
198
- st.write(f"Top {top_k} similar medical cases:")
199
 
200
- for rank, idx in enumerate(top_results.indices):
201
- score = top_results.values[rank].item()
202
- similar_img = data[int(idx)]['image']
203
- st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
204
- st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
205
 
206
  st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
 
143
  st.title("🩺 Multimodal Medical Chatbot")
144
 
145
  query = st.text_input("Enter your medical question or symptom description:")
 
 
146
 
147
  if query:
148
  with st.spinner("Searching medical cases..."):
 
168
  else:
169
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
170
 
171
+
172
+ uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
173
+
174
  if uploaded_files is not None:
175
  with st.spinner("Searching medical cases..."):
 
176
  st.write(f"Number of files: {len(uploaded_files)}")
177
  for uploaded_file in uploaded_files:
178
  st.write(f"File name: {uploaded_file.name}")
179
 
180
+ st.write(uploaded_files)
 
 
 
 
181
 
182
+ if len(uploaded_files) > 0:
183
+ print(uploaded_files)
184
+ uploaded_file = uploaded_files[0]
185
+ st.write(f'uploading file {uploaded_file.name}')
186
+ query_image = Image.open(uploaded_file).convert("RGB")
187
+ st.image(query_image, caption="Your uploaded image", use_container_width=True)
188
+
189
+ # Embed uploaded image
190
+ inputs = clip_processor(images=query_image, return_tensors="pt")
191
+ with torch.no_grad():
192
+ query_feat = clip_model.get_image_features(**inputs)
193
+ query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
194
 
195
+ # Compute cosine similarity
196
+ similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
197
 
198
+ top_k = 3
199
+ top_results = torch.topk(similarities, k=top_k)
200
 
201
+ st.write(f"Top {top_k} similar medical cases:")
202
 
203
+ for rank, idx in enumerate(top_results.indices):
204
+ score = top_results.values[rank].item()
205
+ similar_img = data[int(idx)]['image']
206
+ st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
207
+ st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
208
 
209
  st.caption("This chatbot is for educational purposes only and does not provide medical advice.")