Spaces:
Running
Running
| import streamlit as st | |
| import json | |
| from typing import List | |
| from fastembed import LateInteractionTextEmbedding, TextEmbedding | |
| from fastembed import SparseTextEmbedding, SparseEmbedding | |
| from qdrant_client import QdrantClient, models | |
| from tokenizers import Tokenizer | |
| ############################# | |
| # 1. Utility / Helper Code | |
| ############################# | |
| def load_tokenizer(): | |
| """ | |
| Load the tokenizer for interpreting sparse embeddings (optional usage). | |
| """ | |
| return Tokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0]["sources"]["hf"]) | |
| def load_models(): | |
| """ | |
| Load/initialize your models once and cache them. | |
| """ | |
| # Dense embedding model | |
| dense_embedding_model = TextEmbedding("BAAI/bge-small-en-v1.5") | |
| # Late interaction model (ColBERTv2) | |
| late_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0") | |
| # Sparse embedding model | |
| sparse_model_name = "Qdrant/bm25" | |
| sparse_model = SparseTextEmbedding(model_name=sparse_model_name) | |
| return dense_embedding_model, late_embedding_model, sparse_model | |
| def build_qdrant_index(data): | |
| """ | |
| Given the parsed data (list of items), build an in-memory Qdrant index | |
| with dense, late, and sparse vectors. | |
| """ | |
| # Extract fields | |
| items = data["items"] | |
| descriptions = [f"{item['name']} - {item['description']}" for item in items] | |
| names = [item["name"] for item in items] | |
| metadata = [ | |
| {"name": item["name"],"item_id":item["id"]} # You can store more fields if you like | |
| for item in items | |
| ] | |
| # Load models | |
| dense_embedding_model, late_embedding_model, sparse_model = load_models() | |
| # Generate embeddings | |
| dense_embeddings = list(dense_embedding_model.embed(descriptions)) | |
| name_dense_embeddings = list(dense_embedding_model.embed(names)) | |
| late_embeddings = list(late_embedding_model.embed(descriptions)) | |
| sparse_embeddings: List[SparseEmbedding] = list(sparse_model.embed(descriptions, batch_size=6)) | |
| # Create an in-memory Qdrant instance | |
| qdrant_client = QdrantClient(":memory:") | |
| # Create collection schema | |
| qdrant_client.create_collection( | |
| collection_name="items", | |
| vectors_config={ | |
| "dense": models.VectorParams( | |
| size=len(dense_embeddings[0]), | |
| distance=models.Distance.COSINE, | |
| ), | |
| "late": models.VectorParams( | |
| size=len(late_embeddings[0][0]), | |
| distance=models.Distance.COSINE, | |
| multivector_config=models.MultiVectorConfig( | |
| comparator=models.MultiVectorComparator.MAX_SIM | |
| ), | |
| ), | |
| }, | |
| sparse_vectors_config={ | |
| "sparse": models.SparseVectorParams( | |
| modifier=models.Modifier.IDF, | |
| ), | |
| } | |
| ) | |
| # Upload points | |
| points = [] | |
| for idx, _ in enumerate(metadata): | |
| points.append( | |
| models.PointStruct( | |
| id=idx, | |
| payload=metadata[idx], | |
| vector={ | |
| "late": late_embeddings[idx].tolist(), | |
| "dense": dense_embeddings[idx], | |
| "sparse": sparse_embeddings[idx].as_object(), | |
| }, | |
| ) | |
| ) | |
| qdrant_client.upload_points( | |
| collection_name="items", | |
| points=points, | |
| ) | |
| return qdrant_client | |
| def run_queries(qdrant_client, query_text): | |
| """ | |
| Run all the different query types and return results in a dictionary. | |
| """ | |
| # Load models | |
| dense_embedding_model, late_embedding_model, sparse_model = load_models() | |
| # Generate single-query embeddings | |
| dense_query = next(dense_embedding_model.query_embed(query_text)) | |
| late_query = next(late_embedding_model.query_embed(query_text)) | |
| sparse_query = next(sparse_model.query_embed(query_text)) | |
| # For the fusion approach, we need a list form for prefetch | |
| tsq = list(sparse_model.embed(query_text, batch_size=6)) | |
| # We'll store top-5 results for each approach | |
| results = {} | |
| # 1) ColBERT (late) | |
| results["C"] = qdrant_client.query_points( | |
| collection_name="items", | |
| query=late_query, | |
| using="late", | |
| limit=5, | |
| with_payload=True | |
| ) | |
| # 2) Sparse only | |
| results["S"] = qdrant_client.query_points( | |
| collection_name="items", | |
| query=models.SparseVector(**sparse_query.as_object()), | |
| using="sparse", | |
| limit=5, | |
| with_payload=True | |
| ) | |
| # 3) Dense only | |
| results["D"] = qdrant_client.query_points( | |
| collection_name="items", | |
| query=dense_query, | |
| using="dense", | |
| limit=5, | |
| with_payload=True | |
| ) | |
| # 4) Hybrid fusion (RRF for Sparse+Dense) | |
| results["S+D-F"] = qdrant_client.query_points( | |
| collection_name="items", | |
| prefetch=[ | |
| models.Prefetch( | |
| query=dense_query, | |
| using="dense", | |
| limit=100, | |
| ), | |
| models.Prefetch( | |
| query=tsq[0].as_object(), | |
| using="sparse", | |
| limit=50, | |
| ) | |
| ], | |
| query=models.FusionQuery(fusion=models.Fusion.RRF), | |
| limit=5, | |
| with_payload=True | |
| ) | |
| # 5) Hybrid fusion + ColBERT | |
| sparse_dense_prefetch = models.Prefetch( | |
| prefetch=[ | |
| models.Prefetch(query=dense_query, using="dense", limit=100), | |
| models.Prefetch(query=tsq[0].as_object(), using="sparse", limit=50), | |
| ], | |
| limit=10, | |
| query=models.FusionQuery(fusion=models.Fusion.RRF), | |
| ) | |
| results["S+D-F-C"] = qdrant_client.query_points( | |
| collection_name="items", | |
| prefetch=[sparse_dense_prefetch], | |
| query=late_query, | |
| using="late", | |
| limit=5, | |
| with_payload=True | |
| ) | |
| # 6) Hybrid no-fusion + ColBERT | |
| old_prefetch = models.Prefetch( | |
| prefetch=[ | |
| models.Prefetch( | |
| prefetch=[ | |
| models.Prefetch(query=dense_query, using="dense", limit=100) | |
| ], | |
| query=tsq[0].as_object(), | |
| using="sparse", | |
| limit=50, | |
| ) | |
| ] | |
| ) | |
| results["S+D-C"] = qdrant_client.query_points( | |
| collection_name="items", | |
| prefetch=[old_prefetch], | |
| query=late_query, | |
| using="late", | |
| limit=5, | |
| with_payload=True | |
| ) | |
| return results | |
| ############################# | |
| # 2. Streamlit Main App | |
| ############################# | |
| def main(): | |
| st.title("Semantic Search Sandbox") | |
| # Initialize session state if not present | |
| if "json_loaded" not in st.session_state: | |
| st.session_state["json_loaded"] = False | |
| if "qdrant_client" not in st.session_state: | |
| st.session_state["qdrant_client"] = None | |
| ####################################### | |
| # Show JSON input only if not loaded | |
| ####################################### | |
| if not st.session_state["json_loaded"]: | |
| st.subheader("Paste items.json Here") | |
| default_json = """ | |
| { | |
| "items": [ | |
| { | |
| "name": "Example1", | |
| "description": "An example item" | |
| }, | |
| { | |
| "name": "Example2", | |
| "description": "Another item for demonstration" | |
| } | |
| ] | |
| } | |
| """.strip() | |
| json_text = st.text_area("JSON Input", value=default_json, height=300) | |
| if st.button("Load JSON"): | |
| try: | |
| data = json.loads(json_text) | |
| # Build Qdrant index in memory | |
| st.session_state["qdrant_client"] = build_qdrant_index(data) | |
| st.session_state["json_loaded"] = True | |
| st.success("JSON loaded and Qdrant index built successfully!") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error parsing JSON: {e}") | |
| else: | |
| # The data is loaded, show a button to reset if you want to load new JSON | |
| if st.button("Load a different JSON"): | |
| st.session_state["json_loaded"] = False | |
| st.session_state["qdrant_client"] = None | |
| #st.experimental_rerun() # Refresh the page | |
| else: | |
| # Show the search interface | |
| query_text = st.text_input("Search Query", value="ACB 1.0 Ports") | |
| if st.button("Search"): | |
| if st.session_state["qdrant_client"] is None: | |
| st.warning("Please load valid JSON first.") | |
| return | |
| # Run queries | |
| results_dict = run_queries(st.session_state["qdrant_client"], query_text) | |
| # Display results in columns | |
| col_names = list(results_dict.keys()) | |
| # You can split into multiple rows if there are more than 3 | |
| n_cols = 3 | |
| # We'll create enough columns to handle all search types | |
| rows_needed = (len(col_names) + n_cols - 1) // n_cols | |
| for row_idx in range(rows_needed): | |
| cols = st.columns(n_cols) | |
| for col_idx in range(n_cols): | |
| method_idx = row_idx * n_cols + col_idx | |
| if method_idx < len(col_names): | |
| method = col_names[method_idx] | |
| qdrant_result = results_dict[method] | |
| with cols[col_idx]: | |
| st.markdown(f"### {method}") | |
| for point in qdrant_result.points: | |
| name = point.payload.get("name", "Unnamed") | |
| item_id = point.payload.get("item_id", "") | |
| score = round(point.score, 4) if point.score else "N/A" | |
| st.write(f"- **{item_id}-{name}** (score={score})") | |
| if __name__ == "__main__": | |
| main() |