update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from collections import defaultdict | |
| from typing import Callable, Dict, List, Optional, Sequence, Union | |
| from pandas import MultiIndex | |
| from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric | |
| from pytorch_ie.annotations import BinaryRelation | |
| from pytorch_ie.core.metric import T | |
| from pytorch_ie.utils.hydra import resolve_target | |
| from src.hydra_callbacks.save_job_return_value import to_py_obj | |
| logger = logging.getLogger(__name__) | |
| class RankingMetricsSKLearn(DocumentMetric): | |
| """Ranking metrics for documents with binary relations. | |
| This metric computes the ranking metrics for retrieval tasks, where | |
| relation heads are the queries and the relation tails are the candidates. | |
| The metric is computed for each head and the results are averaged. It is meant to | |
| be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized | |
| Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score` | |
| (LRAP), etc., see | |
| https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics. | |
| Args: | |
| metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their | |
| corresponding functions. The function can be a string (name of the function, e.g., | |
| sklearn.metrics.ndcg_score) or a callable. | |
| layer (str): The name of the annotation layer containing the binary relations, e.g., | |
| "binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations. | |
| use_manual_average (Optional[List[str]]): A list of metric names to use for manual | |
| averaging. If provided, the metric scores will be calculated for each | |
| head and then averaged. Otherwise, all true and predicted scores will be | |
| passed to the metric function at once. | |
| exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons | |
| from the computation, i.e., entries (heads) where the number of candidates is 1. | |
| label (Optional[str]): If provided, only the relations with this label will be used | |
| to compute the metrics. This is useful for filtering out relations that are not | |
| relevant for the task at hand (e.g., when having multiple relation types in the | |
| same layer). | |
| score_threshold (float): If provided, only the relations with a score greater than or | |
| equal to this threshold will be used to compute the metrics. | |
| default_score (float): The default score to use for missing relations, either in the | |
| target or prediction. Default is 0.0. | |
| use_all_spans (bool): Whether to consider all spans in the document as queries and | |
| candidates or only the spans that are present in the target and prediction. | |
| span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with | |
| heads/tails that are in this list. When using use_all_spans=True, this also | |
| restricts the candidates to those that are not in the blacklist. | |
| show_as_markdown (bool): Whether to show the results as markdown. Default is False. | |
| markdown_precision (int): The precision for displaying the results in markdown. | |
| Default is 4. | |
| """ | |
| def __init__( | |
| self, | |
| metrics: Dict[str, Union[str, Callable]], | |
| layer: str, | |
| use_manual_average: Optional[List[str]] = None, | |
| exclude_singletons: Optional[List[str]] = None, | |
| label: Optional[str] = None, | |
| score_threshold: float = 0.0, | |
| default_score: float = 0.0, | |
| use_all_spans: bool = False, | |
| span_label_blacklist: Optional[List[str]] = None, | |
| show_as_markdown: bool = False, | |
| markdown_precision: int = 4, | |
| plot: bool = False, | |
| ): | |
| self.metrics = { | |
| name: resolve_target(metric) if isinstance(metric, str) else metric | |
| for name, metric in metrics.items() | |
| } | |
| self.use_manual_average = set(use_manual_average or []) | |
| self.exclude_singletons = set(exclude_singletons or []) | |
| self.annotation_layer_name = layer | |
| self.annotation_label = label | |
| self.score_threshold = score_threshold | |
| self.default_score = default_score | |
| self.use_all_spans = use_all_spans | |
| self.span_label_blacklist = span_label_blacklist | |
| self.show_as_markdown = show_as_markdown | |
| self.markdown_precision = markdown_precision | |
| self.plot = plot | |
| super().__init__() | |
| def reset(self) -> None: | |
| self._preds: List[List[float]] = [] | |
| self._targets: List[List[float]] = [] | |
| def get_head2tail2score( | |
| self, relations: Sequence[BinaryRelation] | |
| ) -> Dict[Annotation, Dict[Annotation, float]]: | |
| result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict) | |
| for rel in relations: | |
| if ( | |
| (self.annotation_label is None or rel.label == self.annotation_label) | |
| and (rel.score >= self.score_threshold) | |
| and ( | |
| self.span_label_blacklist is None | |
| or ( | |
| rel.head.label not in self.span_label_blacklist | |
| and rel.tail.label not in self.span_label_blacklist | |
| ) | |
| ) | |
| ): | |
| result[rel.head][rel.tail] = rel.score | |
| return result | |
| def _update(self, document: Document) -> None: | |
| annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name] | |
| target_head2tail2score = self.get_head2tail2score(annotation_layer) | |
| prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions) | |
| all_spans = set() | |
| # get spans from all layers targeted by the annotation (relation) layer | |
| for span_layer in annotation_layer.target_layers.values(): | |
| all_spans.update(span_layer) | |
| if self.span_label_blacklist is not None: | |
| all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist} | |
| if self.use_all_spans: | |
| all_heads = all_spans | |
| else: | |
| all_heads = set(target_head2tail2score) | set(prediction_head2tail2score) | |
| all_targets: List[List[float]] = [] | |
| all_predictions: List[List[float]] = [] | |
| for head in all_heads: | |
| target_tail2score = target_head2tail2score.get(head, {}) | |
| prediction_tail2score = prediction_head2tail2score.get(head, {}) | |
| if self.use_all_spans: | |
| # use all spans as tails | |
| tails = set(span for span in all_spans if span != head) | |
| else: | |
| # use only the tails that are in the target or prediction | |
| tails = set(target_tail2score) | set(prediction_tail2score) | |
| target_scores = [target_tail2score.get(t, self.default_score) for t in tails] | |
| prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails] | |
| all_targets.append(target_scores) | |
| all_predictions.append(prediction_scores) | |
| self._targets.extend(all_targets) | |
| self._preds.extend(all_predictions) | |
| def do_plot(self): | |
| raise NotImplementedError() | |
| def _compute(self) -> T: | |
| if self.plot: | |
| self.do_plot() | |
| result = {} | |
| for name, metric in self.metrics.items(): | |
| targets, preds = self._targets, self._preds | |
| if name in self.exclude_singletons: | |
| targets = [t for t in targets if len(t) > 1] | |
| preds = [p for p in preds if len(p) > 1] | |
| num_singletons = len(self._targets) - len(targets) | |
| logger.warning( | |
| f"Excluding {num_singletons} singletons (out of {len(self._targets)} " | |
| f"entries) from {name} metric calculation." | |
| ) | |
| if name in self.use_manual_average: | |
| scores = [ | |
| metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds) | |
| ] | |
| result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0 | |
| else: | |
| result[name] = metric(y_true=targets, y_score=preds) | |
| result = to_py_obj(result) | |
| if self.show_as_markdown: | |
| import pandas as pd | |
| series = pd.Series(result) | |
| if isinstance(series.index, MultiIndex): | |
| if len(series.index.levels) > 1: | |
| # in fact, this is not a series anymore | |
| series = series.unstack(-1) | |
| else: | |
| series.index = series.index.get_level_values(0) | |
| logger.info( | |
| f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" | |
| ) | |
| return result | |