Spaces:
Runtime error
Runtime error
| """ | |
| Functions for evaluating the CORE metric, as described in the DCLM paper. | |
| https://arxiv.org/abs/2406.11794 | |
| TODOs: | |
| - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. | |
| """ | |
| import random | |
| from jinja2 import Template | |
| import torch | |
| import torch.distributed as dist | |
| # ----------------------------------------------------------------------------- | |
| # Prompt rendering utilities | |
| def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): | |
| """Render complete prompts for a multiple choice question""" | |
| template_str = """ | |
| {%- for example in fewshot_examples -%} | |
| {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} | |
| {% endfor -%} | |
| {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() | |
| template = Template(template_str) | |
| fewshot_examples = fewshot_examples or [] | |
| context = { | |
| 'fewshot_examples': fewshot_examples, | |
| 'continuation_delimiter': continuation_delimiter, | |
| 'item': item | |
| } | |
| prompts = [template.render(choice=choice, **context) for choice in item['choices']] | |
| return prompts | |
| def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): | |
| """Render complete prompts for a schema question""" | |
| template_str = """ | |
| {%- for example in fewshot_examples -%} | |
| {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} | |
| {% endfor -%} | |
| {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() | |
| template = Template(template_str) | |
| fewshot_examples = fewshot_examples or [] | |
| context = { | |
| 'fewshot_examples': fewshot_examples, | |
| 'continuation_delimiter': continuation_delimiter, | |
| 'item': item | |
| } | |
| prompts = [template.render(context=context_option, **context) | |
| for context_option in item['context_options']] | |
| return prompts | |
| def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): | |
| """ | |
| Render complete prompt for a language modeling task. | |
| Notice that we manually trim the context in the template, | |
| which in some datasets seems to have trailing whitespace (which we don't want). | |
| """ | |
| template_str = """ | |
| {%- for example in fewshot_examples -%} | |
| {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }} | |
| {% endfor -%} | |
| {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() | |
| template = Template(template_str) | |
| fewshot_examples = fewshot_examples or [] | |
| context = { | |
| 'fewshot_examples': fewshot_examples, | |
| 'continuation_delimiter': continuation_delimiter, | |
| 'item': item | |
| } | |
| # Return two prompts: without and with the continuation | |
| prompt_without = template.render(include_continuation=False, **context) | |
| prompt_with = template.render(include_continuation=True, **context) | |
| # Due to the way the data seems to be stored, I think I need to strip in the case of LM here. | |
| # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next | |
| # token in prompt_with), meaning we don't get a nice and clean prefix in the token space | |
| # to detect the final continuation. Tokenizers... | |
| prompt_without = prompt_without.strip() | |
| return [prompt_without, prompt_with] | |
| def find_common_length(token_sequences, direction='left'): | |
| """ | |
| Find the length of the common prefix or suffix across token sequences | |
| - direction: 'left' for prefix, 'right' for suffix | |
| """ | |
| min_len = min(len(seq) for seq in token_sequences) | |
| indices = { | |
| 'left': range(min_len), | |
| 'right': range(-1, -min_len-1, -1) | |
| }[direction] | |
| # Find the first position where the token sequences differ | |
| for i, idx in enumerate(indices): | |
| token = token_sequences[0][idx] | |
| if not all(seq[idx] == token for seq in token_sequences): | |
| return i | |
| return min_len | |
| def stack_sequences(tokens, pad_token_id): | |
| """Stack up a list of token sequences, pad to longest on the right""" | |
| bsz, seq_len = len(tokens), max(len(x) for x in tokens) | |
| input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) | |
| for i, x in enumerate(tokens): | |
| input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) | |
| return input_ids | |
| def batch_sequences_mc(tokenizer, prompts): | |
| # In multiple choice, contexts are the same but the continuation is different (common prefix) | |
| tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) | |
| # figure out the start and end of each continuation | |
| answer_start_idx = find_common_length(tokens, direction='left') | |
| start_indices = [answer_start_idx] * len(prompts) | |
| end_indices = [len(x) for x in tokens] | |
| return tokens, start_indices, end_indices | |
| def batch_sequences_schema(tokenizer, prompts): | |
| # In schema tasks, contexts vary but continuation is the same (common suffix) | |
| tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) | |
| # figure out the start and end of each context | |
| suffix_length = find_common_length(tokens, direction='right') | |
| end_indices = [len(x) for x in tokens] | |
| start_indices = [ei - suffix_length for ei in end_indices] | |
| return tokens, start_indices, end_indices | |
| def batch_sequences_lm(tokenizer, prompts): | |
| # In LM tasks, we have two prompts: without and with continuation | |
| tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) | |
| tokens_without, tokens_with = tokens | |
| start_idx, end_idx = len(tokens_without), len(tokens_with) | |
| assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with" | |
| assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with" | |
| # we only need the with continuation prompt in the LM task, i.e. batch size of 1 | |
| return [tokens_with], [start_idx], [end_idx] | |
| def forward_model(model, input_ids): | |
| """ | |
| Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. | |
| The last column of losses is set to nan because we don't have autoregressive targets there. | |
| """ | |
| batch_size, seq_len = input_ids.size() | |
| outputs = model(input_ids) | |
| # Roll the tensor to the left by one position to get the (autoregressive) target ids | |
| target_ids = torch.roll(input_ids, shifts=-1, dims=1) | |
| # Calculate cross entropy at all positions | |
| losses = torch.nn.functional.cross_entropy( | |
| outputs.view(batch_size * seq_len, -1), | |
| target_ids.view(batch_size * seq_len), | |
| reduction='none' | |
| ).view(batch_size, seq_len) | |
| # Set the last column to be nan because there is no autoregressive loss there | |
| losses[:, -1] = float('nan') | |
| # Get the argmax predictions at each position | |
| predictions = outputs.argmax(dim=-1) | |
| return losses, predictions | |
| def evaluate_example(idx, model, tokenizer, data, device, task_meta): | |
| """Evaluate a single example, return True if correct, False otherwise""" | |
| item = data[idx] | |
| task_type = task_meta['task_type'] | |
| num_fewshot = task_meta['num_fewshot'] | |
| continuation_delimiter = task_meta['continuation_delimiter'] | |
| # Sample few-shot examples (excluding current item) | |
| fewshot_examples = [] | |
| if num_fewshot > 0: | |
| rng = random.Random(1234 + idx) | |
| available_indices = [i for i in range(len(data)) if i != idx] | |
| fewshot_indices = rng.sample(available_indices, num_fewshot) | |
| fewshot_examples = [data[i] for i in fewshot_indices] | |
| # Render prompts and batch sequences based on task type | |
| if task_type == 'multiple_choice': | |
| prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) | |
| tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) | |
| elif task_type == 'schema': | |
| prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) | |
| tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) | |
| elif task_type == 'language_modeling': | |
| prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) | |
| tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) | |
| else: | |
| raise ValueError(f"Unsupported task type: {task_type}") | |
| # Some models can't forward sequences beyond a certain length (e.g. GPT-2) | |
| # In these cases, we have to truncate sequences to max length and adjust the indices | |
| if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: | |
| max_tokens = model.max_seq_len | |
| new_tokens, new_start_idxs, new_end_idxs = [], [], [] | |
| for t, s, e in zip(tokens, start_idxs, end_idxs): | |
| if len(t) > max_tokens: | |
| num_to_crop = len(t) - max_tokens | |
| new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens | |
| new_start_idxs.append(s - num_to_crop) # shift the indices down | |
| new_end_idxs.append(e - num_to_crop) | |
| assert s - num_to_crop >= 0, "this should never happen right?" | |
| assert e - num_to_crop >= 0, "this should never happen right?" | |
| else: | |
| new_tokens.append(t) # keep unchanged | |
| new_start_idxs.append(s) | |
| new_end_idxs.append(e) | |
| tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs | |
| # Stack up all the sequences into a batch | |
| pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok | |
| input_ids = stack_sequences(tokens, pad_token_id) | |
| input_ids = input_ids.to(device) | |
| # Forward the model, get the autoregressive loss and argmax prediction at each token | |
| losses, predictions = forward_model(model, input_ids) | |
| # See if the losses/predictions come out correctly | |
| if task_type == 'language_modeling': | |
| # language modeling task is currently always batch size 1 | |
| si = start_idxs[0] | |
| ei = end_idxs[0] | |
| # predictions[i] predict input_ids[i+1] autoregressively | |
| predicted_tokens = predictions[0, si-1:ei-1] | |
| actual_tokens = input_ids[0, si:ei] | |
| is_correct = torch.all(predicted_tokens == actual_tokens).item() | |
| elif task_type in ['multiple_choice', 'schema']: | |
| # For MC/schema: find the option with lowest average loss | |
| mean_losses = [losses[i, si-1:ei-1].mean().item() | |
| for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] | |
| pred_idx = mean_losses.index(min(mean_losses)) | |
| is_correct = pred_idx == item['gold'] | |
| else: | |
| raise ValueError(f"Unsupported task type: {task_type}") | |
| return is_correct | |
| def evaluate_task(model, tokenizer, data, device, task_meta): | |
| """ | |
| This function is responsible for evaluating one task across many examples. | |
| It also handles dispatch to all processes if the script is run with torchrun. | |
| """ | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| correct = torch.zeros(len(data), dtype=torch.float32, device=device) | |
| # stride the examples to each rank | |
| for idx in range(rank, len(data), world_size): | |
| is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) | |
| correct[idx] = float(is_correct) | |
| # sync results across all the processes if running distributed | |
| if world_size > 1: | |
| dist.barrier() | |
| dist.all_reduce(correct, op=dist.ReduceOp.SUM) | |
| # compute the mean | |
| mean_correct = correct.mean().item() | |
| return mean_correct | |