Spaces:
Running
Running
| import os | |
| from typing import TYPE_CHECKING, Any, Optional | |
| import weave | |
| if TYPE_CHECKING: | |
| from byaldi import RAGMultiModalModel | |
| import wandb | |
| from PIL import Image | |
| from medrag_multi_modal.utils import get_wandb_artifact | |
| class CalPaliRetriever(weave.Model): | |
| """ | |
| CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali. | |
| This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks. | |
| It can be initialized with a pre-trained model or from a specified W&B artifact. The class | |
| also provides methods to index new data and to predict/retrieve documents based on a query. | |
| Attributes: | |
| model_name (str): The name of the model to be used for retrieval. | |
| """ | |
| model_name: str | |
| _docs_retrieval_model: Optional["RAGMultiModalModel"] = None | |
| _metadata: Optional[dict] = None | |
| _data_artifact_dir: Optional[str] = None | |
| def __init__( | |
| self, | |
| model_name: str = "vidore/colpali-v1.2", | |
| docs_retrieval_model: Optional["RAGMultiModalModel"] = None, | |
| data_artifact_dir: Optional[str] = None, | |
| metadata_dataset_name: Optional[str] = None, | |
| ): | |
| super().__init__(model_name=model_name) | |
| from byaldi import RAGMultiModalModel | |
| self._docs_retrieval_model = ( | |
| docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name) | |
| ) | |
| self._data_artifact_dir = data_artifact_dir | |
| self._metadata = ( | |
| [dict(row) for row in weave.ref(metadata_dataset_name).get().rows] | |
| if metadata_dataset_name | |
| else None | |
| ) | |
| def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str): | |
| """ | |
| Indexes a dataset of documents and saves the index as a Weave artifact. | |
| This method retrieves a dataset of documents from a Weave artifact using the provided | |
| data artifact name. It then indexes the documents using the document retrieval model | |
| and assigns the specified index name. The index is stored locally without storing the | |
| collection with the index and overwrites any existing index with the same name. | |
| If a Weave run is active, the method creates a new Weave artifact with the specified | |
| index name and type "colpali-index". It adds the local index directory to the artifact | |
| and saves it to Weave, including metadata with the provided Weave dataset name. | |
| !!! example "Indexing Data" | |
| First you need to install `Byaldi` library by Answer.ai. | |
| ```bash | |
| uv pip install Byaldi>=0.0.5 | |
| ``` | |
| Next, you can index the data by running the following code: | |
| ```python | |
| import wandb | |
| from medrag_multi_modal.retrieval import CalPaliRetriever | |
| wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index") | |
| retriever = CalPaliRetriever() | |
| retriever.index( | |
| data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", | |
| weave_dataset_name="grays-anatomy-images:v0", | |
| index_name="grays-anatomy", | |
| ) | |
| ``` | |
| ??? note "Optional Speedup using Flash Attention" | |
| If you have a GPU with Flash Attention support, you can enable it for ColPali by simply | |
| installing the `flash-attn` package. | |
| ```bash | |
| uv pip install flash-attn --no-build-isolation | |
| ``` | |
| Args: | |
| data_artifact_name (str): The name of the Weave artifact containing the dataset. | |
| weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata. | |
| index_name (str): The name to assign to the created index. | |
| """ | |
| data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset") | |
| self._docs_retrieval_model.index( | |
| input_path=data_artifact_dir, | |
| index_name=index_name, | |
| store_collection_with_index=False, | |
| overwrite=True, | |
| ) | |
| if wandb.run: | |
| artifact = wandb.Artifact( | |
| name=index_name, | |
| type="colpali-index", | |
| metadata={"weave_dataset_name": weave_dataset_name}, | |
| ) | |
| artifact.add_dir( | |
| local_path=os.path.join(".byaldi", index_name), name="index" | |
| ) | |
| artifact.save() | |
| def from_wandb_artifact( | |
| cls, | |
| index_artifact_name: str, | |
| metadata_dataset_name: str, | |
| data_artifact_name: str, | |
| ): | |
| """ | |
| Creates an instance of the class from Weights & Biases (wandb) artifacts. | |
| This method retrieves the necessary artifacts from wandb to initialize the | |
| ColPaliRetriever. It fetches the index artifact directory and the data artifact | |
| directory using the provided artifact names. It then loads the document retrieval | |
| model from the index path within the index artifact directory. Finally, it returns | |
| an instance of the class initialized with the retrieved document retrieval model, | |
| metadata dataset name, and data artifact directory. | |
| !!! example "Retrieving Documents" | |
| First you need to install `Byaldi` library by Answer.ai. | |
| ```bash | |
| uv pip install Byaldi>=0.0.5 | |
| ``` | |
| Next, you can retrieve the documents by running the following code: | |
| ```python | |
| import weave | |
| import wandb | |
| from medrag_multi_modal.retrieval import CalPaliRetriever | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = CalPaliRetriever.from_wandb_artifact( | |
| index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0", | |
| metadata_dataset_name="grays-anatomy-images:v0", | |
| data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", | |
| ) | |
| ``` | |
| ??? note "Optional Speedup using Flash Attention" | |
| If you have a GPU with Flash Attention support, you can enable it for ColPali by simply | |
| installing the `flash-attn` package. | |
| ```bash | |
| uv pip install flash-attn --no-build-isolation | |
| ``` | |
| Args: | |
| index_artifact_name (str): The name of the wandb artifact containing the index. | |
| metadata_dataset_name (str): The name of the dataset containing metadata. | |
| data_artifact_name (str): The name of the wandb artifact containing the data. | |
| Returns: | |
| An instance of the class initialized with the retrieved document retrieval model, | |
| metadata dataset name, and data artifact directory. | |
| """ | |
| from byaldi import RAGMultiModalModel | |
| index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index") | |
| data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset") | |
| docs_retrieval_model = RAGMultiModalModel.from_index( | |
| index_path=os.path.join(index_artifact_dir, "index") | |
| ) | |
| return cls( | |
| docs_retrieval_model=docs_retrieval_model, | |
| metadata_dataset_name=metadata_dataset_name, | |
| data_artifact_dir=data_artifact_dir, | |
| ) | |
| def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]: | |
| """ | |
| Predicts and retrieves the top-k most relevant documents/images for a given query | |
| using ColPali. | |
| This function uses the document retrieval model to search for the most relevant | |
| documents based on the provided query. It returns a list of dictionaries, each | |
| containing the document image, document ID, and the relevance score. | |
| !!! example "Retrieving Documents" | |
| First you need to install `Byaldi` library by Answer.ai. | |
| ```bash | |
| uv pip install Byaldi>=0.0.5 | |
| ``` | |
| Next, you can retrieve the documents by running the following code: | |
| ```python | |
| import weave | |
| import wandb | |
| from medrag_multi_modal.retrieval import CalPaliRetriever | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| retriever = CalPaliRetriever.from_wandb_artifact( | |
| index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0", | |
| metadata_dataset_name="grays-anatomy-images:v0", | |
| data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", | |
| ) | |
| retriever.predict( | |
| query="which neurotransmitters convey information between Merkel cells and sensory afferents?", | |
| top_k=3, | |
| ) | |
| ``` | |
| ??? note "Optional Speedup using Flash Attention" | |
| If you have a GPU with Flash Attention support, you can enable it for ColPali by simply | |
| installing the `flash-attn` package. | |
| ```bash | |
| uv pip install flash-attn --no-build-isolation | |
| ``` | |
| Args: | |
| query (str): The search query string. | |
| top_k (int, optional): The number of top results to retrieve. Defaults to 10. | |
| Returns: | |
| list[dict[str, Any]]: A list of dictionaries where each dictionary contains: | |
| - "doc_image" (PIL.Image.Image): The image of the document. | |
| - "doc_id" (str): The ID of the document. | |
| - "score" (float): The relevance score of the document. | |
| """ | |
| results = self._docs_retrieval_model.search(query=query, k=top_k) | |
| retrieved_results = [] | |
| for result in results: | |
| retrieved_results.append( | |
| { | |
| "doc_image": Image.open( | |
| os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png") | |
| ), | |
| "doc_id": result["doc_id"], | |
| "score": result["score"], | |
| } | |
| ) | |
| return retrieved_results | |