hackerloi45 commited on
Commit
7116b90
Β·
1 Parent(s): c65ef6e

fixed search by text

Browse files
Files changed (1) hide show
  1. app.py +145 -99
app.py CHANGED
@@ -1,79 +1,96 @@
1
  import os
2
  import uuid
3
  import gradio as gr
 
 
4
  import qdrant_client
5
- from qdrant_client import models
 
6
  from sentence_transformers import SentenceTransformer
7
- from PIL import Image
8
 
9
  # ===============================
10
- # Setup
11
  # ===============================
12
  UPLOAD_DIR = "uploaded_images"
13
  os.makedirs(UPLOAD_DIR, exist_ok=True)
14
 
15
  COLLECTION = "lost_and_found"
16
 
17
- # Qdrant client (in-memory for Hugging Face)
18
- qclient = qdrant_client.QdrantClient(":memory:")
 
 
19
  encoder = SentenceTransformer("clip-ViT-B-32")
20
 
21
- # Create collection if not exists
 
22
  if not qclient.collection_exists(COLLECTION):
23
  qclient.create_collection(
24
  collection_name=COLLECTION,
25
- vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),
26
  )
27
 
28
 
29
  # ===============================
30
- # Encode Function (Text or Image)
 
31
  # ===============================
32
  def encode_data(text=None, image=None):
33
- if isinstance(image, Image.Image): # Image is already PIL
 
 
 
 
 
 
34
  return encoder.encode(image.convert("RGB"))
35
- elif isinstance(image, str): # Path to image
36
  return encoder.encode(Image.open(image).convert("RGB"))
37
- elif text:
 
 
38
  return encoder.encode([text])[0]
39
- else:
40
- return None
41
 
42
 
43
  # ===============================
44
- # Add Item
45
  # ===============================
46
  def add_item(text, image, uploader_name, uploader_phone):
47
  try:
48
  img_path = None
49
  vector = None
50
 
51
- if isinstance(image, Image.Image): # PIL image
 
52
  img_id = str(uuid.uuid4())
53
  img_path = os.path.join(UPLOAD_DIR, f"{img_id}.png")
54
  image.save(img_path)
55
  vector = encode_data(image=image)
56
 
 
57
  elif text:
58
  vector = encode_data(text=text)
59
 
60
  if vector is None:
61
- return "❌ Please provide an image or text."
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  qclient.upsert(
64
  collection_name=COLLECTION,
65
- points=[
66
- models.PointStruct(
67
- id=str(uuid.uuid4()),
68
- vector=vector.tolist(),
69
- payload={
70
- "text": text or "",
71
- "uploader_name": uploader_name or "N/A",
72
- "uploader_phone": uploader_phone or "N/A",
73
- "image_path": img_path,
74
- },
75
- )
76
- ],
77
  )
78
  return "βœ… Item added to database!"
79
  except Exception as e:
@@ -81,104 +98,133 @@ def add_item(text, image, uploader_name, uploader_phone):
81
 
82
 
83
  # ===============================
84
- # Search Function
 
 
85
  # ===============================
86
  def search_items(text, image, max_results, min_score):
87
  try:
88
- vector = None
89
- if isinstance(image, Image.Image): # Search with PIL
90
- vector = encode_data(image=image)
91
- elif text:
92
- vector = encode_data(text=text)
93
-
94
- if vector is None:
95
- return "❌ Please provide an image or text.", []
96
-
97
- results = qclient.search(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  collection_name=COLLECTION,
99
- query_vector=vector.tolist(),
100
- limit=max_results,
101
- score_threshold=min_score,
 
102
  )
103
 
104
- if not results:
105
  return "No matches found.", []
106
 
107
- # Format results
108
- result_texts, result_imgs = [], []
109
- for r in results:
110
- payload = r.payload
111
- result_texts.append(
112
- f"id:{r.id} | score:{r.score:.3f} | "
113
- f"text:{payload.get('text','')} | "
114
- f"finder:{payload.get('uploader_name','N/A')} "
115
- f"({payload.get('uploader_phone','N/A')})"
 
 
 
 
116
  )
