jocko commited on
Commit
b9bdf95
·
1 Parent(s): c8b7285

fix image similarity detection

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +15 -15
src/streamlit_app.py CHANGED
@@ -41,19 +41,19 @@ os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
41
  # ========== 📥 Load Models ==========
42
  @st.cache_resource(show_spinner=False)
43
  def load_models():
44
- clip_model = CLIPModel.from_pretrained(
45
  "openai/clip-vit-base-patch32",
46
  cache_dir=os.environ["TRANSFORMERS_CACHE"]
47
  )
48
- clip_processor = CLIPProcessor.from_pretrained(
49
  "openai/clip-vit-base-patch32",
50
  cache_dir=os.environ["TRANSFORMERS_CACHE"]
51
  )
52
- text_model = SentenceTransformer(
53
  "all-MiniLM-L6-v2",
54
  cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
55
  )
56
- return clip_model, clip_processor, text_model
57
 
58
  clip_model, clip_processor, text_model = load_models()
59
 
@@ -76,9 +76,9 @@ def embed_dataset_images(_dataset):
76
  for item in _dataset:
77
  # Load image from URL/path or raw bytes - adapt this if needed
78
  img = item["image"]
79
- inputs = clip_processor(images=img, return_tensors="pt")
80
  with torch.no_grad():
81
- feat = clip_model.get_image_features(**inputs)
82
  feat /= feat.norm(p=2, dim=-1, keepdim=True)
83
  features.append(feat.cpu())
84
  return torch.cat(features, dim=0)
@@ -113,27 +113,27 @@ combined_texts = prepare_combined_texts(data)
113
  def embed_dataset_texts(_texts):
114
  return text_model.encode(_texts, convert_to_tensor=True)
115
 
116
- def embed_query_text(query):
117
- return text_model.encode([query], convert_to_tensor=True)[0]
118
 
119
  @track
120
- def get_chat_completion_openai(client, prompt: str):
121
- return client.chat.completions.create(
122
  model="gpt-4o", # or "gpt-4" if you need the older GPT-4
123
- messages=[{"role": "user", "content": prompt}],
124
  temperature=0.5,
125
  max_tokens=150
126
  )
127
 
128
  @track
129
- def get_similar_prompt(query):
130
  text_embeddings = embed_dataset_texts(combined_texts) # cached
131
- query_embedding = embed_query_text(query) # recalculated each time
132
 
133
  cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
134
  top_result = torch.topk(cos_scores, k=1)
135
- idx = top_result.indices[0].item()
136
- return data[idx]
137
 
138
 
139
  # Pick which text column to use
 
41
  # ========== 📥 Load Models ==========
42
  @st.cache_resource(show_spinner=False)
43
  def load_models():
44
+ _clip_model = CLIPModel.from_pretrained(
45
  "openai/clip-vit-base-patch32",
46
  cache_dir=os.environ["TRANSFORMERS_CACHE"]
47
  )
48
+ _clip_processor = CLIPProcessor.from_pretrained(
49
  "openai/clip-vit-base-patch32",
50
  cache_dir=os.environ["TRANSFORMERS_CACHE"]
51
  )
52
+ _text_model = SentenceTransformer(
53
  "all-MiniLM-L6-v2",
54
  cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
55
  )
56
+ return _clip_model, _clip_processor, _text_model
57
 
58
  clip_model, clip_processor, text_model = load_models()
59
 
 
76
  for item in _dataset:
77
  # Load image from URL/path or raw bytes - adapt this if needed
78
  img = item["image"]
79
+ inputs_img = clip_processor(images=img, return_tensors="pt")
80
  with torch.no_grad():
81
+ feat = clip_model.get_image_features(**inputs_img)
82
  feat /= feat.norm(p=2, dim=-1, keepdim=True)
83
  features.append(feat.cpu())
84
  return torch.cat(features, dim=0)
 
113
  def embed_dataset_texts(_texts):
114
  return text_model.encode(_texts, convert_to_tensor=True)
115
 
116
+ def embed_query_text(_query):
117
+ return text_model.encode([_query], convert_to_tensor=True)[0]
118
 
119
  @track
120
+ def get_chat_completion_openai(_client, _prompt: str):
121
+ return _client.chat.completions.create(
122
  model="gpt-4o", # or "gpt-4" if you need the older GPT-4
123
+ messages=[{"role": "user", "content": _prompt}],
124
  temperature=0.5,
125
  max_tokens=150
126
  )
127
 
128
  @track
129
+ def get_similar_prompt(_query):
130
  text_embeddings = embed_dataset_texts(combined_texts) # cached
131
+ query_embedding = embed_query_text(_query) # recalculated each time
132
 
133
  cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
134
  top_result = torch.topk(cos_scores, k=1)
135
+ _idx = top_result.indices[0].item()
136
+ return data[_idx]
137
 
138
 
139
  # Pick which text column to use