Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| import torch | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.models import ( | |
| FairseqIncrementalDecoder, | |
| FairseqLanguageModel, | |
| register_model, | |
| ) | |
| from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
| from omegaconf import II | |
| logger = logging.getLogger(__name__) | |
| class TransformerXLConfig(FairseqDataclass): | |
| # defaults come from the original Transformer-XL code | |
| cutoffs: List[int] = field(default_factory=lambda: [20000, 40000, 200000]) | |
| d_model: int = 500 | |
| n_head: int = 10 | |
| d_head: int = 50 | |
| d_inner: int = 1000 | |
| div_val: int = 1 | |
| n_layer: int = 12 | |
| mem_len: int = 0 | |
| clamp_len: int = -1 | |
| same_length: bool = False | |
| dropout: float = 0.0 | |
| dropatt: float = 0.0 | |
| checkpoint_activations: bool = False | |
| offload_activations: bool = False | |
| max_target_positions: int = II("task.max_target_positions") | |
| class TransformerXLLanguageModel(FairseqLanguageModel): | |
| def build_model(cls, cfg: TransformerXLConfig, task): | |
| return cls(TransformerXLDecoder(cfg, task)) | |
| class TransformerXLDecoder(FairseqIncrementalDecoder): | |
| def __init__(self, cfg, task): | |
| try: | |
| from transformers.models.transfo_xl import ( | |
| TransfoXLConfig, | |
| TransfoXLLMHeadModel, | |
| ) | |
| except ImportError: | |
| from transformers.configuration_transfo_xl import TransfoXLConfig | |
| from transformers.modeling_transfo_xl import TransfoXLLMHeadModel | |
| super().__init__(task.target_dictionary) | |
| self.cfg = cfg | |
| # remove any cutoffs larger than the vocab size | |
| cutoffs = [ | |
| cutoff for cutoff in cfg.cutoffs if cutoff < len(task.target_dictionary) | |
| ] | |
| config = TransfoXLConfig( | |
| vocab_size=len(task.target_dictionary), | |
| cutoffs=cutoffs, | |
| d_model=cfg.d_model, | |
| d_embed=cfg.d_model, | |
| n_head=cfg.n_head, | |
| d_head=cfg.d_head, | |
| d_inner=cfg.d_inner, | |
| div_val=cfg.div_val, | |
| n_layer=cfg.n_layer, | |
| mem_len=cfg.mem_len, | |
| clamp_len=cfg.clamp_len, | |
| same_length=cfg.same_length, | |
| dropout=cfg.dropout, | |
| dropatt=cfg.dropatt, | |
| ) | |
| logger.info(config) | |
| self.model = TransfoXLLMHeadModel(config) | |
| # Workaround a bug in huggingface's ``ProjectedAdaptiveLogSoftmax`` | |
| # which adds ``None`` values to an ``nn.ParameterList``, which is not | |
| # supported in PyTorch. Instead we can replace this with an | |
| # ``nn.ModuleList``, which does support ``None`` values. | |
| try: | |
| if all(p is None for p in self.model.crit.out_projs._parameters.values()): | |
| self.model.crit.out_projs = torch.nn.ModuleList( | |
| [None] * len(self.model.crit.out_projs._parameters) | |
| ) | |
| except Exception: | |
| pass | |
| if cfg.checkpoint_activations or cfg.offload_activations: | |
| for i in range(len(self.model.transformer.layers)): | |
| self.model.transformer.layers[i] = checkpoint_wrapper( | |
| self.model.transformer.layers[i], | |
| offload_to_cpu=cfg.offload_activations, | |
| ) | |
| # TODO: may save mem to wrap(layer.pos_ff.CoreNet[3]) | |
| self._mems = None | |
| def forward( | |
| self, | |
| src_tokens, | |
| src_lengths=None, # unused | |
| incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, | |
| encoder_out=None, | |
| ): | |
| if incremental_state is not None: # used during inference | |
| mems = self.get_incremental_state(incremental_state, "mems") | |
| src_tokens = src_tokens[:, -1:] # only keep the most recent token | |
| else: | |
| mems = self._mems | |
| output = self.model( | |
| input_ids=src_tokens, | |
| mems=mems, | |
| return_dict=False, | |
| ) | |
| if len(output) >= 2: | |
| if incremental_state is not None: | |
| self.set_incremental_state(incremental_state, "mems", output[1]) | |
| else: | |
| self._mems = output[1] | |
| return (output[0],) | |
| def max_positions(self): | |
| return self.cfg.max_target_positions | |
| def reorder_incremental_state( | |
| self, | |
| incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], | |
| new_order: torch.Tensor, | |
| ): | |
| """Reorder incremental state. | |
| This will be called when the order of the input has changed from the | |
| previous time step. A typical use case is beam search, where the input | |
| order changes between time steps based on the selection of beams. | |
| """ | |
| mems = self.get_incremental_state(incremental_state, "mems") | |
| if mems is not None: | |
| new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] | |
| self.set_incremental_state(incremental_state, "mems", new_mems) | |