""" Functions for evaluating the CORE metric, as described in the DCLM paper. https://arxiv.org/abs/2406.11794 TODOs: - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. """ import random from jinja2 import Template import torch import torch.distributed as dist # ----------------------------------------------------------------------------- # Prompt rendering utilities def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): """Render complete prompts for a multiple choice question""" template_str = """ {%- for example in fewshot_examples -%} {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} {% endfor -%} {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] context = { 'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item } prompts = [template.render(choice=choice, **context) for choice in item['choices']] return prompts def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): """Render complete prompts for a schema question""" template_str = """ {%- for example in fewshot_examples -%} {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} {% endfor -%} {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] context = { 'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item } prompts = [template.render(context=context_option, **context) for context_option in item['context_options']] return prompts def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): """ Render complete prompt for a language modeling task. Notice that we manually trim the context in the template, which in some datasets seems to have trailing whitespace (which we don't want). """ template_str = """ {%- for example in fewshot_examples -%} {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }} {% endfor -%} {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] context = { 'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item } # Return two prompts: without and with the continuation prompt_without = template.render(include_continuation=False, **context) prompt_with = template.render(include_continuation=True, **context) # Due to the way the data seems to be stored, I think I need to strip in the case of LM here. # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next # token in prompt_with), meaning we don't get a nice and clean prefix in the token space # to detect the final continuation. Tokenizers... prompt_without = prompt_without.strip() return [prompt_without, prompt_with] def find_common_length(token_sequences, direction='left'): """ Find the length of the common prefix or suffix across token sequences - direction: 'left' for prefix, 'right' for suffix """ min_len = min(len(seq) for seq in token_sequences) indices = { 'left': range(min_len), 'right': range(-1, -min_len-1, -1) }[direction] # Find the first position where the token sequences differ for i, idx in enumerate(indices): token = token_sequences[0][idx] if not all(seq[idx] == token for seq in token_sequences): return i return min_len def stack_sequences(tokens, pad_token_id): """Stack up a list of token sequences, pad to longest on the right""" bsz, seq_len = len(tokens), max(len(x) for x in tokens) input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) for i, x in enumerate(tokens): input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) return input_ids def batch_sequences_mc(tokenizer, prompts): # In multiple choice, contexts are the same but the continuation is different (common prefix) tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) # figure out the start and end of each continuation answer_start_idx = find_common_length(tokens, direction='left') start_indices = [answer_start_idx] * len(prompts) end_indices = [len(x) for x in tokens] return tokens, start_indices, end_indices def batch_sequences_schema(tokenizer, prompts): # In schema tasks, contexts vary but continuation is the same (common suffix) tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) # figure out the start and end of each context suffix_length = find_common_length(tokens, direction='right') end_indices = [len(x) for x in tokens] start_indices = [ei - suffix_length for ei in end_indices] return tokens, start_indices, end_indices def batch_sequences_lm(tokenizer, prompts): # In LM tasks, we have two prompts: without and with continuation tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) tokens_without, tokens_with = tokens start_idx, end_idx = len(tokens_without), len(tokens_with) assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with" assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with" # we only need the with continuation prompt in the LM task, i.e. batch size of 1 return [tokens_with], [start_idx], [end_idx] @torch.no_grad() def forward_model(model, input_ids): """ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. The last column of losses is set to nan because we don't have autoregressive targets there. """ batch_size, seq_len = input_ids.size() outputs = model(input_ids) # Roll the tensor to the left by one position to get the (autoregressive) target ids target_ids = torch.roll(input_ids, shifts=-1, dims=1) # Calculate cross entropy at all positions losses = torch.nn.functional.cross_entropy( outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none' ).view(batch_size, seq_len) # Set the last column to be nan because there is no autoregressive loss there losses[:, -1] = float('nan') # Get the argmax predictions at each position predictions = outputs.argmax(dim=-1) return losses, predictions @torch.no_grad() def evaluate_example(idx, model, tokenizer, data, device, task_meta): """Evaluate a single example, return True if correct, False otherwise""" item = data[idx] task_type = task_meta['task_type'] num_fewshot = task_meta['num_fewshot'] continuation_delimiter = task_meta['continuation_delimiter'] # Sample few-shot examples (excluding current item) fewshot_examples = [] if num_fewshot > 0: rng = random.Random(1234 + idx) available_indices = [i for i in range(len(data)) if i != idx] fewshot_indices = rng.sample(available_indices, num_fewshot) fewshot_examples = [data[i] for i in fewshot_indices] # Render prompts and batch sequences based on task type if task_type == 'multiple_choice': prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) elif task_type == 'schema': prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) elif task_type == 'language_modeling': prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) else: raise ValueError(f"Unsupported task type: {task_type}") # Some models can't forward sequences beyond a certain length (e.g. GPT-2) # In these cases, we have to truncate sequences to max length and adjust the indices if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: max_tokens = model.max_seq_len new_tokens, new_start_idxs, new_end_idxs = [], [], [] for t, s, e in zip(tokens, start_idxs, end_idxs): if len(t) > max_tokens: num_to_crop = len(t) - max_tokens new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens new_start_idxs.append(s - num_to_crop) # shift the indices down new_end_idxs.append(e - num_to_crop) assert s - num_to_crop >= 0, "this should never happen right?" assert e - num_to_crop >= 0, "this should never happen right?" else: new_tokens.append(t) # keep unchanged new_start_idxs.append(s) new_end_idxs.append(e) tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs # Stack up all the sequences into a batch pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok input_ids = stack_sequences(tokens, pad_token_id) input_ids = input_ids.to(device) # Forward the model, get the autoregressive loss and argmax prediction at each token losses, predictions = forward_model(model, input_ids) # See if the losses/predictions come out correctly if task_type == 'language_modeling': # language modeling task is currently always batch size 1 si = start_idxs[0] ei = end_idxs[0] # predictions[i] predict input_ids[i+1] autoregressively predicted_tokens = predictions[0, si-1:ei-1] actual_tokens = input_ids[0, si:ei] is_correct = torch.all(predicted_tokens == actual_tokens).item() elif task_type in ['multiple_choice', 'schema']: # For MC/schema: find the option with lowest average loss mean_losses = [losses[i, si-1:ei-1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] pred_idx = mean_losses.index(min(mean_losses)) is_correct = pred_idx == item['gold'] else: raise ValueError(f"Unsupported task type: {task_type}") return is_correct def evaluate_task(model, tokenizer, data, device, task_meta): """ This function is responsible for evaluating one task across many examples. It also handles dispatch to all processes if the script is run with torchrun. """ rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 correct = torch.zeros(len(data), dtype=torch.float32, device=device) # stride the examples to each rank for idx in range(rank, len(data), world_size): is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) correct[idx] = float(is_correct) # sync results across all the processes if running distributed if world_size > 1: dist.barrier() dist.all_reduce(correct, op=dist.ReduceOp.SUM) # compute the mean mean_correct = correct.mean().item() return mean_correct