update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| import warnings | |
| from collections import defaultdict | |
| from functools import partial | |
| from typing import Callable, Iterable, List, Optional, Set, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from pytorch_ie import DocumentMetric | |
| from pytorch_ie.annotations import BinaryRelation | |
| from sklearn.metrics import average_precision_score, ndcg_score | |
| logger = logging.getLogger(__name__) | |
| NEG_INF = -1e9 # smaller than any real score | |
| # metrics | |
| def true_mrr(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float: | |
| """ | |
| Macro MRR over *all* queries. | |
| β’ Reciprocal rank is 0 when a query has no relevant item. | |
| β’ If k is given, restrict the search to the top-k list. | |
| """ | |
| if y_true.size == 0: | |
| return np.nan | |
| rr = [] | |
| for t, s in zip(y_true, y_score): | |
| if t.sum() == 0: | |
| rr.append(0.0) | |
| continue | |
| order = np.argsort(-s) | |
| if k is not None: | |
| order = order[:k] | |
| # first position where t == 1, +1 for 1-based rank | |
| first_hit = np.flatnonzero(t[order] > 0) | |
| rank = first_hit[0] + 1 if first_hit.size else np.inf | |
| rr.append(0.0 if np.isinf(rank) else 1.0 / rank) | |
| return np.mean(rr) | |
| def macro_ndcg(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float: | |
| """ | |
| Macro NDCG@k over all queries. | |
| ndcg_score returns 0 when a query has no positives, so no masking is required. | |
| """ | |
| if y_true.size == 0: | |
| return np.nan | |
| return ndcg_score(y_true, y_score, k=k) | |
| def macro_map(y_true: np.ndarray, y_score: np.ndarray) -> float: | |
| """ | |
| Macro MAP: mean of Average-Precision per query. | |
| Queries without positives contribute AP = 0. | |
| """ | |
| if y_true.size == 0: | |
| return np.nan | |
| ap = [] | |
| for t, s in zip(y_true, y_score): | |
| if t.sum() == 0: | |
| ap.append(0.0) | |
| else: | |
| ap.append(average_precision_score(t, s)) | |
| return np.mean(ap) | |
| def ap_micro(y_true: np.ndarray, y_score: np.ndarray) -> float: | |
| """ | |
| Micro AP over the entire pool (unchanged). | |
| """ | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", message="No positive class found in y_true") | |
| return average_precision_score(y_true.ravel(), y_score.ravel()) | |
| # --------------------------- | |
| # Recall@k | |
| # --------------------------- | |
| def recall_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
| """ | |
| Micro Recall@k (a.k.a. instance-level recall) | |
| β Each *positive instance* counts once, regardless of which query it belongs to. | |
| β Denominator = total #positives across the whole pool. | |
| """ | |
| total_pos = y_true.sum() | |
| if total_pos == 0: | |
| return np.nan | |
| topk = np.argsort(-y_score, axis=1)[:, :k] # indices of top-k per query | |
| rows = np.arange(topk.shape[0])[:, None] | |
| hits = (y_true[rows, topk] > 0).sum() # total #hits (instances) | |
| return hits / total_pos | |
| def recall_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
| """ | |
| Macro Recall@k (query-level recall) | |
| β First compute recall per *query* (#hits / #positives in that query). | |
| β Then average across all queries that actually contain β₯1 positive. | |
| """ | |
| mask = y_true.sum(axis=1) > 0 # keep only valid queries | |
| if not mask.any(): | |
| return np.nan | |
| Yt, Ys = y_true[mask], y_score[mask] | |
| topk = np.argsort(-Ys, axis=1)[:, :k] | |
| rows = np.arange(Yt.shape[0])[:, None] | |
| hits_per_q = (Yt[rows, topk] > 0).sum(axis=1) # shape: (n_queries,) | |
| pos_per_q = Yt.sum(axis=1) | |
| return np.mean(hits_per_q / pos_per_q) # average of query recalls | |
| # --------------------------- | |
| # Precision@k | |
| # --------------------------- | |
| def precision_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
| """ | |
| Micro Precision@k (pool-level precision) | |
| β Numerator = total #hits across all queries. | |
| β Denominator = total #predictions considered (n_queries Β· k). | |
| """ | |
| if y_true.size == 0: | |
| return np.nan | |
| topk = np.argsort(-y_score, axis=1)[:, :k] | |
| rows = np.arange(topk.shape[0])[:, None] | |
| hits = (y_true[rows, topk] > 0).sum() | |
| total_pred = y_true.shape[0] * k | |
| return hits / total_pred | |
| def precision_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
| """ | |
| Macro Precision@k (query-level precision) | |
| β Compute precision = (#hits / k) for each query, **including those with zero positives**, | |
| then average. | |
| """ | |
| if y_true.size == 0: | |
| return np.nan | |
| topk = np.argsort(-y_score, axis=1)[:, :k] | |
| rows = np.arange(topk.shape[0])[:, None] | |
| rel = y_true[rows, topk] > 0 # shape: (n_queries, k) | |
| precision_per_q = rel.mean(axis=1) # mean over k positions | |
| return precision_per_q.mean() | |
| # helper methods | |
| def bootstrap( | |
| metric_fn: Callable[[np.ndarray, np.ndarray], float], | |
| y_true: np.ndarray, | |
| y_score: np.ndarray, | |
| n: int = 1000, | |
| rng=None, | |
| ) -> dict[str, float]: | |
| rng = np.random.default_rng(rng) | |
| idx = np.arange(len(y_true)) | |
| vals: list[float] = [] | |
| while len(vals) < n: | |
| sample = rng.choice(idx, size=len(idx), replace=True) | |
| t = y_true[sample] | |
| s = y_score[sample] | |
| if t.sum() == 0: # no positive at all β resample | |
| continue | |
| vals.append(metric_fn(t, s)) | |
| result = np.asarray(vals) | |
| # get 95% confidence interval | |
| lo, hi = np.percentile(result, [2.5, 97.5]) | |
| return {"mean": result.mean(), "low": lo, "high": hi} | |
| def evaluate_with_ranx( | |
| pred_rels: set[BinaryRelation], | |
| target_rels: set[BinaryRelation], | |
| metrics: list[str], | |
| include_queries_without_gold: bool = True, | |
| ) -> dict[str, float]: | |
| # lazy import to not require ranx via requirements.txt | |
| import ranx | |
| all_rels = set(pred_rels) | set(target_rels) | |
| all_heads = {rel.head for rel in all_rels} | |
| head2id = {head: f"q_{idx}" for idx, head in enumerate(sorted(all_heads))} | |
| tail_and_label2id = {(ann.tail, ann.label): f"d_{idx}" for idx, ann in enumerate(all_rels)} | |
| qrels_dict: dict[str, dict[str, int]] = defaultdict(dict) # {query_id: {doc_id: 1}} | |
| run_dict: dict[str, dict[str, float]] = defaultdict(dict) # {query_id: {doc_id: score}} | |
| for target_rel in target_rels: | |
| query_id = head2id[target_rel.head] | |
| doc_id = tail_and_label2id[(target_rel.tail, target_rel.label)] | |
| if target_rel.score != 1.0: | |
| raise ValueError( | |
| f"target score must be 1.0, but got {target_rel.score} for {target_rel}" | |
| ) | |
| qrels_dict[query_id][doc_id] = 1 | |
| for pred_rel in pred_rels: | |
| query_id = head2id[pred_rel.head] | |
| doc_id = tail_and_label2id[(pred_rel.tail, pred_rel.label)] | |
| run_dict[query_id][doc_id] = pred_rel.score | |
| if include_queries_without_gold: | |
| # add missing query ids to rund_dict and qrels_dict | |
| for query_id in set(head2id.values()) - set(qrels_dict): | |
| qrels_dict[query_id] = {} | |
| # evaluate | |
| qrels = ranx.Qrels(qrels_dict) | |
| run = ranx.Run(run_dict) | |
| results = ranx.evaluate(qrels, run, metrics, make_comparable=True) | |
| return results | |
| def deduplicate_relations( | |
| relations: Iterable[BinaryRelation], caption: str | |
| ) -> Set[BinaryRelation]: | |
| pred2scores = defaultdict(set) | |
| for ann in relations: | |
| pred2scores[ann].add(round(ann.score, 4)) | |
| # warning for duplicates | |
| preds_with_duplicates = [ann for ann, scores in pred2scores.items() if len(scores) > 1] | |
| if len(preds_with_duplicates) > 0: | |
| logger.warning( | |
| f"there are {len(preds_with_duplicates)} {caption} with duplicates: " | |
| f"{preds_with_duplicates}. We will take the max score for each annotation." | |
| ) | |
| # take the max score for each annotation | |
| result = {ann.copy(score=max(scores)) for ann, scores in pred2scores.items()} | |
| return result | |
| def construct_y_true_and_score( | |
| preds: Iterable[BinaryRelation], targets: Iterable[BinaryRelation] | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| # helper constructs | |
| all_anns = set(preds) | set(targets) | |
| head2relations = defaultdict(list) | |
| for ann in all_anns: | |
| head2relations[ann.head].append(ann) | |
| target2score = {rel: rel.score for rel in targets} | |
| pred2score = {rel: rel.score for rel in preds} | |
| max_len = max(len(relations) for relations in head2relations.values()) | |
| target_rows, pred_rows = [], [] | |
| for query in head2relations: | |
| relations = head2relations[query] | |
| # get a very small, random score for missing predictions. Or should we use 0.0 as before? or NEG_INF? | |
| missing_pred_score = NEG_INF # np.random.uniform(0.0, 0.001) #0.0 # | |
| missing_target_score = 0 | |
| query_scores = [ | |
| (target2score.get(ann, missing_target_score), pred2score.get(ann, missing_pred_score)) | |
| for ann in relations | |
| ] | |
| # sort by descending order of prediction score | |
| query_scores_sorted = np.array(sorted(query_scores, key=lambda x: x[1], reverse=True)) | |
| # pad with zeros so every row has the same length | |
| pad_width = max_len - len(query_scores) | |
| query_target = np.pad( | |
| query_scores_sorted[:, 0], (0, pad_width), constant_values=missing_target_score | |
| ) | |
| query_pred = np.pad( | |
| query_scores_sorted[:, 1], (0, pad_width), constant_values=missing_pred_score | |
| ) | |
| target_rows.append(query_target) | |
| pred_rows.append(query_pred) | |
| y_true = np.vstack(target_rows) # shape (n_queries, max_len) | |
| y_score = np.vstack(pred_rows) | |
| return y_true, y_score | |
| class SemanticallySameRankingMetric(DocumentMetric): | |
| def __init__( | |
| self, | |
| layer: str, | |
| label: Optional[str] = None, | |
| add_reversed: bool = False, | |
| require_positive_gold: bool = False, | |
| bootstrap_n: Optional[int] = None, | |
| k_values: Optional[List[int]] = None, | |
| return_coverage: bool = True, | |
| show_as_markdown: bool = False, | |
| use_ranx: bool = False, | |
| add_stats_to_result: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.layer = layer | |
| self.label = label | |
| self.add_reversed = add_reversed | |
| self.require_positive_gold = require_positive_gold | |
| self.bootstrap_n = bootstrap_n | |
| self.k_values = k_values if k_values is not None else [1, 5, 10] | |
| self.return_coverage = return_coverage | |
| self.show_as_markdown = show_as_markdown | |
| self.use_ranx = use_ranx | |
| self.add_stats_to_result = add_stats_to_result | |
| self.metrics = { | |
| "macro_ndcg": macro_ndcg, | |
| "macro_mrr": true_mrr, | |
| "macro_map": macro_map, | |
| "micro_ap": ap_micro, | |
| } | |
| for name, func in [ | |
| ("macro_ndcg", macro_ndcg), | |
| ("micro_recall", recall_at_k_micro), | |
| ("micro_precision", precision_at_k_micro), | |
| ("macro_recall", recall_at_k_macro), | |
| ("macro_precision", precision_at_k_macro), | |
| ]: | |
| for k in self.k_values: | |
| self.metrics[f"{name}@{k}"] = partial(func, k=k) # type: ignore | |
| self.ranx_metrics = ["map", "mrr", "ndcg"] | |
| for name in ["recall", "precision", "ndcg"]: | |
| for k in self.k_values: | |
| self.ranx_metrics.append(f"{name}@{k}") | |
| def reset(self) -> None: | |
| """ | |
| Reset the metric to its initial state. | |
| """ | |
| self._preds: List[BinaryRelation] = [] | |
| self._targets: List[BinaryRelation] = [] | |
| def _update(self, document): | |
| layer = document[self.layer] | |
| ann: BinaryRelation | |
| for ann in layer: | |
| if self.label is None or ann.label == self.label: | |
| if ann.score > 0.0: | |
| self._targets.append(ann.copy()) | |
| if self.add_reversed: | |
| self._targets.append(ann.copy(head=ann.tail, tail=ann.head)) | |
| for ann in layer.predictions: | |
| if self.label is None or ann.label == self.label: | |
| if ann.score > 0.0: | |
| self._preds.append(ann.copy()) | |
| if self.add_reversed: | |
| self._preds.append(ann.copy(head=ann.tail, tail=ann.head)) | |
| def _compute(self): | |
| # take the max score for each annotation | |
| preds_deduplicated = deduplicate_relations(self._preds, "predictions") | |
| targets_deduplicated = deduplicate_relations(self._targets, "targets") | |
| stats = { | |
| "gold": len(targets_deduplicated), | |
| "preds": len(preds_deduplicated), | |
| "queries": len( | |
| set(ann.head for ann in targets_deduplicated) | |
| | set(ann.head for ann in preds_deduplicated) | |
| ), | |
| } | |
| if self.use_ranx: | |
| if self.bootstrap_n is not None: | |
| raise ValueError( | |
| "Ranx does not support bootstrapping. Please set bootstrap_n=None." | |
| ) | |
| scores = evaluate_with_ranx( | |
| preds_deduplicated, | |
| targets_deduplicated, | |
| metrics=self.ranx_metrics, | |
| include_queries_without_gold=not self.require_positive_gold, | |
| ) | |
| if self.add_stats_to_result: | |
| scores.update(stats) | |
| # logger.info(f"results via ranx:\n{pd.Series(ranx_result).sort_index().round(3).to_markdown()}") | |
| df = pd.DataFrame.from_records([scores], index=["score"]) | |
| else: | |
| y_true, y_score = construct_y_true_and_score( | |
| preds=preds_deduplicated, targets=targets_deduplicated | |
| ) | |
| # original definition β share of queries with β₯1 positive | |
| coverage = (y_true.sum(axis=1) > 0).mean() | |
| # keep only queries that actually have at least one gold positive | |
| if self.require_positive_gold: | |
| mask = y_true.sum(axis=1) > 0 # shape: (n_queries,) | |
| y_true = y_true[mask] | |
| y_score = y_score[mask] | |
| if self.bootstrap_n is not None: | |
| scores = { | |
| name: bootstrap(fn, y_true, y_score, n=self.bootstrap_n) | |
| for name, fn in self.metrics.items() | |
| } | |
| if self.add_stats_to_result: | |
| scores["stats"] = stats | |
| df = pd.DataFrame(scores) | |
| else: | |
| scores = {name: fn(y_true, y_score) for name, fn in self.metrics.items()} | |
| if self.add_stats_to_result: | |
| scores.update(stats) | |
| df = pd.DataFrame.from_records([scores], index=["score"]) | |
| if self.return_coverage: | |
| scores["coverage"] = coverage | |
| if self.show_as_markdown: | |
| if not self.add_stats_to_result: | |
| logger.info( | |
| logger.info( | |
| f'\nstatistics ({self.layer}):\n{pd.Series(stats, name="value").to_markdown()}' | |
| ) | |
| ) | |
| logger.info(f"\n{self.layer}:\n{df.round(4).T.to_markdown()}") | |
| return scores | |