Spaces:
Running
Running
| import json | |
| import os | |
| import shutil | |
| from typing import Optional, Union | |
| import bm25s | |
| import huggingface_hub | |
| import weave | |
| from bm25s import BM25 | |
| from datasets import Dataset, load_dataset | |
| from Stemmer import Stemmer | |
| from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface | |
| LANGUAGE_DICT = { | |
| "english": "en", | |
| "french": "fr", | |
| "german": "de", | |
| } | |
| class BM25sRetriever(weave.Model): | |
| """ | |
| `BM25sRetriever` is a class that provides functionality for indexing and | |
| retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s). | |
| Args: | |
| language (str): The language of the documents to be indexed and retrieved. | |
| use_stemmer (bool): A flag indicating whether to use stemming during tokenization. | |
| retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided, | |
| a new instance is created. | |
| """ | |
| language: Optional[str] | |
| use_stemmer: bool = True | |
| _retriever: Optional[BM25] | |
| def __init__( | |
| self, | |
| language: str = "english", | |
| use_stemmer: bool = True, | |
| retriever: Optional[BM25] = None, | |
| ): | |
| super().__init__(language=language, use_stemmer=use_stemmer) | |
| self._retriever = retriever or BM25() | |
| def index( | |
| self, | |
| chunk_dataset: Union[Dataset, str], | |
| index_repo_id: Optional[str] = None, | |
| cleanup: bool = True, | |
| ): | |
| """ | |
| Indexes a dataset of text chunks using the BM25 algorithm. | |
| This method retrieves a dataset of text chunks from a specified source, tokenizes | |
| the text using the BM25 tokenizer with optional stemming, and indexes the tokenized | |
| text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved | |
| to disk and optionally logged as a Huggingface artifact. | |
| !!! example "Example Usage" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = BM25sRetriever() | |
| retriever.index( | |
| chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", | |
| index_repo_id="geekyrakshit/grays-anatomy-index", | |
| ) | |
| ``` | |
| Args: | |
| chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a | |
| dataset repository name or a dataset object can be provided. | |
| index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. | |
| cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. | |
| """ | |
| chunk_dataset = ( | |
| load_dataset(chunk_dataset, split="chunks") | |
| if isinstance(chunk_dataset, str) | |
| else chunk_dataset | |
| ) | |
| corpus = [row["text"] for row in chunk_dataset] | |
| corpus_tokens = bm25s.tokenize( | |
| corpus, | |
| stopwords=LANGUAGE_DICT[self.language], | |
| stemmer=Stemmer(self.language) if self.use_stemmer else None, | |
| ) | |
| self._retriever.index(corpus_tokens) | |
| if index_repo_id: | |
| os.makedirs(".huggingface", exist_ok=True) | |
| index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1]) | |
| self._retriever.save( | |
| index_save_dir, corpus=[dict(row) for row in chunk_dataset] | |
| ) | |
| commit_type = ( | |
| "update" | |
| if huggingface_hub.repo_exists(index_repo_id, repo_type="model") | |
| else "add" | |
| ) | |
| with open(os.path.join(index_save_dir, "config.json"), "w") as config_file: | |
| json.dump( | |
| { | |
| "language": self.language, | |
| "use_stemmer": self.use_stemmer, | |
| }, | |
| config_file, | |
| indent=4, | |
| ) | |
| save_to_huggingface( | |
| index_repo_id, | |
| index_save_dir, | |
| commit_message=f"{commit_type}: BM25s index", | |
| ) | |
| if cleanup: | |
| shutil.rmtree(index_save_dir) | |
| def from_index(cls, index_repo_id: str): | |
| """ | |
| Creates an instance of the class from a Huggingface repository. | |
| This class method retrieves a BM25 index artifact from a Huggingface repository, | |
| downloads the artifact, and loads the BM25 retriever with the index and its | |
| associated corpus. The method also extracts metadata from the artifact to | |
| initialize the class instance with the appropriate language and stemming | |
| settings. | |
| !!! example "Example Usage" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = BM25sRetriever() | |
| retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") | |
| ``` | |
| Args: | |
| index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. | |
| Returns: | |
| An instance of the class initialized with the BM25 retriever and metadata | |
| from the artifact. | |
| """ | |
| index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") | |
| retriever = bm25s.BM25.load(index_dir, load_corpus=True) | |
| with open(os.path.join(index_dir, "config.json"), "r") as config_file: | |
| config = json.load(config_file) | |
| return cls(retriever=retriever, **config) | |
| def retrieve(self, query: str, top_k: int = 2): | |
| """ | |
| Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm. | |
| This method tokenizes the input query using the BM25 tokenizer, which takes into | |
| account the language-specific stopwords and optional stemming. It then retrieves | |
| the top-k most relevant chunks from the BM25 index based on the tokenized query. | |
| The results are returned as a list of dictionaries, each containing a chunk and | |
| its corresponding relevance score. | |
| !!! example "Example Usage" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = BM25sRetriever() | |
| retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") | |
| retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") | |
| ``` | |
| Args: | |
| query (str): The input query string to search for relevant chunks. | |
| top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. | |
| Returns: | |
| list: A list of dictionaries, each containing a retrieved chunk and its | |
| relevance score. | |
| """ | |
| query_tokens = bm25s.tokenize( | |
| query, | |
| stopwords=LANGUAGE_DICT[self.language], | |
| stemmer=Stemmer(self.language) if self.use_stemmer else None, | |
| ) | |
| results = self._retriever.retrieve(query_tokens, k=top_k) | |
| retrieved_chunks = [] | |
| for chunk, score in zip( | |
| results.documents.flatten().tolist(), | |
| results.scores.flatten().tolist(), | |
| ): | |
| retrieved_chunks.append({**chunk, **{"score": score}}) | |
| return retrieved_chunks | |
| def predict(self, query: str, top_k: int = 2): | |
| """ | |
| Predicts the top-k most relevant chunks for a given query using the BM25 algorithm. | |
| This function is a wrapper around the `retrieve` method. It takes an input query string, | |
| tokenizes it using the BM25 tokenizer, and retrieves the top-k most relevant chunks from | |
| the BM25 index. The results are returned as a list of dictionaries, each containing a chunk | |
| and its corresponding relevance score. | |
| !!! example "Example Usage" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = BM25sRetriever() | |
| retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") | |
| retrieved_chunks = retriever.predict(query="What are Ribosomes?") | |
| ``` | |
| Args: | |
| query (str): The input query string to search for relevant chunks. | |
| top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. | |
| Returns: | |
| list: A list of dictionaries, each containing a retrieved chunk and its relevance score. | |
| """ | |
| return self.retrieve(query, top_k) | |