Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import numpy as np | |
| import tqdm | |
| from pathlib import Path | |
| from pprint import pprint | |
| from collections import defaultdict, Counter | |
| from transformers import AutoTokenizer | |
| import scrl.utils as utils | |
| from scrl.model import load_checkpoint | |
| from scrl.metrics import compute_token_f1, rouge_scorer, ROUGE_TYPES | |
| from nltk import word_tokenize | |
| from scrl.rewards import load_rewards | |
| from scrl.config import load_config | |
| import time | |
| def main(args): | |
| model = load_checkpoint(Path(args.checkpoint), device=args.device) | |
| tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") | |
| dataset = list(utils.read_jsonl(args.dataset)) | |
| batches = utils.batchify(dataset, args.batch_size) | |
| outputs = [] | |
| t1 = time.time() | |
| for items in tqdm.tqdm(batches): | |
| sources = [x["text"] for x in items] | |
| summaries = model.predict(sources, tokenizer, args.device) | |
| for item, summary in zip(items, summaries): | |
| output = { | |
| "id": item["id"], | |
| "pred-summary": summary, | |
| } | |
| outputs.append(output) | |
| t2 = time.time() | |
| print("Seconds:", t2-t1) | |
| if args.output: | |
| utils.write_jsonl(outputs, args.output, "w") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset', required=True) | |
| parser.add_argument('--output', required=False) | |
| parser.add_argument('--checkpoint', required=True) | |
| parser.add_argument('--device', default="cpu") | |
| parser.add_argument('--batch-size', type=int, default=4) | |
| return parser.parse_args() | |
| if __name__ == '__main__': | |
| main(parse_args()) | |