update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| from collections import defaultdict | |
| from functools import partial | |
| from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set | |
| from pie_modules.metrics import F1Metric | |
| from pytorch_ie import Annotation, Document | |
| def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: | |
| return getattr(ann, label_field) in labels | |
| def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: | |
| return getattr(ann, label_field) == label | |
| class F1WithBootstrappingMetric(F1Metric): | |
| def __init__(self, *args, bootstrap_n: int = 0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.bootstrap_n = bootstrap_n | |
| def reset(self) -> None: | |
| self.tp: Dict[str, Set[Annotation]] = defaultdict(set) | |
| self.fp: Dict[str, Set[Annotation]] = defaultdict(set) | |
| self.fn: Dict[str, Set[Annotation]] = defaultdict(set) | |
| def calculate_tp_fp_fn( | |
| self, | |
| document: Document, | |
| annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
| annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
| ) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]: | |
| annotation_processor = annotation_processor or (lambda ann: ann) | |
| annotation_filter = annotation_filter or (lambda ann: True) | |
| predicted_annotations = { | |
| annotation_processor(ann) | |
| for ann in document[self.layer].predictions | |
| if annotation_filter(ann) | |
| } | |
| gold_annotations = { | |
| annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) | |
| } | |
| return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations | |
| def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None: | |
| self.tp[label].update(tp) | |
| self.fp[label].update(fp) | |
| self.fn[label].update(fn) | |
| def _update(self, document: Document) -> None: | |
| new_values = self.calculate_tp_fp_fn( | |
| document=document, | |
| annotation_filter=( | |
| partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) | |
| if self.per_label and not self.infer_labels | |
| else None | |
| ), | |
| annotation_processor=self.annotation_processor, | |
| ) | |
| self.add_tp_fp_fn(*new_values, label="MICRO") | |
| if self.infer_labels: | |
| layer = document[self.layer] | |
| # collect labels from gold data and predictions | |
| for ann in list(layer) + list(layer.predictions): | |
| label = getattr(ann, self.label_field) | |
| if label not in self.labels: | |
| self.labels.append(label) | |
| if self.per_label: | |
| for label in self.labels: | |
| new_values = self.calculate_tp_fp_fn( | |
| document=document, | |
| annotation_filter=partial( | |
| has_this_label, label_field=self.label_field, label=label | |
| ), | |
| annotation_processor=self.annotation_processor, | |
| ) | |
| self.add_tp_fp_fn(*new_values, label=label) | |
| def _compute(self) -> Dict[str, Dict[str, float]]: | |
| res = dict() | |
| if self.per_label: | |
| res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0} | |
| for label in self.tp.keys(): | |
| tp, fp, fn = ( | |
| len(self.tp[label]), | |
| len(self.fp[label]), | |
| len(self.fn[label]), | |
| ) | |
| if tp == 0: | |
| p, r, f1 = 0.0, 0.0, 0.0 | |
| else: | |
| p = tp / (tp + fp) | |
| r = tp / (tp + fn) | |
| f1 = 2 * p * r / (p + r) | |
| res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn} | |
| if self.per_label and label in self.labels: | |
| res["MACRO"]["f1"] += f1 / len(self.labels) | |
| res["MACRO"]["p"] += p / len(self.labels) | |
| res["MACRO"]["r"] += r / len(self.labels) | |
| if self.show_as_markdown: | |
| logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") | |
| return res | |