update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from collections import defaultdict | |
| from functools import partial | |
| from typing import ( | |
| Any, | |
| Callable, | |
| Collection, | |
| Dict, | |
| Hashable, | |
| List, | |
| Optional, | |
| Tuple, | |
| TypeAlias, | |
| Union, | |
| ) | |
| from pytorch_ie.core import Annotation, Document, DocumentMetric | |
| from pytorch_ie.utils.hydra import resolve_target | |
| from src.document.types import RelatedRelation | |
| logger = logging.getLogger(__name__) | |
| def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: | |
| return getattr(ann, label_field) in labels | |
| def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: | |
| return getattr(ann, label_field) == label | |
| InstanceType: TypeAlias = Tuple[Document, Annotation] | |
| InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]] | |
| class TPFFPFNMetric(DocumentMetric): | |
| """Computes the lists of True Positive, False Positive, and False Negative | |
| annotations for a given layer. If labels are provided, it also computes | |
| the counts for each label separately. | |
| Works only with `RelatedRelation` annotations for now. | |
| Args: | |
| layer: The layer to compute the metrics for. | |
| labels: If provided, calculate metrics for each label. | |
| label_field: The field to use for the label. Defaults to "label". | |
| """ | |
| def __init__( | |
| self, | |
| layer: str, | |
| labels: Optional[Union[Collection[str], str]] = None, | |
| label_field: str = "label", | |
| annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None, | |
| ): | |
| super().__init__() | |
| self.layer = layer | |
| self.label_field = label_field | |
| self.annotation_processor: Optional[Callable[[Annotation], Hashable]] | |
| if isinstance(annotation_processor, str): | |
| self.annotation_processor = resolve_target(annotation_processor) | |
| else: | |
| self.annotation_processor = annotation_processor | |
| self.per_label = labels is not None | |
| self.infer_labels = False | |
| if self.per_label: | |
| if isinstance(labels, str): | |
| if labels != "INFERRED": | |
| raise ValueError( | |
| "labels can only be 'INFERRED' if per_label is True and labels is a string" | |
| ) | |
| self.labels = [] | |
| self.infer_labels = True | |
| elif isinstance(labels, Collection): | |
| if not all(isinstance(label, str) for label in labels): | |
| raise ValueError("labels must be a collection of strings") | |
| if "MICRO" in labels or "MACRO" in labels: | |
| raise ValueError( | |
| "labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics" | |
| ) | |
| if len(labels) == 0: | |
| raise ValueError("labels cannot be empty") | |
| self.labels = list(labels) | |
| else: | |
| raise ValueError("labels must be a string or a collection of strings") | |
| def reset(self): | |
| self.tp_fp_fn = defaultdict(lambda: (list(), list(), list())) | |
| def get_tp_fp_fn( | |
| self, | |
| document: Document, | |
| annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
| annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
| ) -> InstancesType: | |
| annotation_processor = annotation_processor or (lambda ann: ann) | |
| annotation_filter = annotation_filter or (lambda ann: True) | |
| predicted_annotations = { | |
| annotation_processor(ann) | |
| for ann in document[self.layer].predictions | |
| if annotation_filter(ann) | |
| } | |
| gold_annotations = { | |
| annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) | |
| } | |
| tp = [(document, ann) for ann in predicted_annotations & gold_annotations] | |
| fn = [(document, ann) for ann in gold_annotations - predicted_annotations] | |
| fp = [(document, ann) for ann in predicted_annotations - gold_annotations] | |
| return tp, fp, fn | |
| def add_annotations(self, annotations: InstancesType, label: str): | |
| self.tp_fp_fn[label] = ( | |
| self.tp_fp_fn[label][0] + annotations[0], | |
| self.tp_fp_fn[label][1] + annotations[1], | |
| self.tp_fp_fn[label][2] + annotations[2], | |
| ) | |
| def _update(self, document: Document): | |
| new_tp_fp_fn = self.get_tp_fp_fn( | |
| document=document, | |
| annotation_filter=( | |
| partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) | |
| if self.per_label and not self.infer_labels | |
| else None | |
| ), | |
| annotation_processor=self.annotation_processor, | |
| ) | |
| self.add_annotations(new_tp_fp_fn, label="MICRO") | |
| if self.infer_labels: | |
| layer = document[self.layer] | |
| # collect labels from gold data and predictions | |
| for ann in list(layer) + list(layer.predictions): | |
| label = getattr(ann, self.label_field) | |
| if label not in self.labels: | |
| self.labels.append(label) | |
| if self.per_label: | |
| for label in self.labels: | |
| new_tp_fp_fn = self.get_tp_fp_fn( | |
| document=document, | |
| annotation_filter=partial( | |
| has_this_label, label_field=self.label_field, label=label | |
| ), | |
| annotation_processor=self.annotation_processor, | |
| ) | |
| self.add_annotations(new_tp_fp_fn, label=label) | |
| def format_texts(self, texts: List[str]) -> str: | |
| return "<SEP>".join(texts) | |
| def format_annotation(self, ann: Annotation) -> Dict[str, Any]: | |
| if isinstance(ann, RelatedRelation): | |
| head_resolved = ann.head.resolve() | |
| tail_resolved = ann.tail.resolve() | |
| ref_resolved = ann.reference_span.resolve() | |
| return { | |
| "related_label": ann.label, | |
| "related_score": round(ann.score, 3), | |
| "query_label": head_resolved[0], | |
| "query_texts": self.format_texts(head_resolved[1]), | |
| "query_score": round(ann.head.score, 3), | |
| "ref_label": ref_resolved[0], | |
| "ref_texts": self.format_texts(ref_resolved[1]), | |
| "ref_score": round(ann.reference_span.score, 3), | |
| "rec_label": tail_resolved[0], | |
| "rec_texts": self.format_texts(tail_resolved[1]), | |
| "rec_score": round(ann.tail.score, 3), | |
| } | |
| else: | |
| raise NotImplementedError | |
| # return ann.resolve() | |
| def format_instance(self, instance: InstanceType) -> Dict[str, Any]: | |
| document, annotation = instance | |
| result = self.format_annotation(annotation) | |
| if getattr(document, "id", None) is not None: | |
| result["document_id"] = document.id | |
| return result | |
| def _compute(self) -> Dict[str, Dict[str, list]]: | |
| res = dict() | |
| for k, instances in self.tp_fp_fn.items(): | |
| res[k] = { | |
| "tp": [self.format_instance(instance) for instance in instances[0]], | |
| "fp": [self.format_instance(instance) for instance in instances[1]], | |
| "fn": [self.format_instance(instance) for instance in instances[2]], | |
| } | |
| # if self.show_as_markdown: | |
| # logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") | |
| return res | |