""" Engine for efficient inference of our models. Everything works around token sequences: - The user can send token sequences to the engine - The engine returns the next token Notes: - The engine knows nothing about tokenization, it's purely token id sequences. The whole thing is made as efficient as possible. """ import torch import torch.nn.functional as F import signal import warnings from contextlib import contextmanager from collections import deque from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model # ----------------------------------------------------------------------------- # Calculator tool helpers @contextmanager def timeout(duration, formula): def timeout_handler(signum, frame): raise Exception(f"'{formula}': timed out after {duration} seconds") signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(duration) yield signal.alarm(0) def eval_with_timeout(formula, max_time=3): try: with timeout(max_time, formula): with warnings.catch_warnings(): warnings.simplefilter("ignore", SyntaxWarning) return eval(formula) except Exception as e: signal.alarm(0) # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage return None def use_calculator(expr): """Evaluate a math expression safely.""" expr = expr.replace(",", "") if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars return None if "**" in expr: # for now disallow power operator, could be very expensive return None return eval_with_timeout(expr) # ----------------------------------------------------------------------------- class KVCache: """ Works hand-in-hand with the GPT model to maintain the KV cache. Note that the .pos advances automatically after the last layer of the Transformer inserts. """ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) self.kv_cache = None self.pos = 0 # current position in time in the cache def reset(self): self.pos = 0 def get_pos(self): return self.pos def prefill(self, other): """ Prefill given another KV cache. Optionally expand along batch dim. This is used when we do batch 1 prefill and then want to generate multiple samples in parallel from there. """ # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): if ix in [0, 1, 3, 5]: # num_layers, batch_size, num_heads, head_dim must match assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}" elif ix == 2: # batch_size can be expanded assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" elif ix == 4: # seq_len: self must be longer than other assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) # 3) copy the data over self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache # 4) update the pos self.pos = other.pos def insert_kv(self, layer_idx, k, v): # Lazy initialize the cache here because we need to know the dtype/device if self.kv_cache is None: self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) # Insert new keys/values to the cache and return the full cache so far B, H, T_add, D = k.size() t0, t1 = self.pos, self.pos + T_add # Dynamically grow the cache if needed if t1 > self.kv_cache.size(4): t_needed = t1 + 1024 # as much as we need plus buffer of 1024 t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 current_shape = list(self.kv_cache.shape) current_shape[4] = t_needed self.kv_cache.resize_(current_shape) # Insert k, v into the cache self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 1, :, :, t0:t1] = v # Return the full cached keys/values up to current position (as a view) key_view = self.kv_cache[layer_idx, 0, :, :, :t1] value_view = self.kv_cache[layer_idx, 1, :, :, :t1] # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1 return key_view, value_view # ----------------------------------------------------------------------------- @torch.inference_mode() def sample_next_token(logits, rng, temperature=1.0, top_k=None): """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" assert temperature >= 0.0, "temperature must be non-negative" if temperature == 0.0: return torch.argmax(logits, dim=-1, keepdim=True) if top_k is not None: k = min(top_k, logits.size(-1)) vals, idx = torch.topk(logits, k, dim=-1) vals = vals / temperature probs = F.softmax(vals, dim=-1) choice = torch.multinomial(probs, num_samples=1, generator=rng) return idx.gather(1, choice) else: logits = logits / temperature probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1, generator=rng) # ----------------------------------------------------------------------------- class RowState: # Per-row state tracking during generation def __init__(self, current_tokens=None): self.current_tokens = current_tokens or [] # Current token sequence for this row self.forced_tokens = deque() # Queue of tokens to force inject self.in_python_block = False # Whether we are inside a python block self.python_expr_tokens = [] # Tokens of the current python expression self.completed = False # Whether this row has completed generation class Engine: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer # needed for tool use @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): """Same as generate, but does single prefill and then clones the KV cache.""" assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" device = self.model.get_device() rng = torch.Generator(device=device) rng.manual_seed(seed) # Get the special tokens we need to coordinate the tool use state machine get_special = lambda s: self.tokenizer.encode_special(s) python_start = get_special("<|python_start|>") python_end = get_special("<|python_end|>") output_start = get_special("<|output_start|>") output_end = get_special("<|output_end|>") assistant_end = get_special("<|assistant_end|>") # if sampled, ends row bos = self.tokenizer.get_bos_token_id() # if sampled, ends row # 1) Run a batch 1 prefill of the prompt tokens m = self.model.config kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} kv_cache_prefill = KVCache( batch_size=1, seq_len=len(tokens), **kv_model_kwargs, ) ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() # 2) Replicate the KV cache for each sample/row kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len kv_cache_decode = KVCache( batch_size=num_samples, seq_len=kv_length_hint, **kv_model_kwargs, ) kv_cache_decode.prefill(kv_cache_prefill) del kv_cache_prefill # no need to keep this memory around # 3) Initialize states for each sample row_states = [RowState(tokens.copy()) for _ in range(num_samples)] # 4) Main generation loop num_generated = 0 first_iteration = True while True: # Stop condition: we've reached max tokens if max_tokens is not None and num_generated >= max_tokens: break # Stop condition: all rows are completed if all(state.completed for state in row_states): break # Get sampled tokens - either from prefill or from forward pass if first_iteration: # Use the tokens we already sampled from prefill sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows # TODO: we should sample a token for each row instead of broadcasting first_iteration = False else: # Forward the model and get the next token for each row logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) logits = logits[:, -1, :] # (B, vocab_size) at last time step next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() # Process each row: choose the next token, update state, optional tool use token_column = [] # contains the next token id along each row token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row for i, state in enumerate(row_states): # Select the next token in this row is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i] token_column.append(next_token) # Update the state of this row to include the next token state.current_tokens.append(next_token) # On <|assistant_end|> or <|bos|>, mark the row as completed if next_token == assistant_end or next_token == bos: state.completed = True # Handle tool logic if next_token == python_start: state.in_python_block = True state.python_expr_tokens = [] elif next_token == python_end and state.in_python_block: state.in_python_block = False if state.python_expr_tokens: expr = self.tokenizer.decode(state.python_expr_tokens) result = use_calculator(expr) if result is not None: result_tokens = self.tokenizer.encode(str(result)) state.forced_tokens.append(output_start) state.forced_tokens.extend(result_tokens) state.forced_tokens.append(output_end) state.python_expr_tokens = [] elif state.in_python_block: state.python_expr_tokens.append(next_token) # Yield the token column yield token_column, token_masks num_generated += 1 # Prepare ids for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) def generate_batch(self, tokens, num_samples=1, **kwargs): """ Non-streaming batch generation that just returns the final token sequences. Returns a list of token sequences (list of lists of ints). Terminal tokens (assistant_end, bos) are not included in the results. """ assistant_end = self.tokenizer.encode_special("<|assistant_end|>") bos = self.tokenizer.get_bos_token_id() results = [tokens.copy() for _ in range(num_samples)] masks = [[0] * len(tokens) for _ in range(num_samples)] completed = [False] * num_samples for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): for i, (token, mask) in enumerate(zip(token_column, token_masks)): if not completed[i]: if token == assistant_end or token == bos: completed[i] = True else: results[i].append(token) masks[i].append(mask) # Stop if all rows are completed if all(completed): break return results, masks if __name__ == "__main__": """ Quick inline test to make sure that the naive/slow model.generate function is equivalent to the faster Engine.generate function here. """ import time # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() # common hyperparameters kwargs = dict(max_tokens=64, temperature=0.0) # set the starting prompt prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) # generate the reference sequence using the model.generate() function generated_tokens = [] torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) for token in stream: generated_tokens.append(token) chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() print(f"Reference time: {t1 - t0:.2f}s") reference_ids = generated_tokens # generate tokens with Engine generated_tokens = [] engine = Engine(model, tokenizer) stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() for token_column, token_masks in stream: token = token_column[0] # only print out the first row generated_tokens.append(token) chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() print(f"Engine time: {t1 - t0:.2f}s") # compare the two sequences for i in range(len(reference_ids)): if reference_ids[i] != generated_tokens[i]: print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}") break print(f"Match: {reference_ids == generated_tokens}")