update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from collections import defaultdict | |
| from typing import Callable, Dict, List, Optional, Type, TypeVar, Union | |
| from pie_datasets import Dataset, DatasetDict | |
| from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| from pytorch_ie import Document | |
| from pytorch_ie.annotations import BinaryRelation, Span | |
| from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations | |
| from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target | |
| logger = logging.getLogger(__name__) | |
| # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and | |
| # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged | |
| def apply_func_to_splits( | |
| dataset: DatasetDict, | |
| function: Union[str, Callable], | |
| result_document_type: Type[Document], | |
| **kwargs, | |
| ): | |
| resolved_func = resolve_target(function) | |
| resolved_document_type = resolve_optional_document_type(document_type=result_document_type) | |
| result_dict = dict() | |
| split: Dataset | |
| for split_name, split in dataset.items(): | |
| converted_dataset = split.map( | |
| function=resolved_func, | |
| batched=True, | |
| batch_size=len(split), | |
| result_document_type=resolved_document_type, | |
| **kwargs, | |
| ) | |
| result_dict[split_name] = converted_dataset | |
| return DatasetDict(result_dict) | |
| S = TypeVar("S", bound=Span) | |
| def shift_span(span: S, offset: int) -> S: | |
| """Shift the start and end of a span by a given offset.""" | |
| return span.copy(start=span.start + offset, end=span.end + offset) | |
| D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations) | |
| def add_predicted_semantically_same_relations_to_document( | |
| document: D, | |
| doc_id2docs_with_predictions: Dict[ | |
| str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| ], | |
| relation_label: str, | |
| argument_label_blacklist: Optional[List[str]] = None, | |
| verbose: bool = False, | |
| ) -> D: | |
| # create lookup for detached versions of the spans (attached span != detached span even if they are the same) | |
| span2span = {span.copy(): span for span in document.labeled_spans} | |
| for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []): | |
| offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"] | |
| offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"] | |
| for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions: | |
| head = shift_span(coref_rel.head, offset=offset) | |
| if head not in span2span: | |
| if verbose: | |
| logger.warning(f"doc_id={document.id}: Head span {head} not found.") | |
| continue | |
| tail = shift_span(coref_rel.tail, offset=offset_pair) | |
| if tail not in span2span: | |
| if verbose: | |
| logger.warning(f"doc_id={document.id}: Tail span {tail} not found.") | |
| continue | |
| if argument_label_blacklist is not None and ( | |
| span2span[head].label in argument_label_blacklist | |
| or span2span[tail].label in argument_label_blacklist | |
| ): | |
| continue | |
| new_rel = BinaryRelation( | |
| head=span2span[head], | |
| tail=span2span[tail], | |
| label=relation_label, | |
| score=coref_rel.score, | |
| ) | |
| document.binary_relations.predictions.append(new_rel) | |
| return document | |
| def integrate_coref_predictions_from_text_pair_documents( | |
| dataset: DatasetDict, data_dir: str, **kwargs | |
| ) -> DatasetDict: | |
| dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir) | |
| for split_name in dataset.keys(): | |
| ds_with_predictions = dataset_with_predictions[split_name] | |
| original_doc_id2docs = defaultdict(list) | |
| for doc in ds_with_predictions: | |
| original_doc_id = doc.metadata["original_doc_id"] | |
| if original_doc_id != doc.metadata["original_doc_id_pair"]: | |
| raise ValueError( | |
| f"Original document IDs do not match: " | |
| f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. " | |
| f"Cross-document coref is not supported." | |
| ) | |
| original_doc_id2docs[original_doc_id].append(doc) | |
| dataset[split_name] = dataset[split_name].map( | |
| function=add_predicted_semantically_same_relations_to_document, | |
| fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs), | |
| ) | |
| return dataset | |