117
- if payload.get("image_path") and os.path.exists(payload["image_path"]):
118
- result_imgs.append(payload["image_path"])
 
 
 
 
 
 
 
 
 
 
119
 
120
- return "\n".join(result_texts), result_imgs
121
  except Exception as e:
122
  return f"❌ Error: {e}", []
123
 
124
 
125
  # ===============================
126
- # Delete All
127
  # ===============================
128
  def clear_database():
129
- qclient.delete_collection(COLLECTION)
130
- qclient.create_collection(
131
- collection_name=COLLECTION,
132
- vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),
133
- )
134
- return "πŸ—‘οΈ Database cleared!"
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  # ===============================
138
  # Gradio UI
139
  # ===============================
140
  with gr.Blocks() as demo:
141
- gr.Markdown("## πŸ—οΈ Lost & Found - Database")
142
-
143
- # --- Add Item Tab ---
144
- with gr.Tab("βž• Add Item"):
145
- with gr.Row():
146
- text_input = gr.Textbox(label="Description (optional)")
147
- img_input = gr.Image(type="pil", label="Upload Image")
148
- with gr.Row():
149
- uploader_name = gr.Textbox(label="Finder Name")
150
- uploader_phone = gr.Textbox(label="Finder Phone")
151
- add_btn = gr.Button("Add to Database")
152
- add_output = gr.Textbox(label="Status")
153
-
154
- add_btn.click(
155
- add_item,
156
- inputs=[text_input, img_input, uploader_name, uploader_phone],
157
- outputs=add_output,
158
- )
159
-
160
- # --- Search Tab ---
161
- with gr.Tab("πŸ” Search"):
162
- with gr.Row():
163
- search_text = gr.Textbox(label="Search by text (optional)")
164
- search_img = gr.Image(type="pil", label="Search by image (optional)")
165
- with gr.Row():
166
- max_results = gr.Slider(1, 10, value=5, step=1, label="Max results")
167
- min_score = gr.Slider(0.5, 1.0, value=0.8, step=0.01, label="Min similarity threshold")
168
  search_btn = gr.Button("Search")
169
  search_text_out = gr.Textbox(label="Search results (text)")
170
  search_gallery = gr.Gallery(label="Search Results", columns=2, height="auto")
 
171
 
172
- search_btn.click(
173
- search_items,
174
- inputs=[search_text, search_img, max_results, min_score],
175
- outputs=[search_text_out, search_gallery],
176
- )
177
-
178
- # --- Admin Tab ---
179
  with gr.Tab("πŸ—‘οΈ Admin"):
180
- clear_btn = gr.Button("Clear Database")
181
  clear_out = gr.Textbox(label="Status")
182
- clear_btn.click(clear_database, outputs=clear_out)
183
 
184
- demo.launch()
 
 
1
  import os
2
  import uuid
3
  import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
  import qdrant_client
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http.models import VectorParams, Distance, PointStruct
9
  from sentence_transformers import SentenceTransformer
 
10
 
11
  # ===============================
12
+ # Config / Setup
13
  # ===============================
14
  UPLOAD_DIR = "uploaded_images"
15
  os.makedirs(UPLOAD_DIR, exist_ok=True)
16
 
17
  COLLECTION = "lost_and_found"
18
 
19
+ # Qdrant client (in-memory for Spaces; replace with actual url/api_key if you use a remote Qdrant)
20
+ qclient = QdrantClient(":memory:")
21
+
22
+ # SentenceTransformer encoder (CLIP)
23
  encoder = SentenceTransformer("clip-ViT-B-32")
24
 
25
+ # Create collection if missing (use the model vector size)
26
+ VECTOR_SIZE = encoder.get_sentence_embedding_dimension()
27
  if not qclient.collection_exists(COLLECTION):
28
  qclient.create_collection(
29
  collection_name=COLLECTION,
30
+ vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
31
  )
32
 
33
 
34
  # ===============================
35
+ # Encoding function
36
+ # (image handling MUST remain unchanged per request)
37
  # ===============================
