update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| from typing import Hashable, List, Optional, Sequence, TypeVar | |
| from pytorch_ie.annotations import BinaryRelation | |
| H = TypeVar("H", bound=Hashable) | |
| def get_connected_components( | |
| relations: Sequence[BinaryRelation], | |
| elements: Optional[Sequence[H]] = None, | |
| link_relation_label: Optional[str] = None, | |
| link_relation_relation_score_threshold: Optional[float] = None, | |
| add_singletons: bool = False, | |
| ) -> List[List[H]]: | |
| try: | |
| import networkx as nx | |
| except ImportError: | |
| raise ImportError( | |
| "NetworkX must be installed to use the SpansViaRelationMerger. " | |
| "You can install NetworkX with `pip install networkx`." | |
| ) | |
| # convert list of relations to a graph to easily calculate connected components to merge | |
| g = nx.Graph() | |
| link_relations = [] | |
| other_relations = [] | |
| elem2edge_relation = {} | |
| for rel in relations: | |
| if (link_relation_label is None or rel.label == link_relation_label) and ( | |
| link_relation_relation_score_threshold is None | |
| or rel.score >= link_relation_relation_score_threshold | |
| ): | |
| link_relations.append(rel) | |
| g.add_edge(rel.head, rel.tail) | |
| elem2edge_relation[rel.head] = rel | |
| elem2edge_relation[rel.tail] = rel | |
| else: | |
| other_relations.append(rel) | |
| if add_singletons: | |
| if elements is None: | |
| raise ValueError("elements must be provided if add_singletons is True") | |
| # add singletons to the graph | |
| for elem in elements: | |
| if elem not in elem2edge_relation: | |
| g.add_node(elem) | |
| return list(nx.connected_components(g)) | |