Spaces:
Sleeping
Sleeping
| from collections import deque | |
| import torch | |
| from nanochat.common import get_dist_info | |
| from nanochat.dataset import parquets_iter_batched | |
| from nanochat.tokenizer import get_tokenizer | |
| def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): | |
| """Stream pretraining text from parquet files, tokenize, yield training batches.""" | |
| assert split in ["train", "val"], "split must be 'train' or 'val'" | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() | |
| needed_tokens = B * T + 1 # +1 is because we also need the target at the last token | |
| # get the tokenizer and the bos token | |
| tokenizer = get_tokenizer() | |
| bos_token = tokenizer.get_bos_token_id() | |
| # scratch buffer holds the tokens for one iteration | |
| token_buffer = deque() # we stream tokens on the right and pop from the left | |
| scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) | |
| # infinite iterator over document batches | |
| def document_batches(): | |
| while True: | |
| # batch will iterate in group size of the parquet files, usually e.g. 1024 rows | |
| for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): | |
| # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows | |
| for i in range(0, len(batch), tokenizer_batch_size): | |
| yield batch[i:i+tokenizer_batch_size] | |
| batches = document_batches() | |
| batch_index = 0 | |
| while True: | |
| # Accumulate enough tokens for one iteration before yielding. | |
| while len(token_buffer) < needed_tokens: | |
| doc_batch = next(batches) | |
| token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) | |
| for tokens in token_lists: | |
| token_buffer.extend(tokens) | |
| batch_index += 1 | |
| # Move tokens from the deque into the scratch buffer | |
| for i in range(needed_tokens): | |
| scratch[i] = token_buffer.popleft() | |
| # Create the inputs/targets as 1D tensors | |
| inputs_cpu = scratch[:-1].to(dtype=torch.int32) | |
| targets_cpu = scratch[1:] | |
| # Reshape to 2D and move to GPU async | |
| inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) | |
| targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) | |
| yield inputs, targets | |