Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| from accelerate import Accelerator | |
| from arguments import EvaluationArguments | |
| from datasets import load_dataset | |
| from torch.utils.data import IterableDataset | |
| from torch.utils.data.dataloader import DataLoader | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed | |
| class ConstantLengthDataset(IterableDataset): | |
| def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6): | |
| self.tokenizer = tokenizer | |
| self.concat_token_id = tokenizer.bos_token_id | |
| self.dataset = dataset | |
| self.seq_length = seq_length | |
| self.input_characters = seq_length * chars_per_token * num_of_sequences | |
| def __iter__(self): | |
| iterator = iter(self.dataset) | |
| more_examples = True | |
| while more_examples: | |
| buffer, buffer_len = [], 0 | |
| while True: | |
| if buffer_len >= self.input_characters: | |
| break | |
| try: | |
| buffer.append(next(iterator)["content"]) | |
| buffer_len += len(buffer[-1]) | |
| except StopIteration: | |
| more_examples = False | |
| break | |
| tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"] | |
| all_token_ids = [] | |
| for tokenized_input in tokenized_inputs: | |
| all_token_ids.extend(tokenized_input + [self.concat_token_id]) | |
| for i in range(0, len(all_token_ids), self.seq_length): | |
| input_ids = all_token_ids[i : i + self.seq_length] | |
| if len(input_ids) == self.seq_length: | |
| yield torch.tensor(input_ids) | |
| def create_dataloader(args): | |
| ds_kwargs = {"streaming": True} | |
| valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs) | |
| valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length) | |
| eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size) | |
| return eval_dataloader | |
| def evaluate(args): | |
| model.eval() | |
| losses = [] | |
| for step, batch in enumerate(eval_dataloader): | |
| with torch.no_grad(): | |
| outputs = model(batch, labels=batch) | |
| loss = outputs.loss.repeat(args.batch_size) | |
| losses.append(accelerator.gather(loss)) | |
| if args.max_eval_steps > 0 and step >= args.max_eval_steps: | |
| break | |
| loss = torch.mean(torch.cat(losses)) | |
| try: | |
| perplexity = torch.exp(loss) | |
| except OverflowError: | |
| perplexity = float("inf") | |
| return loss.item(), perplexity.item() | |
| # Setup Accelerator | |
| accelerator = Accelerator() | |
| # Parse configuration | |
| parser = HfArgumentParser(EvaluationArguments) | |
| args = parser.parse_args() | |
| set_seed(args.seed) | |
| # Logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO | |
| ) | |
| # Load model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) | |
| # Load dataset and dataloader | |
| eval_dataloader = create_dataloader(args) | |
| # Prepare everything with our `accelerator`. | |
| model, eval_dataloader = accelerator.prepare(model, eval_dataloader) | |
| # Evaluate and save the last checkpoint | |
| logger.info("Evaluating and saving model after training") | |
| eval_loss, perplexity = evaluate(args) | |
| logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}") | |