Spaces:
Running
Running
Commit
·
80d7919
1
Parent(s):
d5750c7
mean for bertscore and bleurt
Browse files- generation_evaluator.py +17 -6
- requirements.txt +2 -1
generation_evaluator.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import datasets
|
| 2 |
import evaluate
|
|
|
|
| 3 |
|
| 4 |
_CITATION = """\
|
| 5 |
@inproceedings{lin-2004-rouge,
|
|
@@ -109,7 +110,7 @@ BLEU:{
|
|
| 109 |
EXACT_MATCH:{
|
| 110 |
"exact_match": exact_match rate. Possible values are between 0.0 and 1.0, inclusive.
|
| 111 |
},
|
| 112 |
-
BERT_SCORE:{
|
| 113 |
"precision": Precision.
|
| 114 |
"recall": Recall.
|
| 115 |
"f1": F1 score.
|
|
@@ -158,22 +159,32 @@ class GenerationEvaluator(evaluate.Metric):
|
|
| 158 |
exact_match_results = exact_match_score.compute(
|
| 159 |
predictions=predictions, references=references
|
| 160 |
)
|
| 161 |
-
|
| 162 |
bert_score = evaluate.load("bertscore")
|
| 163 |
bert_score_results = bert_score.compute(
|
| 164 |
-
predictions=predictions, references=references,
|
| 165 |
-
lang="en"
|
| 166 |
)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
bleurt_score = evaluate.load("bleurt", module_type="metric")
|
| 169 |
bleurt_results = bleurt_score.compute(
|
| 170 |
predictions=predictions, references=references
|
| 171 |
)
|
| 172 |
|
|
|
|
|
|
|
|
|
|
| 173 |
return {
|
| 174 |
"ROUGE": rouge_results,
|
| 175 |
"BLEU": bleu_results,
|
| 176 |
"EXACT_MATCH": exact_match_results,
|
| 177 |
-
"BERT_SCORE":bert_score_results,
|
| 178 |
-
"BLEURT":bleurt_results
|
| 179 |
}
|
|
|
|
| 1 |
import datasets
|
| 2 |
import evaluate
|
| 3 |
+
import numpy as np
|
| 4 |
|
| 5 |
_CITATION = """\
|
| 6 |
@inproceedings{lin-2004-rouge,
|
|
|
|
| 110 |
EXACT_MATCH:{
|
| 111 |
"exact_match": exact_match rate. Possible values are between 0.0 and 1.0, inclusive.
|
| 112 |
},
|
| 113 |
+
BERT_SCORE:{
|
| 114 |
"precision": Precision.
|
| 115 |
"recall": Recall.
|
| 116 |
"f1": F1 score.
|
|
|
|
| 159 |
exact_match_results = exact_match_score.compute(
|
| 160 |
predictions=predictions, references=references
|
| 161 |
)
|
| 162 |
+
|
| 163 |
bert_score = evaluate.load("bertscore")
|
| 164 |
bert_score_results = bert_score.compute(
|
| 165 |
+
predictions=predictions, references=references, lang="en"
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
+
mean_precision = np.mean(bert_score_results['precision'])
|
| 169 |
+
mean_recall = np.mean(bert_score_results['recall'])
|
| 170 |
+
mean_f1 = np.mean(bert_score_results['f1'])
|
| 171 |
+
|
| 172 |
+
bert_score_results['precision'] = round(mean_precision, 4)
|
| 173 |
+
bert_score_results['recall'] = round(mean_recall, 4)
|
| 174 |
+
bert_score_results['f1'] = round(mean_f1, 4)
|
| 175 |
+
|
| 176 |
bleurt_score = evaluate.load("bleurt", module_type="metric")
|
| 177 |
bleurt_results = bleurt_score.compute(
|
| 178 |
predictions=predictions, references=references
|
| 179 |
)
|
| 180 |
|
| 181 |
+
mean_bleurt_score = np.mean(bleurt_results['scores'])
|
| 182 |
+
bleurt_results['scores'] = round(mean_bleurt_score, 4)
|
| 183 |
+
|
| 184 |
return {
|
| 185 |
"ROUGE": rouge_results,
|
| 186 |
"BLEU": bleu_results,
|
| 187 |
"EXACT_MATCH": exact_match_results,
|
| 188 |
+
"BERT_SCORE": bert_score_results,
|
| 189 |
+
"BLEURT": bleurt_results,
|
| 190 |
}
|
requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ datasets
|
|
| 3 |
scikit-learn
|
| 4 |
gradio
|
| 5 |
bert_score
|
| 6 |
-
git+https://github.com/google-research/bleurt.git
|
|
|
|
|
|
| 3 |
scikit-learn
|
| 4 |
gradio
|
| 5 |
bert_score
|
| 6 |
+
git+https://github.com/google-research/bleurt.git
|
| 7 |
+
numpy
|