update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from collections import defaultdict | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import pandas as pd | |
| from pytorch_ie import Document, DocumentMetric | |
| logger = logging.getLogger() | |
| class ScoreDistribution(DocumentMetric): | |
| """Computes the distribution of prediction scores for annotations in a layer. The scores are | |
| separated into true positives (TP) and false positives (FP) based on the gold annotations. | |
| Args: | |
| layer: The name of the annotation layer to analyze. | |
| per_label: If True, the scores are separated per label. Default is False. | |
| label_field: The field name of the label to use for separating the scores per label. Default is "label". | |
| equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False, | |
| the scores are binned into equal width. The former is useful when the distribution of scores is skewed. | |
| Default is True. | |
| show_plot: If True, a plot of the score distribution is shown. Default is False. | |
| plotting_backend: The plotting backend to use. Default is "plotly". | |
| plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name, | |
| labels, or TP/FP. Default is None. | |
| plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None. | |
| """ | |
| def __init__( | |
| self, | |
| layer: str, | |
| label_field: str = "label", | |
| per_label: bool = False, | |
| show_plot: bool = False, | |
| equal_sample_size_binning: bool = True, | |
| plotting_backend: str = "plotly", | |
| plotting_caption_mapping: Optional[Dict[str, str]] = None, | |
| plotting_colors: Optional[Dict[str, str]] = None, | |
| plotly_use_create_distplot: bool = True, | |
| plotly_barmode: Optional[str] = None, | |
| plotly_marginal: Optional[str] = "violin", | |
| plotly_font: Optional[Dict[str, Any]] = None, | |
| plotly_font_size: Optional[int] = None, | |
| plotly_font_family: Optional[str] = None, | |
| plotly_background_color: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.layer = layer | |
| self.label_field = label_field | |
| self.per_label = per_label | |
| self.equal_sample_size_binning = equal_sample_size_binning | |
| self.plotting_backend = plotting_backend | |
| self.show_plot = show_plot | |
| self.plotting_caption_mapping = plotting_caption_mapping or {} | |
| self.plotting_colors = plotting_colors | |
| self.plotly_use_create_distplot = plotly_use_create_distplot | |
| self.plotly_barmode = plotly_barmode | |
| self.plotly_marginal = plotly_marginal | |
| self.plotly_font = plotly_font or {} | |
| if plotly_font_size is not None: | |
| logger.warning( | |
| "Parameter 'plotly_font_size' is deprecated. Use 'plotly_font' with 'size' key instead." | |
| ) | |
| self.plotly_font["size"] = plotly_font_size | |
| self.plotly_font_family = plotly_font_family | |
| self.plotly_background_color = plotly_background_color | |
| self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) | |
| def reset(self): | |
| self.scores = defaultdict(lambda: defaultdict(list)) | |
| def _update(self, document: Document): | |
| gold_annotations = set(document[self.layer]) | |
| for ann in document[self.layer].predictions: | |
| if self.per_label: | |
| label = getattr(ann, self.label_field) | |
| else: | |
| label = "ALL" | |
| if ann in gold_annotations: | |
| self.scores[label]["TP"].append(ann.score) | |
| else: | |
| self.scores[label]["FP"].append(ann.score) | |
| def _combine_scores( | |
| self, | |
| scores_tp: List[float], | |
| score_fp: List[float], | |
| col_name_pred: str = "prediction", | |
| col_name_gold: str = "gold", | |
| ) -> pd.DataFrame: | |
| scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred]) | |
| scores_tp_df[col_name_gold] = 1.0 | |
| scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred]) | |
| scores_fp_df[col_name_gold] = 0.0 | |
| scores_df = pd.concat([scores_tp_df, scores_fp_df]) | |
| return scores_df | |
| def _get_calibration_data_and_metrics( | |
| self, scores: pd.DataFrame, q: int = 20 | |
| ) -> Tuple[pd.DataFrame, pd.Series]: | |
| from sklearn.metrics import brier_score_loss | |
| if self.equal_sample_size_binning: | |
| # Create bins with equal number of samples. | |
| scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False) | |
| else: | |
| # Create bins with equal width. | |
| scores["bin"] = pd.cut( | |
| scores["prediction"], | |
| bins=q, | |
| include_lowest=True, | |
| right=True, | |
| labels=False, | |
| ) | |
| calibration_data = ( | |
| scores.groupby("bin") | |
| .apply( | |
| lambda x: pd.Series( | |
| { | |
| "avg_score": x["prediction"].mean(), | |
| "fraction_positive": x["gold"].mean(), | |
| "count": len(x), | |
| } | |
| ) | |
| ) | |
| .reset_index() | |
| ) | |
| total_count = scores.shape[0] | |
| calibration_data["bin_weight"] = calibration_data["count"] / total_count | |
| # Calculate the absolute differences and squared differences. | |
| calibration_data["abs_diff"] = abs( | |
| calibration_data["avg_score"] - calibration_data["fraction_positive"] | |
| ) | |
| calibration_data["squared_diff"] = ( | |
| calibration_data["avg_score"] - calibration_data["fraction_positive"] | |
| ) ** 2 | |
| # Compute Expected Calibration Error (ECE): weighted average of absolute differences. | |
| ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum() | |
| # Compute Maximum Calibration Error (MCE): maximum absolute difference. | |
| mce = calibration_data["abs_diff"].max() | |
| # Compute Mean Squared Error (MSE): weighted average of squared differences. | |
| mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum() | |
| # Compute the Brier Score on the raw predictions. | |
| brier = brier_score_loss(scores["gold"], scores["prediction"]) | |
| values = { | |
| "ece": ece, | |
| "mce": mce, | |
| "mse": mse, | |
| "brier": brier, | |
| } | |
| return calibration_data, pd.Series(values) | |
| def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame: | |
| calibration_data_dict = {} | |
| calibration_metrics_dict = {} | |
| for label, current_scores in scores_combined.groupby("label"): | |
| calibration_data, calibration_metrics = self._get_calibration_data_and_metrics( | |
| current_scores, q=20 | |
| ) | |
| calibration_data_dict[label] = calibration_data | |
| calibration_metrics_dict[label] = calibration_metrics | |
| all_calibration_data = pd.concat( | |
| calibration_data_dict, names=["label", "idx"] | |
| ).reset_index(level=0) | |
| all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T | |
| if self.show_plot: | |
| self.plot_calibration_data(calibration_data=all_calibration_data) | |
| return all_calibration_metrics | |
| def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series: | |
| result_dict = {} | |
| for label, current_scores in scores.groupby("label"): | |
| result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"] | |
| return pd.Series(result_dict, name="correlation") | |
| def mapped_layer(self): | |
| return self.plotting_caption_mapping.get(self.layer, self.layer) | |
| def plot_score_distribution(self, scores: pd.DataFrame): | |
| if self.plotting_backend == "plotly": | |
| for label in scores["label"].unique(): | |
| description = f"Distribution of Predicted Scores for {self.mapped_layer}" | |
| if self.per_label: | |
| label_mapped = self.plotting_caption_mapping.get(label, label) | |
| description += f" ({label_mapped})" | |
| if self.plotly_use_create_distplot: | |
| import plotly.figure_factory as ff | |
| current_scores = scores[scores["label"] == label] | |
| # group by gold score | |
| scores_dict = ( | |
| current_scores.groupby("gold")["prediction"].apply(list).to_dict() | |
| ) | |
| group_labels, hist_data = zip(*scores_dict.items()) | |
| group_labels_renamed = [ | |
| self.plotting_caption_mapping.get(label, label) for label in group_labels | |
| ] | |
| if self.plotting_colors is not None: | |
| colors = [ | |
| self.plotting_colors[group_label] for group_label in group_labels | |
| ] | |
| else: | |
| colors = None | |
| fig = ff.create_distplot( | |
| hist_data, | |
| group_labels=group_labels_renamed, | |
| show_hist=True, | |
| colors=colors, | |
| bin_size=0.025, | |
| ) | |
| else: | |
| import plotly.express as px | |
| fig = px.histogram( | |
| scores, | |
| x="prediction", | |
| color="gold", | |
| marginal=self.plotly_marginal, # "violin", # or box, violin, rug | |
| hover_data=scores.columns, | |
| color_discrete_map=self.plotting_colors, | |
| nbins=50, | |
| ) | |
| fig.update_layout( | |
| height=600, | |
| width=800, | |
| title_text=description, | |
| title_x=0.5, | |
| font=self.plotly_font, | |
| legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
| ) | |
| if self.plotly_barmode is not None: | |
| fig.update_layout(barmode=self.plotly_barmode) | |
| if self.plotly_font_family is not None: | |
| fig.update_layout(font_family=self.plotly_font_family) | |
| if self.plotly_background_color is not None: | |
| fig.update_layout( | |
| plot_bgcolor=self.plotly_background_color, | |
| paper_bgcolor=self.plotly_background_color, | |
| ) | |
| fig.show() | |
| else: | |
| raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented") | |
| def plot_calibration_data(self, calibration_data: pd.DataFrame): | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| color = "label" if self.per_label else None | |
| x_col = "avg_score" | |
| y_col = "fraction_positive" | |
| fig = px.scatter( | |
| calibration_data, | |
| x=x_col, | |
| y=y_col, | |
| color=color, | |
| trendline="ols", | |
| labels=self.plotting_caption_mapping, | |
| ) | |
| if not self.per_label: | |
| fig["data"][1]["name"] = "prediction vs. gold" | |
| # show legend only for trendlines | |
| for idx, trace_data in enumerate(fig["data"]): | |
| if idx % 2 == 0: | |
| trace_data["showlegend"] = False | |
| else: | |
| trace_data["showlegend"] = True | |
| # add the optimal line | |
| minimum = calibration_data[x_col].min() | |
| maximum = calibration_data[x_col].max() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[minimum, maximum], | |
| y=[minimum, maximum], | |
| mode="lines", | |
| name="optimal", | |
| line=dict(color="black", dash="dash"), | |
| ) | |
| ) | |
| fig.update_layout( | |
| height=600, | |
| width=800, | |
| title_text=f"Mean Binned Scores for {self.mapped_layer}", | |
| title_x=0.5, | |
| font=self.plotly_font, | |
| ) | |
| fig.update_layout( | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01, | |
| title="OLS trendline" + ("s" if self.per_label else ""), | |
| ), | |
| ) | |
| if self.plotly_background_color is not None: | |
| fig.update_layout( | |
| plot_bgcolor=self.plotly_background_color, | |
| paper_bgcolor=self.plotly_background_color, | |
| ) | |
| if self.plotly_font_family is not None: | |
| fig.update_layout(font_family=self.plotly_font_family) | |
| fig.show() | |
| def _compute(self) -> Dict[str, Dict[str, Any]]: | |
| scores_combined = pd.concat( | |
| { | |
| label: self._combine_scores(scores["TP"], scores["FP"]) | |
| for label, scores in self.scores.items() | |
| }, | |
| names=["label", "idx"], | |
| ).reset_index(level=0) | |
| result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"]) | |
| if self.show_plot: | |
| self.plot_score_distribution(scores=scores_combined) | |
| calibration_metrics = self.calculate_calibration_metrics(scores_combined) | |
| calibration_metrics["correlation"] = self.calculate_correlation(scores_combined) | |
| result_df = pd.concat( | |
| {"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1 | |
| ) | |
| if not self.per_label: | |
| result = result_df.xs("ALL") | |
| else: | |
| result = result_df.T.stack().unstack() | |
| result_dict = { | |
| main_key: result.xs(main_key).T.to_dict() | |
| for main_key in result.index.get_level_values(0).unique() | |
| } | |
| return result_dict | |