Spaces:
Running
Running
Ian Borrego Obrador
commited on
Commit
·
faf189f
1
Parent(s):
1b5f36c
generalization in eval_args
Browse files- generation_evaluator.py +20 -10
- requirements.txt +2 -1
generation_evaluator.py
CHANGED
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
| 5 |
import spacy
|
| 6 |
import torch
|
| 7 |
from alignscore import AlignScore
|
|
|
|
| 8 |
|
| 9 |
_CITATION = """\
|
| 10 |
@inproceedings{lin-2004-rouge,
|
|
@@ -155,9 +156,7 @@ class GenerationEvaluator(evaluate.Metric):
|
|
| 155 |
# Download AlignScore model and move to GPU if possible
|
| 156 |
model_path = dl_manager.download(ALIGNSCORE_ARGS["ckpt_path"])
|
| 157 |
ALIGNSCORE_ARGS["ckpt_path"] = model_path
|
| 158 |
-
ALIGNSCORE_ARGS["device"] = (
|
| 159 |
-
"cuda:0" if torch.cuda.is_available() else "cpu"
|
| 160 |
-
)
|
| 161 |
self.align_scorer = AlignScore(**ALIGNSCORE_ARGS)
|
| 162 |
|
| 163 |
# Prepare scorers
|
|
@@ -167,20 +166,33 @@ class GenerationEvaluator(evaluate.Metric):
|
|
| 167 |
self.bert_scorer = evaluate.load("bertscore")
|
| 168 |
self.chrf_scorer = evaluate.load("chrf")
|
| 169 |
|
| 170 |
-
def _compute(self, predictions, references,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# Compute ROUGE
|
| 172 |
rouge_results = self.rouge_scorer.compute(
|
| 173 |
-
predictions=predictions,
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
|
| 176 |
# Compute BLEU
|
| 177 |
if tokenizer is None:
|
| 178 |
bleu_results = self.bleu_scorer.compute(
|
| 179 |
-
predictions=predictions, references=references
|
| 180 |
)
|
| 181 |
else:
|
| 182 |
bleu_results = self.bleu_scorer.compute(
|
| 183 |
-
predictions=predictions,
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
# Compute Exact Match
|
|
@@ -203,9 +215,7 @@ class GenerationEvaluator(evaluate.Metric):
|
|
| 203 |
|
| 204 |
# Compute AlignScore
|
| 205 |
align_score = round(
|
| 206 |
-
np.mean(
|
| 207 |
-
self.align_scorer.score(contexts=references, claims=predictions)
|
| 208 |
-
),
|
| 209 |
4,
|
| 210 |
)
|
| 211 |
|
|
|
|
| 5 |
import spacy
|
| 6 |
import torch
|
| 7 |
from alignscore import AlignScore
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
|
| 10 |
_CITATION = """\
|
| 11 |
@inproceedings{lin-2004-rouge,
|
|
|
|
| 156 |
# Download AlignScore model and move to GPU if possible
|
| 157 |
model_path = dl_manager.download(ALIGNSCORE_ARGS["ckpt_path"])
|
| 158 |
ALIGNSCORE_ARGS["ckpt_path"] = model_path
|
| 159 |
+
ALIGNSCORE_ARGS["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 160 |
self.align_scorer = AlignScore(**ALIGNSCORE_ARGS)
|
| 161 |
|
| 162 |
# Prepare scorers
|
|
|
|
| 166 |
self.bert_scorer = evaluate.load("bertscore")
|
| 167 |
self.chrf_scorer = evaluate.load("chrf")
|
| 168 |
|
| 169 |
+
def _compute(self, predictions, references, **eval_kwargs):
|
| 170 |
+
tokenizer_name = eval_kwargs.pop("tokenizer_name", None)
|
| 171 |
+
tokenizer = None
|
| 172 |
+
|
| 173 |
+
if tokenizer_name is not None:
|
| 174 |
+
tks = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 175 |
+
tokenizer = tks.tokenize
|
| 176 |
+
|
| 177 |
# Compute ROUGE
|
| 178 |
rouge_results = self.rouge_scorer.compute(
|
| 179 |
+
predictions=predictions,
|
| 180 |
+
references=references,
|
| 181 |
+
tokenizer=tokenizer,
|
| 182 |
+
**eval_kwargs
|
| 183 |
)
|
| 184 |
|
| 185 |
# Compute BLEU
|
| 186 |
if tokenizer is None:
|
| 187 |
bleu_results = self.bleu_scorer.compute(
|
| 188 |
+
predictions=predictions, references=references, **eval_kwargs
|
| 189 |
)
|
| 190 |
else:
|
| 191 |
bleu_results = self.bleu_scorer.compute(
|
| 192 |
+
predictions=predictions,
|
| 193 |
+
references=references,
|
| 194 |
+
tokenizer=tokenizer,
|
| 195 |
+
**eval_kwargs
|
| 196 |
)
|
| 197 |
|
| 198 |
# Compute Exact Match
|
|
|
|
| 215 |
|
| 216 |
# Compute AlignScore
|
| 217 |
align_score = round(
|
| 218 |
+
np.mean(self.align_scorer.score(contexts=references, claims=predictions)),
|
|
|
|
|
|
|
| 219 |
4,
|
| 220 |
)
|
| 221 |
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ rouge_score
|
|
| 7 |
numpy
|
| 8 |
sacrebleu
|
| 9 |
git+https://github.com/yuh-zha/AlignScore.git
|
| 10 |
-
spacy
|
|
|
|
|
|
| 7 |
numpy
|
| 8 |
sacrebleu
|
| 9 |
git+https://github.com/yuh-zha/AlignScore.git
|
| 10 |
+
spacy
|
| 11 |
+
transformers
|