Spaces:
Sleeping
Sleeping
jocko
commited on
Commit
·
6f5e256
1
Parent(s):
7f5755e
fix image similarity detection
Browse files- src/streamlit_app.py +15 -14
src/streamlit_app.py
CHANGED
|
@@ -72,7 +72,22 @@ def load_medical_data():
|
|
| 72 |
)
|
| 73 |
return dataset
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
data = load_medical_data()
|
|
|
|
| 76 |
|
| 77 |
from openai import OpenAI
|
| 78 |
client = OpenAI(api_key=openai.api_key)
|
|
@@ -133,21 +148,7 @@ def get_similar_prompt(query):
|
|
| 133 |
idx = top_result.indices[0].item()
|
| 134 |
return data[idx]
|
| 135 |
|
| 136 |
-
# Cache dataset image embeddings (takes time, so cached)
|
| 137 |
-
@st.cache_data(show_spinner=True)
|
| 138 |
-
def embed_dataset_images(_dataset):
|
| 139 |
-
features = []
|
| 140 |
-
for item in _dataset:
|
| 141 |
-
# Load image from URL/path or raw bytes - adapt this if needed
|
| 142 |
-
img = item["image"]
|
| 143 |
-
inputs = clip_processor(images=img, return_tensors="pt")
|
| 144 |
-
with torch.no_grad():
|
| 145 |
-
feat = clip_model.get_image_features(**inputs)
|
| 146 |
-
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
| 147 |
-
features.append(feat.cpu())
|
| 148 |
-
return torch.cat(features, dim=0)
|
| 149 |
|
| 150 |
-
dataset_image_features = embed_dataset_images(data)
|
| 151 |
|
| 152 |
if query:
|
| 153 |
with st.spinner("Searching medical cases..."):
|
|
|
|
| 72 |
)
|
| 73 |
return dataset
|
| 74 |
|
| 75 |
+
# Cache dataset image embeddings (takes time, so cached)
|
| 76 |
+
@st.cache_data(show_spinner=True)
|
| 77 |
+
def embed_dataset_images(_dataset):
|
| 78 |
+
features = []
|
| 79 |
+
for item in _dataset:
|
| 80 |
+
# Load image from URL/path or raw bytes - adapt this if needed
|
| 81 |
+
img = item["image"]
|
| 82 |
+
inputs = clip_processor(images=img, return_tensors="pt")
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
feat = clip_model.get_image_features(**inputs)
|
| 85 |
+
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
| 86 |
+
features.append(feat.cpu())
|
| 87 |
+
return torch.cat(features, dim=0)
|
| 88 |
+
|
| 89 |
data = load_medical_data()
|
| 90 |
+
dataset_image_features = embed_dataset_images(data)
|
| 91 |
|
| 92 |
from openai import OpenAI
|
| 93 |
client = OpenAI(api_key=openai.api_key)
|
|
|
|
| 148 |
idx = top_result.indices[0].item()
|
| 149 |
return data[idx]
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
|
|
|
| 152 |
|
| 153 |
if query:
|
| 154 |
with st.spinner("Searching medical cases..."):
|