update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from collections import Counter | |
| from typing import Dict, List, TypeVar | |
| from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic | |
| from pytorch_ie.annotations import BinaryRelation | |
| from src.utils.graph_utils import get_connected_components | |
| logger = logging.getLogger(__name__) | |
| A = TypeVar("A") | |
| # TODO: remove when "counts" aggregation function is available in DocumentStatistic | |
| def count_func(values: List[int]) -> Dict[int, int]: | |
| """Counts the number of occurrences of each value in the list.""" | |
| counter = Counter(values) | |
| result = {k: counter[k] for k in sorted(counter)} | |
| return result | |
| class ConnectedComponentSizes(DocumentStatistic): | |
| # TODO: use "counts" aggregation function when available in DocumentStatistic | |
| DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"] | |
| def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None: | |
| super().__init__(**kwargs) | |
| self.relation_layer = relation_layer | |
| self.link_relation_label = link_relation_label | |
| def _collect(self, document: Document) -> List[int]: | |
| relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] | |
| spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer | |
| connected_components: List[List] = get_connected_components( | |
| elements=spans, | |
| relations=relations, | |
| link_relation_label=self.link_relation_label, | |
| add_singletons=True, | |
| ) | |
| new_component_sizes = [len(component) for component in connected_components] | |
| return new_component_sizes | |