update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import argparse | |
| import json | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| def get_col_name(col: str) -> str: | |
| parts = [part[1:-1] for part in col[1:-1].split(", ") if part[1:-1] != ""] | |
| return parts[-1] | |
| def get_idx_entry(s: str, keep_only_last_part: bool = False) -> Tuple[str, str]: | |
| k, v = s.split("=", 1) | |
| if keep_only_last_part: | |
| k = k.split(".")[-1] | |
| return k, v | |
| def get_idx_dict(job_id: str, keep_only_last_part: bool = False) -> Dict[str, str]: | |
| return dict( | |
| get_idx_entry(part, keep_only_last_part=keep_only_last_part) for part in job_id.split("-") | |
| ) | |
| def unflatten_index( | |
| index: Iterable[str], | |
| keep_only_last_part: bool = False, | |
| dtypes: Optional[Dict[str, Any]] = None, | |
| ) -> pd.MultiIndex: | |
| as_df = pd.DataFrame.from_records( | |
| [get_idx_dict(idx, keep_only_last_part=keep_only_last_part) for idx in index] | |
| ) | |
| if dtypes is not None: | |
| dtypes_valid = {col: dtype for col, dtype in dtypes.items() if col in as_df.columns} | |
| as_df = as_df.astype(dtypes_valid) | |
| return pd.MultiIndex.from_frame(as_df.convert_dtypes()) | |
| def col_to_str(col_entries: Iterable[str], names: Iterable[Optional[str]], sep: str) -> str: | |
| return sep.join( | |
| [ | |
| f"{name}={col_entry}" if name is not None else col_entry | |
| for col_entry, name in zip(col_entries, names) | |
| ] | |
| ) | |
| def flatten_index(index: pd.MultiIndex, names: Optional[List[Optional[str]]] = None) -> pd.Index: | |
| names = names or index.names | |
| if names is None: | |
| raise ValueError("names must be provided if index has no names") | |
| return pd.Index([col_to_str(col, names=names, sep=",") for col in index]) | |
| def prepare_quality_and_throughput_dfs( | |
| metric_data_path: str, | |
| job_return_value_path: str, | |
| char_total: int, | |
| index_dtypes: Optional[Dict[str, Any]] = None, | |
| job_id_prefix: Optional[str] = None, | |
| ) -> Tuple[pd.DataFrame, pd.Series]: | |
| with open(metric_data_path) as f: | |
| data = json.load(f) | |
| # save result from above command in "data" (use only last ouf the output line!) | |
| df = pd.DataFrame.from_dict(data) | |
| df.columns = [get_col_name(col) for col in df.columns] | |
| f1_series = df.set_index([col for col in df.columns if col != "f1"])["f1"] | |
| f1_df = f1_series.apply(lambda x: pd.Series(x)).T | |
| with open(job_return_value_path) as f: | |
| job_return_value = json.load(f) | |
| job_ids = job_return_value["job_id"] | |
| if job_id_prefix is not None: | |
| job_ids = [ | |
| f"{job_id_prefix},{job_id}" if job_id.strip() != "" else job_id_prefix | |
| for job_id in job_ids | |
| ] | |
| index = unflatten_index( | |
| job_ids, | |
| keep_only_last_part=True, | |
| dtypes=index_dtypes, | |
| ) | |
| prediction_time_series = pd.Series( | |
| job_return_value["prediction_time"], index=index, name="prediction_time" | |
| ) | |
| f1_df.index = prediction_time_series.index | |
| k_chars_per_s = char_total / (prediction_time_series * 1000) | |
| k_chars_per_s.name = "1k_chars_per_s" | |
| return f1_df, k_chars_per_s | |
| def get_pareto_front_mask(df: pd.DataFrame, x_col: str, y_col: str) -> pd.Series: | |
| """ | |
| Return a boolean mask indicating which rows belong to the Pareto front. | |
| In this version, we assume you want to maximize both x_col and y_col. | |
| A point A is said to dominate point B if: | |
| A[x_col] >= B[x_col] AND | |
| A[y_col] >= B[y_col] AND | |
| at least one is strictly greater. | |
| Then B is not on the Pareto front. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| DataFrame containing the data points. | |
| x_col : str | |
| Name of the column to treat as the first objective (maximize). | |
| y_col : str | |
| Name of the column to treat as the second objective (maximize). | |
| Returns | |
| ------- | |
| pd.Series | |
| A boolean Series (aligned with df.index) where True means | |
| the row is on the Pareto front. | |
| """ | |
| # Extract the relevant columns as a NumPy array for speed. | |
| data = df[[x_col, y_col]].values | |
| n = len(data) | |
| is_dominated = np.zeros(n, dtype=bool) | |
| for i in range(n): | |
| # If it's already marked dominated, skip checks | |
| if is_dominated[i]: | |
| continue | |
| for j in range(n): | |
| if i == j: | |
| continue | |
| # Check if j dominates i | |
| if ( | |
| data[j, 0] >= data[i, 0] | |
| and data[j, 1] >= data[i, 1] | |
| and (data[j, 0] > data[i, 0] or data[j, 1] > data[i, 1]) | |
| ): | |
| is_dominated[i] = True | |
| break | |
| # Return True for points not dominated by any other | |
| return pd.Series(~is_dominated, index=df.index) | |
| def main( | |
| job_return_value_path_test: List[str], | |
| job_return_value_path_val: List[str], | |
| metric_data_path_test: List[str], | |
| metric_data_path_val: List[str], | |
| char_total_test: int, | |
| char_total_val: int, | |
| job_id_prefixes: Optional[List[str]] = None, | |
| metric_filters: Optional[List[str]] = None, | |
| index_filters: Optional[List[str]] = None, | |
| index_blacklist: Optional[List[str]] = None, | |
| label_mapping: Optional[Dict[str, str]] = None, | |
| plot_method: str = "line", # can be "scatter" or "line" | |
| pareto_front: bool = False, | |
| show_as: str = "figure", | |
| columns: Optional[List[str]] = None, | |
| color_column: Optional[str] = None, | |
| ): | |
| label_mapping = label_mapping or {} | |
| if job_id_prefixes is not None: | |
| if len(job_id_prefixes) != len(job_return_value_path_test): | |
| raise ValueError( | |
| f"job_id_prefixes ({len(job_id_prefixes)}) and " | |
| f"job_return_value_path_test ({len(job_return_value_path_test)}) " | |
| f"must have the same length" | |
| ) | |
| # replace empty strings with None | |
| job_id_prefixes_with_none = [ | |
| job_id_prefix if job_id_prefix != "" else None for job_id_prefix in job_id_prefixes | |
| ] | |
| else: | |
| job_id_prefixes_with_none = [None] * len(job_return_value_path_test) | |
| # combine input data for test and val | |
| char_total = {"test": char_total_test, "val": char_total_val} | |
| metric_data_path = {"test": metric_data_path_test, "val": metric_data_path_val} | |
| job_return_value_path = {"test": job_return_value_path_test, "val": job_return_value_path_val} | |
| # prepare dataframes | |
| common_kwargs = dict( | |
| index_dtypes={ | |
| "max_argument_distance": int, | |
| "max_length": int, | |
| "num_beams": int, | |
| } | |
| ) | |
| f1_df_list: Dict[str, List[pd.DataFrame]] = {"test": [], "val": []} | |
| k_chars_per_s_list: Dict[str, List[pd.Series]] = {"test": [], "val": []} | |
| for split in metric_data_path: | |
| if len(metric_data_path[split]) != len(job_return_value_path[split]): | |
| raise ValueError( | |
| f"metric_data_path[{split}] ({len(metric_data_path[split])}) and " | |
| f"job_return_value_path[{split}] ({len(job_return_value_path[split])}) " | |
| f"must have the same length" | |
| ) | |
| for current_metric_data_path, current_job_return_value_path, job_id_prefix in zip( | |
| metric_data_path[split], job_return_value_path[split], job_id_prefixes_with_none | |
| ): | |
| current_f1_df, current_k_chars_per_s = prepare_quality_and_throughput_dfs( | |
| current_metric_data_path, | |
| current_job_return_value_path, | |
| char_total=char_total[split], | |
| job_id_prefix=job_id_prefix, | |
| **common_kwargs, | |
| ) | |
| f1_df_list[split].append(current_f1_df) | |
| k_chars_per_s_list[split].append(current_k_chars_per_s) | |
| f1_df_dict = {split: pd.concat(f1_df_list[split], axis=0) for split in f1_df_list} | |
| k_chars_per_s_dict = { | |
| split: pd.concat(k_chars_per_s_list[split], axis=0) for split in k_chars_per_s_list | |
| } | |
| # combine dataframes for test and val | |
| f1_df = pd.concat(f1_df_dict, names=["split"] + f1_df_dict["test"].index.names) | |
| f1_df.columns = [col_to_str(col, names=f1_df.columns.names, sep=",") for col in f1_df.columns] | |
| k_chars_per_s = pd.concat( | |
| k_chars_per_s_dict, | |
| names=["split"] + k_chars_per_s_dict["test"].index.names, | |
| ) | |
| # combine quality and throughput data | |
| df_plot = pd.concat([f1_df, k_chars_per_s], axis=1) | |
| df_plot = ( | |
| df_plot.reset_index() | |
| .set_index(list(f1_df.index.names) + [k_chars_per_s.name]) | |
| .unstack("split") | |
| ) | |
| df_plot.columns = flatten_index(df_plot.columns, names=[None, "split"]) | |
| # remove all columns that are not needed | |
| if metric_filters is not None: | |
| for fil in metric_filters: | |
| df_plot.drop(columns=[col for col in df_plot.columns if fil not in col], inplace=True) | |
| df_plot.columns = [col.replace(fil, "") for col in df_plot.columns] | |
| # flatten the columns | |
| df_plot.columns = [ | |
| ",".join([part for part in col.split(",") if part != ""]) for col in df_plot.columns | |
| ] | |
| v: Any | |
| if index_filters is not None: | |
| for k_v in index_filters: | |
| k, v = k_v.split("=") | |
| if k in common_kwargs["index_dtypes"]: | |
| v = common_kwargs["index_dtypes"][k](v) | |
| df_plot = df_plot.xs(v, level=k, axis=0) | |
| if index_blacklist is not None: | |
| for k_v in index_blacklist: | |
| k, v = k_v.split("=") | |
| if k in common_kwargs["index_dtypes"]: | |
| v = common_kwargs["index_dtypes"][k](v) | |
| df_plot = df_plot.drop(v, level=k, axis=0) | |
| if columns is not None: | |
| df_plot = df_plot[columns] | |
| x = "1k_chars_per_s" | |
| y = df_plot.columns | |
| if pareto_front: | |
| for col in y: | |
| current_data = df_plot[col].dropna().reset_index(x).copy() | |
| pareto_front_mask = get_pareto_front_mask(current_data, x_col=x, y_col=col) | |
| current_data.loc[~pareto_front_mask, col] = np.nan | |
| current_data_reset = current_data.reset_index().set_index(df_plot.index.names) | |
| df_plot[col] = current_data_reset[col] | |
| # remove nan rows | |
| df_plot = df_plot.dropna(how="all") | |
| # plot | |
| # Create a custom color sequence (concatenating multiple palettes if needed) | |
| custom_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24 | |
| text_cols = list(df_plot.index.names) | |
| text_cols.remove(x) | |
| df_plot_reset = df_plot.reset_index() | |
| if len(text_cols) > 1: | |
| df_plot_reset[",".join(text_cols)] = ( | |
| df_plot_reset[text_cols].astype(str).agg(", ".join, axis=1) | |
| ) | |
| text_col = ",".join(text_cols) | |
| if show_as == "figure": | |
| _plot_method = getattr(px, plot_method) | |
| df_plot_sorted = df_plot_reset.sort_values(by=x) | |
| fig = _plot_method( | |
| df_plot_sorted, | |
| x=x, | |
| y=y, | |
| text=text_col if plot_method != "scatter" else None, | |
| color=color_column, | |
| color_discrete_sequence=custom_colors, | |
| hover_data=text_cols, | |
| ) | |
| # set connectgaps to True to connect the lines | |
| fig.update_traces(connectgaps=True) | |
| legend_title = "Evaluation Setup" | |
| if metric_filters: | |
| whitelist_filters_mapped = [label_mapping.get(fil, fil) for fil in metric_filters] | |
| legend_title += f" ({', '.join(whitelist_filters_mapped)})" | |
| text_cols_mapped = [label_mapping.get(col, col) for col in text_cols] | |
| title = f"Impact of {', '.join(text_cols_mapped)} on Prediction Quality and Throughput" | |
| if index_filters: | |
| index_filters_mapped = [label_mapping.get(fil, fil) for fil in index_filters] | |
| title += f" ({', '.join(index_filters_mapped)})" | |
| if pareto_front: | |
| title += " (Pareto Front)" | |
| fig.update_layout( | |
| xaxis_title="Throughput (1k chars/s)", | |
| yaxis_title="Quality (F1)", | |
| title=title, | |
| # center the title | |
| title_x=0.2, | |
| # black title | |
| title_font=dict(color="black"), | |
| # change legend title | |
| legend_title=legend_title, | |
| font_family="Computer Modern", | |
| # white background | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| ) | |
| update_axes_kwargs = dict( | |
| tickfont=dict(color="black"), | |
| title_font=dict(color="black"), | |
| ticks="inside", # ensure tick markers are drawn | |
| tickcolor="black", | |
| tickwidth=1, | |
| ticklen=10, | |
| linecolor="black", | |
| # show grid | |
| gridcolor="lightgray", | |
| ) | |
| fig.update_yaxes(**update_axes_kwargs) | |
| fig.update_xaxes(**update_axes_kwargs) | |
| fig.show() | |
| elif show_as == "markdown": | |
| # Print the DataFrame as a Markdown table | |
| print(df_plot_reset.to_markdown(index=False, floatfmt=".4f")) | |
| elif show_as == "json": | |
| # Print the DataFrame as a JSON object | |
| print(df_plot_reset.to_json(orient="columns", indent=4)) | |
| else: | |
| raise ValueError(f"Unknown show_as value: {show_as}. Use 'figure', 'markdown' or 'json'.") | |
| if __name__ == "__main__": | |
| """ | |
| # Example usage 1 (pipeline model, data from data source: https://github.com/ArneBinder/pie-document-level/issues/388#issuecomment-2752829257): | |
| python src/analysis/show_inference_params_on_quality_and_throughput.py \ | |
| --job-return-value-path-test logs/prediction/multiruns/default/2025-03-26_01-31-05/job_return_value.json \ | |
| --job-return-value-path-val logs/prediction/multiruns/default/2025-03-26_16-49-36/job_return_value.json \ | |
| --metric-data-path-test data/evaluation/argumentation_structure/inference_pipeline_test.json \ | |
| --metric-data-path-val data/evaluation/argumentation_structure/inference_pipeline_validation.json \ | |
| --metric-filters task=are discont_comp=true split=val | |
| # Example usage 2 (joint model, data from: https://github.com/ArneBinder/pie-document-level/issues/390#issuecomment-2759888004) | |
| python src/analysis/show_inference_params_on_quality_and_throughput.py \ | |
| --job-return-value-path-test logs/prediction/multiruns/default/2025-03-28_01-34-07/job_return_value.json \ | |
| --job-return-value-path-val logs/prediction/multiruns/default/2025-03-28_02-57-00/job_return_value.json \ | |
| --metric-data-path-test data/evaluation/argumentation_structure/inference_joint_test.json \ | |
| --metric-data-path-val data/evaluation/argumentation_structure/inference_joint_validation.json \ | |
| --metric-filters task=are discont_comp=true split=val \ | |
| --plot-method scatter | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--job-return-value-path-test", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--job-return-value-path-val", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--metric-data-path-test", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--metric-data-path-val", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--job-id-prefixes", | |
| type=str, | |
| nargs="*", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--plot-method", | |
| type=str, | |
| default="line", | |
| choices=["scatter", "line"], | |
| help="Plot method to use (default: line)", | |
| ) | |
| parser.add_argument( | |
| "--color-column", | |
| type=str, | |
| default=None, | |
| help="Column to use for colour coding (default: None)", | |
| ) | |
| parser.add_argument( | |
| "--metric-filters", | |
| type=str, | |
| nargs="*", | |
| default=None, | |
| help="Filters to apply to the metric data in the format 'key=value'", | |
| ) | |
| parser.add_argument( | |
| "--index-filters", | |
| type=str, | |
| nargs="*", | |
| default=None, | |
| help="Filters to apply to the index data in the format 'key=value'", | |
| ) | |
| parser.add_argument( | |
| "--index-blacklist", | |
| type=str, | |
| nargs="*", | |
| default=None, | |
| help="Blacklist to apply to the index data in the format 'key=value'", | |
| ) | |
| parser.add_argument( | |
| "--columns", | |
| type=str, | |
| nargs="*", | |
| default=None, | |
| help="Columns to plot (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--pareto-front", | |
| action="store_true", | |
| help="Whether to show only the pareto front", | |
| ) | |
| parser.add_argument( | |
| "--show-as", | |
| type=str, | |
| default="figure", | |
| choices=["figure", "markdown", "json"], | |
| help="How to show the results (default: figure)", | |
| ) | |
| kwargs = vars(parser.parse_args()) | |
| main( | |
| char_total_test=383154, | |
| char_total_val=182794, | |
| label_mapping={ | |
| "max_argument_distance": "Max. Argument Distance", | |
| "max_length": "Max. Length", | |
| "num_beams": "Num. Beams", | |
| "task=are": "ARE", | |
| "discont_comp=true": "Discont. Comp.", | |
| "split=val": "Validation Split", | |
| }, | |
| **kwargs, | |
| ) | |