Spaces:
Runtime error
Runtime error
| import collections | |
| from typing import Any, Dict, Iterator, List, Optional | |
| import torch | |
| from transformers import AutoModel | |
| from transformers.activations import ClippedGELUActivation, GELUActivation | |
| from transformers.modeling_utils import PoolerEndLogits | |
| from relik.reader.data.relik_reader_sample import RelikReaderSample | |
| activation2functions = { | |
| "relu": torch.nn.ReLU(), | |
| "gelu": GELUActivation(), | |
| "gelu_10": ClippedGELUActivation(-10, 10), | |
| } | |
| class RelikReaderCoreModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| transformer_model: str, | |
| additional_special_symbols: int, | |
| num_layers: Optional[int] = None, | |
| activation: str = "gelu", | |
| linears_hidden_size: Optional[int] = 512, | |
| use_last_k_layers: int = 1, | |
| training: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| # Transformer model declaration | |
| self.transformer_model_name = transformer_model | |
| self.transformer_model = ( | |
| AutoModel.from_pretrained(transformer_model) | |
| if num_layers is None | |
| else AutoModel.from_pretrained( | |
| transformer_model, num_hidden_layers=num_layers | |
| ) | |
| ) | |
| self.transformer_model.resize_token_embeddings( | |
| self.transformer_model.config.vocab_size + additional_special_symbols | |
| ) | |
| self.activation = activation | |
| self.linears_hidden_size = linears_hidden_size | |
| self.use_last_k_layers = use_last_k_layers | |
| # named entity detection layers | |
| self.ned_start_classifier = self._get_projection_layer( | |
| self.activation, last_hidden=2, layer_norm=False | |
| ) | |
| self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) | |
| # END entity disambiguation layer | |
| self.ed_start_projector = self._get_projection_layer(self.activation) | |
| self.ed_end_projector = self._get_projection_layer(self.activation) | |
| self.training = training | |
| # criterion | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| def _get_projection_layer( | |
| self, | |
| activation: str, | |
| last_hidden: Optional[int] = None, | |
| input_hidden=None, | |
| layer_norm: bool = True, | |
| ) -> torch.nn.Sequential: | |
| head_components = [ | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| self.transformer_model.config.hidden_size * self.use_last_k_layers | |
| if input_hidden is None | |
| else input_hidden, | |
| self.linears_hidden_size, | |
| ), | |
| activation2functions[activation], | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| self.linears_hidden_size, | |
| self.linears_hidden_size if last_hidden is None else last_hidden, | |
| ), | |
| ] | |
| if layer_norm: | |
| head_components.append( | |
| torch.nn.LayerNorm( | |
| self.linears_hidden_size if last_hidden is None else last_hidden, | |
| self.transformer_model.config.layer_norm_eps, | |
| ) | |
| ) | |
| return torch.nn.Sequential(*head_components) | |
| def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| mask = mask.unsqueeze(-1) | |
| if next(self.parameters()).dtype == torch.float16: | |
| logits = logits * (1 - mask) - 65500 * mask | |
| else: | |
| logits = logits * (1 - mask) - 1e30 * mask | |
| return logits | |
| def _get_model_features( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor], | |
| ): | |
| model_input = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "output_hidden_states": self.use_last_k_layers > 1, | |
| } | |
| if token_type_ids is not None: | |
| model_input["token_type_ids"] = token_type_ids | |
| model_output = self.transformer_model(**model_input) | |
| if self.use_last_k_layers > 1: | |
| model_features = torch.cat( | |
| model_output[1][-self.use_last_k_layers :], dim=-1 | |
| ) | |
| else: | |
| model_features = model_output[0] | |
| return model_features | |
| def compute_ned_end_logits( | |
| self, | |
| start_predictions, | |
| start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| ) -> Optional[torch.Tensor]: | |
| # todo: maybe when constraining on the spans, | |
| # we should not use a prediction_mask for the end tokens. | |
| # at least we should not during training imo | |
| start_positions = start_labels if self.training else start_predictions | |
| start_positions_indices = ( | |
| torch.arange(start_positions.size(1), device=start_positions.device) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1)[start_positions > 0] | |
| ).to(start_positions.device) | |
| if len(start_positions_indices) > 0: | |
| expanded_features = torch.cat( | |
| [ | |
| model_features[i].unsqueeze(0).expand(x, -1, -1) | |
| for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
| if x > 0 | |
| ], | |
| dim=0, | |
| ).to(start_positions_indices.device) | |
| expanded_prediction_mask = torch.cat( | |
| [ | |
| prediction_mask[i].unsqueeze(0).expand(x, -1) | |
| for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
| if x > 0 | |
| ], | |
| dim=0, | |
| ).to(expanded_features.device) | |
| end_logits = self.ned_end_classifier( | |
| hidden_states=expanded_features, | |
| start_positions=start_positions_indices, | |
| p_mask=expanded_prediction_mask, | |
| ) | |
| return end_logits | |
| return None | |
| def compute_classification_logits( | |
| self, | |
| model_features, | |
| special_symbols_mask, | |
| prediction_mask, | |
| batch_size, | |
| start_positions=None, | |
| end_positions=None, | |
| ) -> torch.Tensor: | |
| if start_positions is None or end_positions is None: | |
| start_positions = torch.zeros_like(prediction_mask) | |
| end_positions = torch.zeros_like(prediction_mask) | |
| model_start_features = self.ed_start_projector(model_features) | |
| model_end_features = self.ed_end_projector(model_features) | |
| model_end_features[start_positions > 0] = model_end_features[end_positions > 0] | |
| model_ed_features = torch.cat( | |
| [model_start_features, model_end_features], dim=-1 | |
| ) | |
| # computing ed features | |
| classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() | |
| special_symbols_representation = model_ed_features[special_symbols_mask].view( | |
| batch_size, classes_representations, -1 | |
| ) | |
| logits = torch.bmm( | |
| model_ed_features, | |
| torch.permute(special_symbols_representation, (0, 2, 1)), | |
| ) | |
| logits = self._mask_logits(logits, prediction_mask) | |
| return logits | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| prediction_mask: Optional[torch.Tensor] = None, | |
| special_symbols_mask: Optional[torch.Tensor] = None, | |
| start_labels: Optional[torch.Tensor] = None, | |
| end_labels: Optional[torch.Tensor] = None, | |
| use_predefined_spans: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| batch_size, seq_len = input_ids.shape | |
| model_features = self._get_model_features( | |
| input_ids, attention_mask, token_type_ids | |
| ) | |
| # named entity detection if required | |
| if use_predefined_spans: # no need to compute spans | |
| ned_start_logits, ned_start_probabilities, ned_start_predictions = ( | |
| None, | |
| None, | |
| torch.clone(start_labels) | |
| if start_labels is not None | |
| else torch.zeros_like(input_ids), | |
| ) | |
| ned_end_logits, ned_end_probabilities, ned_end_predictions = ( | |
| None, | |
| None, | |
| torch.clone(end_labels) | |
| if end_labels is not None | |
| else torch.zeros_like(input_ids), | |
| ) | |
| ned_start_predictions[ned_start_predictions > 0] = 1 | |
| ned_end_predictions[ned_end_predictions > 0] = 1 | |
| else: # compute spans | |
| # start boundary prediction | |
| ned_start_logits = self.ned_start_classifier(model_features) | |
| ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) | |
| ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) | |
| ned_start_predictions = ned_start_probabilities.argmax(dim=-1) | |
| # end boundary prediction | |
| ned_start_labels = ( | |
| torch.zeros_like(start_labels) if start_labels is not None else None | |
| ) | |
| if ned_start_labels is not None: | |
| ned_start_labels[start_labels == -100] = -100 | |
| ned_start_labels[start_labels > 0] = 1 | |
| ned_end_logits = self.compute_ned_end_logits( | |
| ned_start_predictions, | |
| ned_start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| ) | |
| if ned_end_logits is not None: | |
| ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) | |
| ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) | |
| else: | |
| ned_end_logits, ned_end_probabilities = None, None | |
| ned_end_predictions = ned_start_predictions.new_zeros(batch_size) | |
| # flattening end predictions | |
| # (flattening can happen only if the | |
| # end boundaries were not predicted using the gold labels) | |
| if not self.training: | |
| flattened_end_predictions = torch.clone(ned_start_predictions) | |
| flattened_end_predictions[flattened_end_predictions > 0] = 0 | |
| batch_start_predictions = list() | |
| for elem_idx in range(batch_size): | |
| batch_start_predictions.append( | |
| torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() | |
| ) | |
| # check that the total number of start predictions | |
| # is equal to the end predictions | |
| total_start_predictions = sum(map(len, batch_start_predictions)) | |
| total_end_predictions = len(ned_end_predictions) | |
| assert ( | |
| total_start_predictions == 0 | |
| or total_start_predictions == total_end_predictions | |
| ), ( | |
| f"Total number of start predictions = {total_start_predictions}. " | |
| f"Total number of end predictions = {total_end_predictions}" | |
| ) | |
| curr_end_pred_num = 0 | |
| for elem_idx, bsp in enumerate(batch_start_predictions): | |
| for sp in bsp: | |
| ep = ned_end_predictions[curr_end_pred_num].item() | |
| if ep < sp: | |
| ep = sp | |
| # if we already set this span throw it (no overlap) | |
| if flattened_end_predictions[elem_idx, ep] == 1: | |
| ned_start_predictions[elem_idx, sp] = 0 | |
| else: | |
| flattened_end_predictions[elem_idx, ep] = 1 | |
| curr_end_pred_num += 1 | |
| ned_end_predictions = flattened_end_predictions | |
| start_position, end_position = ( | |
| (start_labels, end_labels) | |
| if self.training | |
| else (ned_start_predictions, ned_end_predictions) | |
| ) | |
| # Entity disambiguation | |
| ed_logits = self.compute_classification_logits( | |
| model_features, | |
| special_symbols_mask, | |
| prediction_mask, | |
| batch_size, | |
| start_position, | |
| end_position, | |
| ) | |
| ed_probabilities = torch.softmax(ed_logits, dim=-1) | |
| ed_predictions = torch.argmax(ed_probabilities, dim=-1) | |
| # output build | |
| output_dict = dict( | |
| batch_size=batch_size, | |
| ned_start_logits=ned_start_logits, | |
| ned_start_probabilities=ned_start_probabilities, | |
| ned_start_predictions=ned_start_predictions, | |
| ned_end_logits=ned_end_logits, | |
| ned_end_probabilities=ned_end_probabilities, | |
| ned_end_predictions=ned_end_predictions, | |
| ed_logits=ed_logits, | |
| ed_probabilities=ed_probabilities, | |
| ed_predictions=ed_predictions, | |
| ) | |
| # compute loss if labels | |
| if start_labels is not None and end_labels is not None and self.training: | |
| # named entity detection loss | |
| # start | |
| if ned_start_logits is not None: | |
| ned_start_loss = self.criterion( | |
| ned_start_logits.view(-1, ned_start_logits.shape[-1]), | |
| ned_start_labels.view(-1), | |
| ) | |
| else: | |
| ned_start_loss = 0 | |
| # end | |
| if ned_end_logits is not None: | |
| ned_end_labels = torch.zeros_like(end_labels) | |
| ned_end_labels[end_labels == -100] = -100 | |
| ned_end_labels[end_labels > 0] = 1 | |
| ned_end_loss = self.criterion( | |
| ned_end_logits, | |
| ( | |
| torch.arange( | |
| ned_end_labels.size(1), device=ned_end_labels.device | |
| ) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1)[ned_end_labels > 0] | |
| ).to(ned_end_labels.device), | |
| ) | |
| else: | |
| ned_end_loss = 0 | |
| # entity disambiguation loss | |
| start_labels[ned_start_labels != 1] = -100 | |
| ed_labels = torch.clone(start_labels) | |
| ed_labels[end_labels > 0] = end_labels[end_labels > 0] | |
| ed_loss = self.criterion( | |
| ed_logits.view(-1, ed_logits.shape[-1]), | |
| ed_labels.view(-1), | |
| ) | |
| output_dict["ned_start_loss"] = ned_start_loss | |
| output_dict["ned_end_loss"] = ned_end_loss | |
| output_dict["ed_loss"] = ed_loss | |
| output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss | |
| return output_dict | |
| def batch_predict( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| prediction_mask: Optional[torch.Tensor] = None, | |
| special_symbols_mask: Optional[torch.Tensor] = None, | |
| sample: Optional[List[RelikReaderSample]] = None, | |
| top_k: int = 5, # the amount of top-k most probable entities to predict | |
| *args, | |
| **kwargs, | |
| ) -> Iterator[RelikReaderSample]: | |
| forward_output = self.forward( | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| prediction_mask, | |
| special_symbols_mask, | |
| ) | |
| ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() | |
| ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() | |
| ed_predictions = forward_output["ed_predictions"].cpu().numpy() | |
| ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() | |
| batch_predictable_candidates = kwargs["predictable_candidates"] | |
| patch_offset = kwargs["patch_offset"] | |
| for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( | |
| sample, | |
| ned_start_predictions, | |
| ned_end_predictions, | |
| ed_predictions, | |
| ed_probabilities, | |
| batch_predictable_candidates, | |
| patch_offset, | |
| ): | |
| ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] | |
| ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] | |
| final_class2predicted_spans = collections.defaultdict(list) | |
| spans2predicted_probabilities = dict() | |
| for start_token_index, end_token_index in zip( | |
| ne_start_indices, ne_end_indices | |
| ): | |
| # predicted candidate | |
| token_class = edp[start_token_index + 1] - 1 | |
| predicted_candidate_title = pred_cands[token_class] | |
| final_class2predicted_spans[predicted_candidate_title].append( | |
| [start_token_index, end_token_index] | |
| ) | |
| # candidates probabilities | |
| classes_probabilities = edpr[start_token_index + 1] | |
| classes_probabilities_best_indices = classes_probabilities.argsort()[ | |
| ::-1 | |
| ] | |
| titles_2_probs = [] | |
| top_k = ( | |
| min( | |
| top_k, | |
| len(classes_probabilities_best_indices), | |
| ) | |
| if top_k != -1 | |
| else len(classes_probabilities_best_indices) | |
| ) | |
| for i in range(top_k): | |
| titles_2_probs.append( | |
| ( | |
| pred_cands[classes_probabilities_best_indices[i] - 1], | |
| classes_probabilities[ | |
| classes_probabilities_best_indices[i] | |
| ].item(), | |
| ) | |
| ) | |
| spans2predicted_probabilities[ | |
| (start_token_index, end_token_index) | |
| ] = titles_2_probs | |
| if "patches" not in ts._d: | |
| ts._d["patches"] = dict() | |
| ts._d["patches"][po] = dict() | |
| sample_patch = ts._d["patches"][po] | |
| sample_patch["predicted_window_labels"] = final_class2predicted_spans | |
| sample_patch["span_title_probabilities"] = spans2predicted_probabilities | |
| # additional info | |
| sample_patch["predictable_candidates"] = pred_cands | |
| yield ts | |