update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| from typing import Callable, Hashable, Optional, Tuple | |
| from pie_modules.metrics import F1Metric | |
| from pytorch_ie import Annotation, Document | |
| class F1WithThresholdMetric(F1Metric): | |
| def __init__(self, *args, threshold: float = 0.0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.threshold = threshold | |
| def calculate_counts( | |
| self, | |
| document: Document, | |
| annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
| annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
| ) -> Tuple[int, int, int]: | |
| annotation_processor = annotation_processor or (lambda ann: ann) | |
| annotation_filter = annotation_filter or (lambda ann: True) | |
| predicted_annotations = { | |
| annotation_processor(ann) | |
| for ann in document[self.layer].predictions | |
| if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold | |
| } | |
| gold_annotations = { | |
| annotation_processor(ann) | |
| for ann in document[self.layer] | |
| if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold | |
| } | |
| tp = len([ann for ann in predicted_annotations & gold_annotations]) | |
| fn = len([ann for ann in gold_annotations - predicted_annotations]) | |
| fp = len([ann for ann in predicted_annotations - gold_annotations]) | |
| return tp, fp, fn | |