| from typing import Optional | |
| from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs | |
| from inseq.commands.attribute_context.attribute_context_helpers import ( | |
| AttributeContextOutput, | |
| filter_rank_tokens, | |
| get_filtered_tokens, | |
| ) | |
| from inseq.models import HuggingfaceModel | |
| def get_formatted_attribute_context_results( | |
| model: HuggingfaceModel, | |
| args: AttributeContextArgs, | |
| output: AttributeContextOutput, | |
| ) -> str: | |
| """Format the results of the context attribution process.""" | |
| def format_context_comment( | |
| model: HuggingfaceModel, | |
| has_other_context: bool, | |
| special_tokens_to_keep: list[str], | |
| context: str, | |
| context_scores: list[float], | |
| other_context_scores: Optional[list[float]] = None, | |
| is_target: bool = False, | |
| ) -> str: | |
| context_tokens = get_filtered_tokens( | |
| context, | |
| model, | |
| special_tokens_to_keep, | |
| replace_special_characters=True, | |
| is_target=is_target, | |
| ) | |
| context_token_tuples = [(t, None) for t in context_tokens] | |
| scores = context_scores | |
| if has_other_context: | |
| scores += other_context_scores | |
| context_ranked_tokens, _ = filter_rank_tokens( | |
| tokens=context_tokens, | |
| scores=scores, | |
| std_threshold=args.attribution_std_threshold, | |
| topk=args.attribution_topk, | |
| ) | |
| for idx, _, tok in context_ranked_tokens: | |
| context_token_tuples[idx] = (tok, "Influential context") | |
| return context_token_tuples | |
| out = [] | |
| output_current_tokens = get_filtered_tokens( | |
| output.output_current, | |
| model, | |
| args.special_tokens_to_keep, | |
| replace_special_characters=True, | |
| is_target=True, | |
| ) | |
| for example_idx, cci_out in enumerate(output.cci_scores, start=1): | |
| curr_output_tokens = [(t, None) for t in output_current_tokens] | |
| cti_idx = cci_out.cti_idx | |
| curr_output_tokens[cti_idx] = ( | |
| curr_output_tokens[cti_idx][0], | |
| "Context sensitive", | |
| ) | |
| if args.has_input_context: | |
| input_context_tokens = format_context_comment( | |
| model, | |
| args.has_output_context, | |
| args.special_tokens_to_keep, | |
| output.input_context, | |
| cci_out.input_context_scores, | |
| cci_out.output_context_scores, | |
| ) | |
| if args.has_output_context: | |
| output_context_tokens = format_context_comment( | |
| model, | |
| args.has_input_context, | |
| args.special_tokens_to_keep, | |
| output.output_context, | |
| cci_out.output_context_scores, | |
| cci_out.input_context_scores, | |
| is_target=True, | |
| ) | |
| out += [ | |
| ("\n\n" if example_idx > 1 else "", None), | |
| ( | |
| f"#{example_idx}.\nGenerated output:\t", | |
| None, | |
| ), | |
| ] | |
| out += curr_output_tokens | |
| if args.has_input_context: | |
| out += [("\nInput context:\t", None)] | |
| out += input_context_tokens | |
| if args.has_output_context: | |
| out += [("\nOutput context:\t", None)] | |
| out += output_context_tokens | |
| return out | |
