Spaces:
Runtime error
Runtime error
Commit
·
b5a3ebb
1
Parent(s):
7f98acf
add: docs for BM25sRetriever
Browse files
medrag_multi_modal/retrieval/bm25s_retrieval.py
CHANGED
|
@@ -16,6 +16,17 @@ LANGUAGE_DICT = {
|
|
| 16 |
|
| 17 |
|
| 18 |
class BM25sRetriever(weave.Model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
language: str
|
| 20 |
use_stemmer: bool
|
| 21 |
_retriever: Optional[bm25s.BM25]
|
|
@@ -30,6 +41,34 @@ class BM25sRetriever(weave.Model):
|
|
| 30 |
self._retriever = retriever or bm25s.BM25()
|
| 31 |
|
| 32 |
def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
chunk_dataset = weave.ref(chunk_dataset_name).get().rows
|
| 34 |
corpus = [row["text"] for row in chunk_dataset]
|
| 35 |
corpus_tokens = bm25s.tokenize(
|
|
@@ -56,6 +95,23 @@ class BM25sRetriever(weave.Model):
|
|
| 56 |
|
| 57 |
@classmethod
|
| 58 |
def from_wandb_artifact(cls, index_artifact_address: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if wandb.run:
|
| 60 |
artifact = wandb.run.use_artifact(
|
| 61 |
index_artifact_address, type="bm25s-index"
|
|
@@ -76,13 +132,48 @@ class BM25sRetriever(weave.Model):
|
|
| 76 |
|
| 77 |
@weave.op()
|
| 78 |
def retrieve(self, query: str, top_k: int = 2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
query_tokens = bm25s.tokenize(
|
| 80 |
query,
|
| 81 |
stopwords=LANGUAGE_DICT[self.language],
|
| 82 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
| 83 |
)
|
| 84 |
results = self._retriever.retrieve(query_tokens, k=top_k)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
"
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class BM25sRetriever(weave.Model):
|
| 19 |
+
"""
|
| 20 |
+
`BM25sRetriever` is a class that provides functionality for indexing and
|
| 21 |
+
retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s).
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
language (str): The language of the documents to be indexed and retrieved.
|
| 25 |
+
use_stemmer (bool): A flag indicating whether to use stemming during tokenization.
|
| 26 |
+
retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided,
|
| 27 |
+
a new instance is created.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
language: str
|
| 31 |
use_stemmer: bool
|
| 32 |
_retriever: Optional[bm25s.BM25]
|
|
|
|
| 41 |
self._retriever = retriever or bm25s.BM25()
|
| 42 |
|
| 43 |
def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
|
| 44 |
+
"""
|
| 45 |
+
Indexes a dataset of text chunks using the BM25 algorithm.
|
| 46 |
+
|
| 47 |
+
This function takes a dataset of text chunks identified by `chunk_dataset_name`,
|
| 48 |
+
tokenizes the text using the BM25 tokenizer with optional stemming, and indexes
|
| 49 |
+
the tokenized text using the BM25 retriever. If an `index_name` is provided, the
|
| 50 |
+
index is saved to disk and logged as a Weights & Biases artifact.
|
| 51 |
+
|
| 52 |
+
!!! example "Example Usage"
|
| 53 |
+
```python
|
| 54 |
+
import weave
|
| 55 |
+
from dotenv import load_dotenv
|
| 56 |
+
|
| 57 |
+
import wandb
|
| 58 |
+
from medrag_multi_modal.retrieval import BM25sRetriever
|
| 59 |
+
|
| 60 |
+
load_dotenv()
|
| 61 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
| 62 |
+
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
|
| 63 |
+
retriever = BM25sRetriever()
|
| 64 |
+
retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s")
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
|
| 69 |
+
index_name (Optional[str]): The name to save the index under. If provided, the index
|
| 70 |
+
is saved to disk and logged as a Weights & Biases artifact.
|
| 71 |
+
"""
|
| 72 |
chunk_dataset = weave.ref(chunk_dataset_name).get().rows
|
| 73 |
corpus = [row["text"] for row in chunk_dataset]
|
| 74 |
corpus_tokens = bm25s.tokenize(
|
|
|
|
| 95 |
|
| 96 |
@classmethod
|
| 97 |
def from_wandb_artifact(cls, index_artifact_address: str):
|
| 98 |
+
"""
|
| 99 |
+
Creates an instance of the class from a Weights & Biases artifact.
|
| 100 |
+
|
| 101 |
+
This class method retrieves a BM25 index artifact from Weights & Biases,
|
| 102 |
+
downloads the artifact, and loads the BM25 retriever with the index and its
|
| 103 |
+
associated corpus. The method also extracts metadata from the artifact to
|
| 104 |
+
initialize the class instance with the appropriate language and stemming
|
| 105 |
+
settings.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
index_artifact_address (str): The address of the Weights & Biases artifact
|
| 109 |
+
containing the BM25 index.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
An instance of the class initialized with the BM25 retriever and metadata
|
| 113 |
+
from the artifact.
|
| 114 |
+
"""
|
| 115 |
if wandb.run:
|
| 116 |
artifact = wandb.run.use_artifact(
|
| 117 |
index_artifact_address, type="bm25s-index"
|
|
|
|
| 132 |
|
| 133 |
@weave.op()
|
| 134 |
def retrieve(self, query: str, top_k: int = 2):
|
| 135 |
+
"""
|
| 136 |
+
Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm.
|
| 137 |
+
|
| 138 |
+
This method tokenizes the input query using the BM25 tokenizer, which takes into
|
| 139 |
+
account the language-specific stopwords and optional stemming. It then retrieves
|
| 140 |
+
the top-k most relevant chunks from the BM25 index based on the tokenized query.
|
| 141 |
+
The results are returned as a list of dictionaries, each containing a chunk and
|
| 142 |
+
its corresponding relevance score.
|
| 143 |
+
|
| 144 |
+
!!! example "Example Usage"
|
| 145 |
+
```python
|
| 146 |
+
import weave
|
| 147 |
+
from dotenv import load_dotenv
|
| 148 |
+
|
| 149 |
+
from medrag_multi_modal.retrieval import BM25sRetriever
|
| 150 |
+
|
| 151 |
+
load_dotenv()
|
| 152 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
| 153 |
+
retriever = BM25sRetriever.from_wandb_artifact(
|
| 154 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:v2"
|
| 155 |
+
)
|
| 156 |
+
retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
query (str): The input query string to search for relevant chunks.
|
| 161 |
+
top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
list: A list of dictionaries, each containing a retrieved chunk and its
|
| 165 |
+
relevance score.
|
| 166 |
+
"""
|
| 167 |
query_tokens = bm25s.tokenize(
|
| 168 |
query,
|
| 169 |
stopwords=LANGUAGE_DICT[self.language],
|
| 170 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
| 171 |
)
|
| 172 |
results = self._retriever.retrieve(query_tokens, k=top_k)
|
| 173 |
+
retrieved_chunks = []
|
| 174 |
+
for chunk, score in zip(
|
| 175 |
+
results["results"].flatten().tolist(),
|
| 176 |
+
results["scores"].flatten().tolist(),
|
| 177 |
+
):
|
| 178 |
+
retrieved_chunks.append({"chunk": chunk, "score": score})
|
| 179 |
+
return retrieved_chunks
|