Spaces:
Runtime error
Runtime error
| import json | |
| from typing import List, Literal, Protocol, Tuple, TypedDict, Union | |
| from pyserini.analysis import get_lucene_analyzer | |
| from pyserini.index import IndexReader | |
| from pyserini.search import DenseSearchResult, JLuceneSearcherResult | |
| from pyserini.search.faiss.__main__ import init_query_encoder | |
| from pyserini.search.faiss import FaissSearcher | |
| from pyserini.search.hybrid import HybridSearcher | |
| from pyserini.search.lucene import LuceneSearcher | |
| EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"] | |
| class AnalyzerArgs(TypedDict): | |
| language: str | |
| stemming: bool | |
| stemmer: str | |
| stopwords: bool | |
| huggingFaceTokenizer: str | |
| class SearchResult(TypedDict): | |
| docid: str | |
| text: str | |
| score: float | |
| language: str | |
| class Searcher(Protocol): | |
| def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]: | |
| ... | |
| def init_searcher_and_reader( | |
| sparse_index_path: str = None, | |
| bm25_k1: float = None, | |
| bm25_b: float = None, | |
| analyzer_args: AnalyzerArgs = None, | |
| dense_index_path: str = None, | |
| encoder_name_or_path: str = None, | |
| encoder_class: EncoderClass = None, | |
| tokenizer_name: str = None, | |
| device: str = None, | |
| prefix: str = None | |
| ) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]: | |
| """ | |
| Initialize and return an approapriate searcher | |
| Parameters | |
| ---------- | |
| sparse_index_path: str | |
| Path to sparse index | |
| dense_index_path: str | |
| Path to dense index | |
| encoder_name_or_path: str | |
| Path to query encoder checkpoint or encoder name | |
| encoder_class: str | |
| Query encoder class to use. If None, infer from `encoder` | |
| tokenizer_name: str | |
| Tokenizer name or path | |
| device: str | |
| Device to load Query encoder on. | |
| prefix: str | |
| Query prefix if exists | |
| Returns | |
| ------- | |
| Searcher: FaissSearcher | HybridSearcher | LuceneSearcher | |
| A sparse, dense or hybrid searcher | |
| """ | |
| reader = None | |
| if sparse_index_path: | |
| ssearcher = LuceneSearcher(sparse_index_path) | |
| if analyzer_args: | |
| analyzer = get_lucene_analyzer(**analyzer_args) | |
| ssearcher.set_analyzer(analyzer) | |
| if bm25_k1 and bm25_b: | |
| ssearcher.set_bm25(bm25_k1, bm25_b) | |
| if dense_index_path: | |
| encoder = init_query_encoder( | |
| encoder=encoder_name_or_path, | |
| encoder_class=encoder_class, | |
| tokenizer_name=tokenizer_name, | |
| topics_name=None, | |
| encoded_queries=None, | |
| device=device, | |
| prefix=prefix | |
| ) | |
| reader = IndexReader(sparse_index_path) | |
| dsearcher = FaissSearcher(dense_index_path, encoder) | |
| if sparse_index_path: | |
| hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher) | |
| return hsearcher, reader | |
| else: | |
| return dsearcher, reader | |
| return ssearcher, reader | |
| def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]: | |
| """ | |
| Parameters: | |
| ----------- | |
| searcher: FaissSearcher | HybridSearcher | LuceneSearcher | |
| A sparse, dense or hybrid searcher | |
| query: str | |
| Query for which to retrieve results | |
| num_results: int | |
| Maximum number of results to retrieve | |
| Returns: | |
| -------- | |
| Dict: | |
| """ | |
| def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]): | |
| if isinstance(r, JLuceneSearcherResult): | |
| return json.loads(r.raw) | |
| elif isinstance(r, DenseSearchResult): | |
| # Get document from sparse_index using index reader | |
| return json.loads(reader.doc(r.docid).raw()) | |
| search_results = searcher.search(query, k=num_results) | |
| all_results = [ | |
| SearchResult( | |
| docid=result["id"], | |
| text=result["contents"], | |
| score=search_results[idx].score | |
| ) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results)) | |
| ] | |
| return all_results | |