Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import numpy as np | |
| import tqdm | |
| from pathlib import Path | |
| from pprint import pprint | |
| from collections import defaultdict, Counter | |
| from transformers import AutoTokenizer | |
| import sys | |
| sys.path.append("/home/hdd/lijinyi/CompressionInAvalon/promptcompressor/SCRL_new") | |
| print(sys.path) | |
| import scrl.utils as utils | |
| from scrl.model import load_checkpoint, load_model | |
| from scrl.eval_metrics import compute_token_f1, rouge_scorer, ROUGE_TYPES | |
| from nltk import word_tokenize | |
| import nltk | |
| nltk.download('punkt') | |
| print("punkt done!") | |
| def main(args): | |
| if args.model_dir is not None and args.checkpoint is None: | |
| model = load_model( | |
| Path(args.model_dir), device=args.device, prefix="best" | |
| ) | |
| elif args.model_dir is None and args.checkpoint is not None: | |
| model = load_checkpoint(Path(args.checkpoint), device=args.device) | |
| else: | |
| raise Exception("Provide either a model directory or checkpoint.") | |
| model = load_model(Path(args.model_dir), device=args.device) | |
| tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") | |
| dataset = list(utils.read_jsonl(args.dataset)) | |
| all_scores = defaultdict(list) | |
| for item in tqdm.tqdm(dataset): | |
| src = item["text"] | |
| if args.lower_src: | |
| src = src.lower() | |
| tgts = item["summaries"] | |
| pred = model.predict([src], tokenizer, args.device)[0] | |
| if args.max_chars > 0: | |
| pred = pred[:args.max_chars] | |
| src_tokens = word_tokenize(src) | |
| pred_tokens = word_tokenize(pred) | |
| if args.lower_summary: | |
| pred_tokens = [t.lower() for t in pred_tokens] | |
| if args.pretokenized: | |
| src_tokens = src.split() | |
| else: | |
| src_tokens = word_tokenize(src) | |
| item_scores = defaultdict(list) | |
| for tgt in tgts: | |
| if args.pretokenized: | |
| tgt_tokens = tgt.split() | |
| else: | |
| tgt_tokens = word_tokenize(tgt) | |
| if args.lower_summary: | |
| tgt_tokens = [t.lower() for t in tgt_tokens] | |
| token_fscore = compute_token_f1(tgt_tokens, pred_tokens, use_counts=True) | |
| rouge_scores = rouge_scorer.score(tgt, pred) | |
| for rouge_type, rouge_type_scores in rouge_scores.items(): | |
| item_scores[f"{rouge_type}-p"].append(rouge_type_scores.precision) | |
| item_scores[f"{rouge_type}-r"].append(rouge_type_scores.recall) | |
| item_scores[f"{rouge_type}-f"].append(rouge_type_scores.fmeasure) | |
| item_scores["token-f1"].append(token_fscore) | |
| item_scores["tgt-len"].append(len(tgt_tokens)) | |
| item_scores["tgt-cr"].append(len(tgt_tokens) / len(src_tokens)) | |
| for k, values in item_scores.items(): | |
| item_mean = np.mean(values) | |
| all_scores[k].append(item_mean) | |
| all_scores["pred-len"].append(len(pred_tokens)) | |
| all_scores["src-len"].append(len(src_tokens)) | |
| all_scores["pred-cr"].append(len(pred_tokens) / len(src_tokens)) | |
| if args.verbose: | |
| print("SRC:", src) | |
| print("TGT:", tgts[0]) | |
| print("PRED:", pred) | |
| print("=" * 100) | |
| print("="*100) | |
| print("RESULTS:") | |
| print("="*20, "Length (#tokens):", "="*20) | |
| for metric in ("src-len", "tgt-len", "pred-len"): | |
| mean = np.mean(all_scores[metric]) | |
| print(f"{metric}: {mean:.2f}") | |
| print() | |
| print("="*20, "Compression ratio:", "="*20) | |
| for metric in ("tgt-cr", "pred-cr"): | |
| mean = np.mean(all_scores[metric]) | |
| print(f"{metric}: {mean:.2f}") | |
| print() | |
| print("="*20, "Token F1-Score:", "="*20) | |
| mean = np.mean(all_scores["token-f1"]) | |
| print(f"f1-score: {mean:.3f}") | |
| print() | |
| print("="*20, "ROUGE F1-Scores:", "="*20) | |
| for rouge_type in ROUGE_TYPES: | |
| mean = np.mean(all_scores[f"{rouge_type}-f"]) | |
| print(f"{rouge_type}: {mean:.4f}") | |
| print() | |
| print("="*20, "ROUGE Recall:", "="*20) | |
| for rouge_type in ROUGE_TYPES: | |
| mean = np.mean(all_scores[f"{rouge_type}-r"]) | |
| print(f"{rouge_type}: {mean:.4f}") | |
| print() | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset', required=True) | |
| parser.add_argument('--model-dir', required=False) | |
| parser.add_argument('--checkpoint', required=False) | |
| parser.add_argument('--device', default="cpu") | |
| parser.add_argument('--pretokenized', action="store_true") | |
| parser.add_argument('--max-chars', type=int, default=-1) | |
| parser.add_argument('--verbose', action="store_true") | |
| parser.add_argument('--lower-src', action="store_true") | |
| parser.add_argument('--lower-summary', action="store_true") | |
| return parser.parse_args() | |
| if __name__ == '__main__': | |
| main(parse_args()) | |