|
|
""" |
|
|
Engine for efficient inference of our models. |
|
|
|
|
|
Everything works around token sequences: |
|
|
- The user can send token sequences to the engine |
|
|
- The engine returns the next token |
|
|
|
|
|
Notes: |
|
|
- The engine knows nothing about tokenization, it's purely token id sequences. |
|
|
|
|
|
The whole thing is made as efficient as possible. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import signal |
|
|
import warnings |
|
|
from contextlib import contextmanager |
|
|
from collections import deque |
|
|
from nanochat.common import compute_init |
|
|
from nanochat.checkpoint_manager import load_model |
|
|
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def timeout(duration, formula): |
|
|
def timeout_handler(signum, frame): |
|
|
raise Exception(f"'{formula}': timed out after {duration} seconds") |
|
|
|
|
|
signal.signal(signal.SIGALRM, timeout_handler) |
|
|
signal.alarm(duration) |
|
|
yield |
|
|
signal.alarm(0) |
|
|
|
|
|
def eval_with_timeout(formula, max_time=3): |
|
|
try: |
|
|
with timeout(max_time, formula): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", SyntaxWarning) |
|
|
return eval(formula) |
|
|
except Exception as e: |
|
|
signal.alarm(0) |
|
|
|
|
|
return None |
|
|
|
|
|
def use_calculator(expr): |
|
|
"""Evaluate a math expression safely.""" |
|
|
expr = expr.replace(",", "") |
|
|
if any([x not in "0123456789*+-/.() " for x in expr]): |
|
|
return None |
|
|
if "**" in expr: |
|
|
return None |
|
|
return eval_with_timeout(expr) |
|
|
|
|
|
|
|
|
class KVCache: |
|
|
""" |
|
|
Works hand-in-hand with the GPT model to maintain the KV cache. |
|
|
Note that the .pos advances automatically after the last layer of the Transformer inserts. |
|
|
""" |
|
|
|
|
|
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): |
|
|
|
|
|
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) |
|
|
self.kv_cache = None |
|
|
self.pos = 0 |
|
|
|
|
|
def reset(self): |
|
|
self.pos = 0 |
|
|
|
|
|
def get_pos(self): |
|
|
return self.pos |
|
|
|
|
|
def prefill(self, other): |
|
|
""" |
|
|
Prefill given another KV cache. Optionally expand along batch dim. |
|
|
This is used when we do batch 1 prefill and then want to generate |
|
|
multiple samples in parallel from there. |
|
|
""" |
|
|
|
|
|
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" |
|
|
assert other.kv_cache is not None, "Cannot prefill with a None KV cache" |
|
|
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): |
|
|
if ix in [0, 1, 3, 5]: |
|
|
|
|
|
assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}" |
|
|
elif ix == 2: |
|
|
|
|
|
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" |
|
|
elif ix == 4: |
|
|
|
|
|
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" |
|
|
|
|
|
dtype, device = other.kv_cache.dtype, other.kv_cache.device |
|
|
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) |
|
|
|
|
|
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache |
|
|
|
|
|
self.pos = other.pos |
|
|
|
|
|
def insert_kv(self, layer_idx, k, v): |
|
|
|
|
|
if self.kv_cache is None: |
|
|
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) |
|
|
|
|
|
B, H, T_add, D = k.size() |
|
|
t0, t1 = self.pos, self.pos + T_add |
|
|
|
|
|
if t1 > self.kv_cache.size(4): |
|
|
t_needed = t1 + 1024 |
|
|
t_needed = (t_needed + 1023) & ~1023 |
|
|
current_shape = list(self.kv_cache.shape) |
|
|
current_shape[4] = t_needed |
|
|
self.kv_cache.resize_(current_shape) |
|
|
|
|
|
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k |
|
|
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v |
|
|
|
|
|
key_view = self.kv_cache[layer_idx, 0, :, :, :t1] |
|
|
value_view = self.kv_cache[layer_idx, 1, :, :, :t1] |
|
|
|
|
|
if layer_idx == self.kv_cache.size(0) - 1: |
|
|
self.pos = t1 |
|
|
return key_view, value_view |
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def sample_next_token(logits, rng, temperature=1.0, top_k=None): |
|
|
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" |
|
|
assert temperature >= 0.0, "temperature must be non-negative" |
|
|
if temperature == 0.0: |
|
|
return torch.argmax(logits, dim=-1, keepdim=True) |
|
|
if top_k is not None: |
|
|
k = min(top_k, logits.size(-1)) |
|
|
vals, idx = torch.topk(logits, k, dim=-1) |
|
|
vals = vals / temperature |
|
|
probs = F.softmax(vals, dim=-1) |
|
|
choice = torch.multinomial(probs, num_samples=1, generator=rng) |
|
|
return idx.gather(1, choice) |
|
|
else: |
|
|
logits = logits / temperature |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
return torch.multinomial(probs, num_samples=1, generator=rng) |
|
|
|
|
|
|
|
|
|
|
|
class RowState: |
|
|
|
|
|
def __init__(self, current_tokens=None): |
|
|
self.current_tokens = current_tokens or [] |
|
|
self.forced_tokens = deque() |
|
|
self.in_python_block = False |
|
|
self.python_expr_tokens = [] |
|
|
self.completed = False |
|
|
|
|
|
class Engine: |
|
|
|
|
|
def __init__(self, model, tokenizer): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): |
|
|
"""Same as generate, but does single prefill and then clones the KV cache.""" |
|
|
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" |
|
|
device = self.model.get_device() |
|
|
rng = torch.Generator(device=device) |
|
|
rng.manual_seed(seed) |
|
|
|
|
|
|
|
|
get_special = lambda s: self.tokenizer.encode_special(s) |
|
|
python_start = get_special("<|python_start|>") |
|
|
python_end = get_special("<|python_end|>") |
|
|
output_start = get_special("<|output_start|>") |
|
|
output_end = get_special("<|output_end|>") |
|
|
assistant_end = get_special("<|assistant_end|>") |
|
|
bos = self.tokenizer.get_bos_token_id() |
|
|
|
|
|
|
|
|
m = self.model.config |
|
|
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} |
|
|
kv_cache_prefill = KVCache( |
|
|
batch_size=1, |
|
|
seq_len=len(tokens), |
|
|
**kv_model_kwargs, |
|
|
) |
|
|
ids = torch.tensor([tokens], dtype=torch.long, device=device) |
|
|
logits = self.model.forward(ids, kv_cache=kv_cache_prefill) |
|
|
logits = logits[:, -1, :] |
|
|
next_ids = sample_next_token(logits, rng, temperature, top_k) |
|
|
sampled_tokens = next_ids[:, 0].tolist() |
|
|
|
|
|
|
|
|
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len |
|
|
kv_cache_decode = KVCache( |
|
|
batch_size=num_samples, |
|
|
seq_len=kv_length_hint, |
|
|
**kv_model_kwargs, |
|
|
) |
|
|
kv_cache_decode.prefill(kv_cache_prefill) |
|
|
del kv_cache_prefill |
|
|
|
|
|
|
|
|
row_states = [RowState(tokens.copy()) for _ in range(num_samples)] |
|
|
|
|
|
|
|
|
num_generated = 0 |
|
|
first_iteration = True |
|
|
while True: |
|
|
|
|
|
if max_tokens is not None and num_generated >= max_tokens: |
|
|
break |
|
|
|
|
|
if all(state.completed for state in row_states): |
|
|
break |
|
|
|
|
|
|
|
|
if first_iteration: |
|
|
|
|
|
sampled_tokens = [sampled_tokens[0]] * num_samples |
|
|
|
|
|
first_iteration = False |
|
|
else: |
|
|
|
|
|
logits = self.model.forward(ids, kv_cache=kv_cache_decode) |
|
|
logits = logits[:, -1, :] |
|
|
next_ids = sample_next_token(logits, rng, temperature, top_k) |
|
|
sampled_tokens = next_ids[:, 0].tolist() |
|
|
|
|
|
|
|
|
token_column = [] |
|
|
token_masks = [] |
|
|
for i, state in enumerate(row_states): |
|
|
|
|
|
is_forced = len(state.forced_tokens) > 0 |
|
|
token_masks.append(0 if is_forced else 1) |
|
|
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i] |
|
|
token_column.append(next_token) |
|
|
|
|
|
state.current_tokens.append(next_token) |
|
|
|
|
|
if next_token == assistant_end or next_token == bos: |
|
|
state.completed = True |
|
|
|
|
|
if next_token == python_start: |
|
|
state.in_python_block = True |
|
|
state.python_expr_tokens = [] |
|
|
elif next_token == python_end and state.in_python_block: |
|
|
state.in_python_block = False |
|
|
if state.python_expr_tokens: |
|
|
expr = self.tokenizer.decode(state.python_expr_tokens) |
|
|
result = use_calculator(expr) |
|
|
if result is not None: |
|
|
result_tokens = self.tokenizer.encode(str(result)) |
|
|
state.forced_tokens.append(output_start) |
|
|
state.forced_tokens.extend(result_tokens) |
|
|
state.forced_tokens.append(output_end) |
|
|
state.python_expr_tokens = [] |
|
|
elif state.in_python_block: |
|
|
state.python_expr_tokens.append(next_token) |
|
|
|
|
|
|
|
|
yield token_column, token_masks |
|
|
num_generated += 1 |
|
|
|
|
|
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) |
|
|
|
|
|
def generate_batch(self, tokens, num_samples=1, **kwargs): |
|
|
""" |
|
|
Non-streaming batch generation that just returns the final token sequences. |
|
|
Returns a list of token sequences (list of lists of ints). |
|
|
Terminal tokens (assistant_end, bos) are not included in the results. |
|
|
""" |
|
|
assistant_end = self.tokenizer.encode_special("<|assistant_end|>") |
|
|
bos = self.tokenizer.get_bos_token_id() |
|
|
results = [tokens.copy() for _ in range(num_samples)] |
|
|
masks = [[0] * len(tokens) for _ in range(num_samples)] |
|
|
completed = [False] * num_samples |
|
|
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): |
|
|
for i, (token, mask) in enumerate(zip(token_column, token_masks)): |
|
|
if not completed[i]: |
|
|
if token == assistant_end or token == bos: |
|
|
completed[i] = True |
|
|
else: |
|
|
results[i].append(token) |
|
|
masks[i].append(mask) |
|
|
|
|
|
if all(completed): |
|
|
break |
|
|
return results, masks |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
Quick inline test to make sure that the naive/slow model.generate function |
|
|
is equivalent to the faster Engine.generate function here. |
|
|
""" |
|
|
import time |
|
|
|
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() |
|
|
|
|
|
model, tokenizer, meta = load_model("base", device, phase="eval") |
|
|
bos_token_id = tokenizer.get_bos_token_id() |
|
|
|
|
|
kwargs = dict(max_tokens=64, temperature=0.0) |
|
|
|
|
|
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) |
|
|
|
|
|
generated_tokens = [] |
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
stream = model.generate(prompt_tokens, **kwargs) |
|
|
for token in stream: |
|
|
generated_tokens.append(token) |
|
|
chunk = tokenizer.decode([token]) |
|
|
print(chunk, end="", flush=True) |
|
|
print() |
|
|
torch.cuda.synchronize() |
|
|
t1 = time.time() |
|
|
print(f"Reference time: {t1 - t0:.2f}s") |
|
|
reference_ids = generated_tokens |
|
|
|
|
|
generated_tokens = [] |
|
|
engine = Engine(model, tokenizer) |
|
|
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) |
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
for token_column, token_masks in stream: |
|
|
token = token_column[0] |
|
|
generated_tokens.append(token) |
|
|
chunk = tokenizer.decode([token]) |
|
|
print(chunk, end="", flush=True) |
|
|
print() |
|
|
torch.cuda.synchronize() |
|
|
t1 = time.time() |
|
|
print(f"Engine time: {t1 - t0:.2f}s") |
|
|
|
|
|
for i in range(len(reference_ids)): |
|
|
if reference_ids[i] != generated_tokens[i]: |
|
|
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}") |
|
|
break |
|
|
print(f"Match: {reference_ids == generated_tokens}") |
|
|
|