Spaces:
Sleeping
Sleeping
jocko
commited on
Commit
·
b9bdf95
1
Parent(s):
c8b7285
fix image similarity detection
Browse files- src/streamlit_app.py +15 -15
src/streamlit_app.py
CHANGED
|
@@ -41,19 +41,19 @@ os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
|
|
| 41 |
# ========== 📥 Load Models ==========
|
| 42 |
@st.cache_resource(show_spinner=False)
|
| 43 |
def load_models():
|
| 44 |
-
|
| 45 |
"openai/clip-vit-base-patch32",
|
| 46 |
cache_dir=os.environ["TRANSFORMERS_CACHE"]
|
| 47 |
)
|
| 48 |
-
|
| 49 |
"openai/clip-vit-base-patch32",
|
| 50 |
cache_dir=os.environ["TRANSFORMERS_CACHE"]
|
| 51 |
)
|
| 52 |
-
|
| 53 |
"all-MiniLM-L6-v2",
|
| 54 |
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
|
| 55 |
)
|
| 56 |
-
return
|
| 57 |
|
| 58 |
clip_model, clip_processor, text_model = load_models()
|
| 59 |
|
|
@@ -76,9 +76,9 @@ def embed_dataset_images(_dataset):
|
|
| 76 |
for item in _dataset:
|
| 77 |
# Load image from URL/path or raw bytes - adapt this if needed
|
| 78 |
img = item["image"]
|
| 79 |
-
|
| 80 |
with torch.no_grad():
|
| 81 |
-
feat = clip_model.get_image_features(**
|
| 82 |
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
| 83 |
features.append(feat.cpu())
|
| 84 |
return torch.cat(features, dim=0)
|
|
@@ -113,27 +113,27 @@ combined_texts = prepare_combined_texts(data)
|
|
| 113 |
def embed_dataset_texts(_texts):
|
| 114 |
return text_model.encode(_texts, convert_to_tensor=True)
|
| 115 |
|
| 116 |
-
def embed_query_text(
|
| 117 |
-
return text_model.encode([
|
| 118 |
|
| 119 |
@track
|
| 120 |
-
def get_chat_completion_openai(
|
| 121 |
-
return
|
| 122 |
model="gpt-4o", # or "gpt-4" if you need the older GPT-4
|
| 123 |
-
messages=[{"role": "user", "content":
|
| 124 |
temperature=0.5,
|
| 125 |
max_tokens=150
|
| 126 |
)
|
| 127 |
|
| 128 |
@track
|
| 129 |
-
def get_similar_prompt(
|
| 130 |
text_embeddings = embed_dataset_texts(combined_texts) # cached
|
| 131 |
-
query_embedding = embed_query_text(
|
| 132 |
|
| 133 |
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
|
| 134 |
top_result = torch.topk(cos_scores, k=1)
|
| 135 |
-
|
| 136 |
-
return data[
|
| 137 |
|
| 138 |
|
| 139 |
# Pick which text column to use
|
|
|
|
| 41 |
# ========== 📥 Load Models ==========
|
| 42 |
@st.cache_resource(show_spinner=False)
|
| 43 |
def load_models():
|
| 44 |
+
_clip_model = CLIPModel.from_pretrained(
|
| 45 |
"openai/clip-vit-base-patch32",
|
| 46 |
cache_dir=os.environ["TRANSFORMERS_CACHE"]
|
| 47 |
)
|
| 48 |
+
_clip_processor = CLIPProcessor.from_pretrained(
|
| 49 |
"openai/clip-vit-base-patch32",
|
| 50 |
cache_dir=os.environ["TRANSFORMERS_CACHE"]
|
| 51 |
)
|
| 52 |
+
_text_model = SentenceTransformer(
|
| 53 |
"all-MiniLM-L6-v2",
|
| 54 |
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
|
| 55 |
)
|
| 56 |
+
return _clip_model, _clip_processor, _text_model
|
| 57 |
|
| 58 |
clip_model, clip_processor, text_model = load_models()
|
| 59 |
|
|
|
|
| 76 |
for item in _dataset:
|
| 77 |
# Load image from URL/path or raw bytes - adapt this if needed
|
| 78 |
img = item["image"]
|
| 79 |
+
inputs_img = clip_processor(images=img, return_tensors="pt")
|
| 80 |
with torch.no_grad():
|
| 81 |
+
feat = clip_model.get_image_features(**inputs_img)
|
| 82 |
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
| 83 |
features.append(feat.cpu())
|
| 84 |
return torch.cat(features, dim=0)
|
|
|
|
| 113 |
def embed_dataset_texts(_texts):
|
| 114 |
return text_model.encode(_texts, convert_to_tensor=True)
|
| 115 |
|
| 116 |
+
def embed_query_text(_query):
|
| 117 |
+
return text_model.encode([_query], convert_to_tensor=True)[0]
|
| 118 |
|
| 119 |
@track
|
| 120 |
+
def get_chat_completion_openai(_client, _prompt: str):
|
| 121 |
+
return _client.chat.completions.create(
|
| 122 |
model="gpt-4o", # or "gpt-4" if you need the older GPT-4
|
| 123 |
+
messages=[{"role": "user", "content": _prompt}],
|
| 124 |
temperature=0.5,
|
| 125 |
max_tokens=150
|
| 126 |
)
|
| 127 |
|
| 128 |
@track
|
| 129 |
+
def get_similar_prompt(_query):
|
| 130 |
text_embeddings = embed_dataset_texts(combined_texts) # cached
|
| 131 |
+
query_embedding = embed_query_text(_query) # recalculated each time
|
| 132 |
|
| 133 |
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
|
| 134 |
top_result = torch.topk(cos_scores, k=1)
|
| 135 |
+
_idx = top_result.indices[0].item()
|
| 136 |
+
return data[_idx]
|
| 137 |
|
| 138 |
|
| 139 |
# Pick which text column to use
|