Spaces:
Runtime error
Runtime error
File size: 11,572 Bytes
96cfd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
"""
Functions for evaluating the CORE metric, as described in the DCLM paper.
https://arxiv.org/abs/2406.11794
TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""
import random
from jinja2 import Template
import torch
import torch.distributed as dist
# -----------------------------------------------------------------------------
# Prompt rendering utilities
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
{% endfor -%}
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a schema question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(context=context_option, **context)
for context_option in item['context_options']]
return prompts
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
"""
Render complete prompt for a language modeling task.
Notice that we manually trim the context in the template,
which in some datasets seems to have trailing whitespace (which we don't want).
"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
# Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context)
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
# to detect the final continuation. Tokenizers...
prompt_without = prompt_without.strip()
return [prompt_without, prompt_with]
def find_common_length(token_sequences, direction='left'):
"""
Find the length of the common prefix or suffix across token sequences
- direction: 'left' for prefix, 'right' for suffix
"""
min_len = min(len(seq) for seq in token_sequences)
indices = {
'left': range(min_len),
'right': range(-1, -min_len-1, -1)
}[direction]
# Find the first position where the token sequences differ
for i, idx in enumerate(indices):
token = token_sequences[0][idx]
if not all(seq[idx] == token for seq in token_sequences):
return i
return min_len
def stack_sequences(tokens, pad_token_id):
"""Stack up a list of token sequences, pad to longest on the right"""
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
for i, x in enumerate(tokens):
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
return input_ids
def batch_sequences_mc(tokenizer, prompts):
# In multiple choice, contexts are the same but the continuation is different (common prefix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each continuation
answer_start_idx = find_common_length(tokens, direction='left')
start_indices = [answer_start_idx] * len(prompts)
end_indices = [len(x) for x in tokens]
return tokens, start_indices, end_indices
def batch_sequences_schema(tokenizer, prompts):
# In schema tasks, contexts vary but continuation is the same (common suffix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each context
suffix_length = find_common_length(tokens, direction='right')
end_indices = [len(x) for x in tokens]
start_indices = [ei - suffix_length for ei in end_indices]
return tokens, start_indices, end_indices
def batch_sequences_lm(tokenizer, prompts):
# In LM tasks, we have two prompts: without and with continuation
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
tokens_without, tokens_with = tokens
start_idx, end_idx = len(tokens_without), len(tokens_with)
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
return [tokens_with], [start_idx], [end_idx]
@torch.no_grad()
def forward_model(model, input_ids):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
The last column of losses is set to nan because we don't have autoregressive targets there.
"""
batch_size, seq_len = input_ids.size()
outputs = model(input_ids)
# Roll the tensor to the left by one position to get the (autoregressive) target ids
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
return losses, predictions
@torch.no_grad()
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
"""Evaluate a single example, return True if correct, False otherwise"""
item = data[idx]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
continuation_delimiter = task_meta['continuation_delimiter']
# Sample few-shot examples (excluding current item)
fewshot_examples = []
if num_fewshot > 0:
rng = random.Random(1234 + idx)
available_indices = [i for i in range(len(data)) if i != idx]
fewshot_indices = rng.sample(available_indices, num_fewshot)
fewshot_examples = [data[i] for i in fewshot_indices]
# Render prompts and batch sequences based on task type
if task_type == 'multiple_choice':
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
elif task_type == 'schema':
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
elif task_type == 'language_modeling':
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
else:
raise ValueError(f"Unsupported task type: {task_type}")
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
# In these cases, we have to truncate sequences to max length and adjust the indices
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_end_idxs.append(e - num_to_crop)
assert s - num_to_crop >= 0, "this should never happen right?"
assert e - num_to_crop >= 0, "this should never happen right?"
else:
new_tokens.append(t) # keep unchanged
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
# Forward the model, get the autoregressive loss and argmax prediction at each token
losses, predictions = forward_model(model, input_ids)
# See if the losses/predictions come out correctly
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']
else:
raise ValueError(f"Unsupported task type: {task_type}")
return is_correct
def evaluate_task(model, tokenizer, data, device, task_meta):
"""
This function is responsible for evaluating one task across many examples.
It also handles dispatch to all processes if the script is run with torchrun.
"""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# stride the examples to each rank
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
# sync results across all the processes if running distributed
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
# compute the mean
mean_correct = correct.mean().item()
return mean_correct
|