Spaces:
Runtime error
Runtime error
| import logging | |
| import random | |
| import ray | |
| from transformers import RagConfig, RagRetriever, RagTokenizer | |
| from transformers.models.rag.retrieval_rag import CustomHFIndex | |
| logger = logging.getLogger(__name__) | |
| class RayRetriever: | |
| def __init__(self): | |
| self.initialized = False | |
| def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index): | |
| if not self.initialized: | |
| self.retriever = RagRetriever( | |
| config, | |
| question_encoder_tokenizer=question_encoder_tokenizer, | |
| generator_tokenizer=generator_tokenizer, | |
| index=index, | |
| init_retrieval=False, | |
| ) | |
| self.initialized = True | |
| def init_retrieval(self): | |
| self.retriever.index.init_index() | |
| def retrieve(self, question_hidden_states, n_docs): | |
| doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs) | |
| return doc_ids, retrieved_doc_embeds | |
| class RagRayDistributedRetriever(RagRetriever): | |
| """ | |
| A distributed retriever built on top of the ``Ray`` API, a library | |
| for building distributed applications (https://docs.ray.io/en/master/). | |
| package. During training, all training workers initialize their own | |
| instance of a `RagRayDistributedRetriever`, and each instance of | |
| this distributed retriever shares a common set of Retrieval Ray | |
| Actors (https://docs.ray.io/en/master/walkthrough.html#remote | |
| -classes-actors) that load the index on separate processes. Ray | |
| handles the communication between the `RagRayDistributedRetriever` | |
| instances and the remote Ray actors. If training is done in a | |
| non-distributed setup, the index will simply be loaded in the same | |
| process as the training worker and Ray will not be used. | |
| Args: | |
| config (:class:`~transformers.RagConfig`): | |
| The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. | |
| question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`): | |
| The tokenizer that was used to tokenize the question. | |
| It is used to decode the question and then use the generator_tokenizer. | |
| generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): | |
| The tokenizer used for the generator part of the RagModel. | |
| retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors. | |
| These actor classes run on remote processes and are responsible for performing the index lookup. | |
| index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): | |
| If specified, use this index instead of the one built using the configuration | |
| """ | |
| def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None): | |
| if index is not None and index.is_initialized() and len(retrieval_workers) > 0: | |
| raise ValueError( | |
| "When using Ray for distributed fine-tuning, " | |
| "you'll need to provide the paths instead, " | |
| "as the dataset and the index are loaded " | |
| "separately. More info in examples/rag/use_own_knowledge_dataset.py " | |
| ) | |
| super().__init__( | |
| config, | |
| question_encoder_tokenizer=question_encoder_tokenizer, | |
| generator_tokenizer=generator_tokenizer, | |
| index=index, | |
| init_retrieval=False, | |
| ) | |
| self.retrieval_workers = retrieval_workers | |
| if len(self.retrieval_workers) > 0: | |
| ray.get( | |
| [ | |
| worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index) | |
| for worker in self.retrieval_workers | |
| ] | |
| ) | |
| def init_retrieval(self): | |
| """ | |
| Retriever initialization function, needs to be called from the | |
| training process. This function triggers retrieval initialization | |
| for all retrieval actors if using distributed setting, or loads | |
| index into current process if training is not distributed. | |
| """ | |
| logger.info("initializing retrieval") | |
| if len(self.retrieval_workers) > 0: | |
| ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers]) | |
| else: | |
| # Non-distributed training. Load index into this same process. | |
| self.index.init_index() | |
| def retrieve(self, question_hidden_states, n_docs): | |
| """ | |
| Retrieves documents for specified ``question_hidden_states``. If | |
| running training with multiple workers, a random retrieval actor is | |
| selected to perform the index lookup and return the result. | |
| Args: | |
| question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): | |
| A batch of query vectors to retrieve with. | |
| n_docs (:obj:`int`): | |
| The number of docs retrieved per query. | |
| Output: | |
| retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` | |
| The retrieval embeddings of the retrieved docs per query. | |
| doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) | |
| The ids of the documents in the index | |
| doc_dicts (:obj:`List[dict]`): | |
| The retrieved_doc_embeds examples per query. | |
| """ | |
| if len(self.retrieval_workers) > 0: | |
| # Select a random retrieval actor. | |
| random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)] | |
| doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs)) | |
| else: | |
| doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) | |
| return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) | |
| def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): | |
| return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs) | |
| def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): | |
| config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) | |
| rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) | |
| question_encoder_tokenizer = rag_tokenizer.question_encoder | |
| generator_tokenizer = rag_tokenizer.generator | |
| if indexed_dataset is not None: | |
| config.index_name = "custom" | |
| index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) | |
| else: | |
| index = cls._build_index(config) | |
| return cls( | |
| config, | |
| question_encoder_tokenizer=question_encoder_tokenizer, | |
| generator_tokenizer=generator_tokenizer, | |
| retrieval_workers=actor_handles, | |
| index=index, | |
| ) | |