Spaces:
Running
on
Zero
Running
on
Zero
| import datasets | |
| import numpy as np | |
| import spaces | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| from table import BASE_REPO_ID | |
| ds = datasets.load_dataset(BASE_REPO_ID, split="train") | |
| ds.add_faiss_index(column="embedding") | |
| bi_model = SentenceTransformer("BAAI/bge-base-en-v1.5") | |
| ce_model = CrossEncoder("BAAI/bge-reranker-base") | |
| def search(query: str, candidate_pool_size: int = 100, retrieval_k: int = 50) -> list[dict]: | |
| prefix = "Represent this sentence for searching relevant passages: " | |
| q_vec = bi_model.encode(prefix + query, normalize_embeddings=True) | |
| _, retrieved_ds = ds.get_nearest_examples("embedding", q_vec, k=candidate_pool_size) | |
| ce_inputs = [ | |
| (query, f"{retrieved_ds['title'][i]} {retrieved_ds['abstract'][i]}") for i in range(len(retrieved_ds["title"])) | |
| ] | |
| ce_scores = ce_model.predict(ce_inputs, batch_size=16) | |
| sorted_idx = np.argsort(ce_scores)[::-1] | |
| return [ | |
| {"paper_id": retrieved_ds["paper_id"][i], "ce_score": float(ce_scores[i])} for i in sorted_idx[:retrieval_k] | |
| ] | |