38
  def encode_data(text=None, image=None):
39
+ """
40
+ Returns a vector (numpy array) for either a PIL Image, an image path (str),
41
+ or text (string). Image-handling kept exactly as requested.
42
+ """
43
+ # --- IMAGE branch (unchanged) ---
44
+ if isinstance(image, Image.Image):
45
+ # NOTE: per your instruction, do not modify the image encoding logic
46
  return encoder.encode(image.convert("RGB"))
47
+ if isinstance(image, str):
48
  return encoder.encode(Image.open(image).convert("RGB"))
49
+
50
+ # --- TEXT branch (safe to adjust) ---
51
+ if text:
52
  return encoder.encode([text])[0]
53
+
54
+ return None
55
 
56
 
57
  # ===============================
58
+ # Add Item (finder uploads a found item)
59
  # ===============================
60
  def add_item(text, image, uploader_name, uploader_phone):
61
  try:
62
  img_path = None
63
  vector = None
64
 
65
+ # If image provided (PIL), save and encode by image (image priority)
66
+ if isinstance(image, Image.Image):
67
  img_id = str(uuid.uuid4())
68
  img_path = os.path.join(UPLOAD_DIR, f"{img_id}.png")
69
  image.save(img_path)
70
  vector = encode_data(image=image)
71
 
72
+ # If no image but text provided -> encode text
73
  elif text:
74
  vector = encode_data(text=text)
75
 
76
  if vector is None:
77
+ return "❌ Please provide at least an image or some text."
78
+
79
+ # Ensure vector is numpy array
80
+ vec = np.asarray(vector, dtype=float)
81
+
82
+ payload = {
83
+ "text": text or "",
84
+ "uploader_name": (uploader_name or "N/A"),
85
+ "uploader_phone": (uploader_phone or "N/A"),
86
+ "image_path": img_path,
87
+ "has_image": bool(img_path),
88
+ }
89
 
90
  qclient.upsert(
91
  collection_name=COLLECTION,
92
+ points=[PointStruct(id=str(uuid.uuid4()), vector=vec.tolist(), payload=payload)],
93
+ wait=True,
 
 
 
 
 
 
 
 
 
 
94
  )
95
  return "βœ… Item added to database!"
96
  except Exception as e:
 
98
 
99
 
100
  # ===============================
101
+ # Search (fixed to handle text OR image OR both)
102
+ # - keep image coding intact (per request)
103
+ # - if both text+image supplied, average normalized vectors (cross-modal)
104
  # ===============================
105
  def search_items(text, image, max_results, min_score):
106
  try:
107
+ text_vec = None
108
+ img_vec = None
109
+
110
+ # get vectors (do not change image encoding)
111
+ if isinstance(image, Image.Image):
112
+ img_vec = encode_data(image=image)
113
+ img_vec = np.asarray(img_vec, dtype=float)
114
+ if text and len(text.strip()) > 0:
115
+ text_vec = encode_data(text=text)
116
+ text_vec = np.asarray(text_vec, dtype=float)
117
+
118
+ # If both provided -> combine (normalize then average)
119
+ if img_vec is not None and text_vec is not None:
120
+ # normalize
121
+ n1 = np.linalg.norm(img_vec) + 1e-12
122
+ n2 = np.linalg.norm(text_vec) + 1e-12
123
+ v1 = img_vec / n1
124
+ v2 = text_vec / n2
125
+ qvec = v1 + v2
126
+ qvec = qvec / (np.linalg.norm(qvec) + 1e-12)
127
+ elif img_vec is not None:
128
+ qvec = img_vec
129
+ elif text_vec is not None:
130
+ qvec = text_vec
131
+ else:
132
+ return "❌ Please provide an image or some text to search.", []
133
+
134
+ # Run search
135
+ hits = qclient.search(
136
  collection_name=COLLECTION,
137
+ query_vector=qvec.tolist(),
138
+ limit=int(max_results),
139
+ score_threshold=float(min_score),
140
+ with_payload=True,
141
  )
142
 
143
+ if not hits:
144
  return "No matches found.", []
