Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| from typing import List | |
| from fairseq.data import BaseWrapperDataset, data_utils | |
| class RandomInputDataset(BaseWrapperDataset): | |
| def __init__( | |
| self, | |
| dataset, | |
| random_input_dataset, | |
| input_key_path: List[str], | |
| add_to_input, | |
| pad_idx, | |
| ): | |
| super().__init__(dataset) | |
| self.random_input_dataset = random_input_dataset | |
| if isinstance(input_key_path, str): | |
| input_key_path = [input_key_path] | |
| assert len(input_key_path) > 0 | |
| self.input_key_path = input_key_path | |
| self.add_to_input = add_to_input | |
| self.pad_idx = pad_idx | |
| def get_target(self, item): | |
| target_loc = item | |
| for p in self.input_key_path[:-1]: | |
| target_loc = target_loc[p] | |
| return self.input_key_path[-1], target_loc | |
| def get_target_value(self, item): | |
| k, target_loc = self.get_target(item) | |
| return target_loc[k] | |
| def __getitem__(self, index): | |
| item = self.dataset[index] | |
| k, target_loc = self.get_target(item) | |
| target_loc[k] = random.choice(self.random_input_dataset) | |
| return item | |
| def collater(self, samples): | |
| collated = self.dataset.collater(samples) | |
| if len(collated) == 0: | |
| return collated | |
| indices = set(collated["id"].tolist()) | |
| random_inputs = data_utils.collate_tokens( | |
| [self.get_target_value(s) for s in samples if s["id"] in indices], | |
| pad_idx=self.pad_idx, | |
| left_pad=False, | |
| ) | |
| k, target_loc = self.get_target( | |
| collated if not self.add_to_input else collated["net_input"] | |
| ) | |
| target_loc[k] = random_inputs | |
| return collated | |