145
 
146
+ result_texts = []
147
+ gallery_items = [] # list of image paths (or placeholders)
148
+
149
+ for h in hits:
150
+ payload = h.payload or {}
151
+ score = getattr(h, "score", None)
152
+ score_str = f"{float(score):.3f}" if score is not None else "N/A"
153
+ uploader_name = payload.get("uploader_name", "N/A") or "N/A"
154
+ uploader_phone = payload.get("uploader_phone", "N/A") or "N/A"
155
+
156
+ desc = (
157
+ f"id:{h.id} | score:{score_str} | text:{payload.get('text','')} "
158
+ f"| finder:{uploader_name} ({uploader_phone})"
159
  )
160
+ result_texts.append(desc)
161
+
162
+ img_path = payload.get("image_path")
163
+ if img_path and os.path.exists(img_path):
164
+ gallery_items.append(img_path)
165
+ else:
166
+ # append a small placeholder (you can also skip adding)
167
+ # Gradio can display an empty string but better to put a placeholder image path if desired
168
+ # We'll skip adding placeholders so gallery only shows real images
169
+ pass
170
+
171
+ return "\n".join(result_texts), gallery_items
172
 
 
173
  except Exception as e:
174
  return f"❌ Error: {e}", []
175
 
176
 
177
  # ===============================
178
+ # Clear DB
179
  # ===============================
180
  def clear_database():
181
+ try:
182
+ if qclient.collection_exists(COLLECTION):
183
+ qclient.delete_collection(COLLECTION)
184
+ qclient.create_collection(
185
+ collection_name=COLLECTION,
186
+ vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
187
+ )
188
+ # delete uploaded images
189
+ for f in os.listdir(UPLOAD_DIR):
190
+ try:
191
+ os.remove(os.path.join(UPLOAD_DIR, f))
192
+ except Exception:
193
+ pass
194
+ return "πŸ—‘οΈ Database cleared!"
195
+ except Exception as e:
196
+ return f"❌ Error clearing DB: {e}"
197
 
198
 
199
  # ===============================
200
  # Gradio UI
201
  # ===============================
202
  with gr.Blocks() as demo:
203
+ gr.Markdown("## πŸ”Ž Lost & Found β€” Add Found Items (finder) & Search (lost)")
204
+
205
+ with gr.Tab("βž• Add Found Item"):
206
+ text_in = gr.Textbox(label="Description (optional)")
207
+ img_in = gr.Image(type="pil", label="Upload Image (optional)")
208
+ uploader_name = gr.Textbox(label="Your name (finder)")
209
+ uploader_phone = gr.Textbox(label="Your phone (finder)")
210
+ add_btn = gr.Button("Add to database")
211
+ add_status = gr.Textbox(label="Status")
212
+ add_btn.click(add_item, inputs=[text_in, img_in, uploader_name, uploader_phone], outputs=[add_status])
213
+
214
+ with gr.Tab("πŸ” Search Lost Item"):
215
+ search_text = gr.Textbox(label="Search by text (optional)")
216
+ search_img = gr.Image(type="pil", label="Search by image (optional)")
217
+ max_results = gr.Slider(1, 20, value=5, step=1, label="Max results")
218
+ min_score = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Min similarity score")
 
 
 
 
 
 
 
 
 
 
 
219
  search_btn = gr.Button("Search")
220
  search_text_out = gr.Textbox(label="Search results (text)")
221
  search_gallery = gr.Gallery(label="Search Results", columns=2, height="auto")
222
+ search_btn.click(search_items, inputs=[search_text, search_img, max_results, min_score], outputs=[search_text_out, search_gallery])
223
 
 
 
 
 
 
 
 
224
  with gr.Tab("πŸ—‘οΈ Admin"):
225
+ clear_btn = gr.Button("Clear database")
226
  clear_out = gr.Textbox(label="Status")
227
+ clear_btn.click(clear_database, outputs=[clear_out])
228
 
229
+ if __name__ == "__main__":
230
+ demo.launch(server_name="0.0.0.0", server_port=7860)