Spaces:
Running
Running
Fix Title
Browse files- README.md +9 -5
- encoder_models.py +0 -2
- semncg.py +11 -13
- tests.py +19 -8
- type_aliases.py +1 -3
- utils.py +4 -5
README.md
CHANGED
|
@@ -30,11 +30,7 @@ Before using this metric, you need to install the dependencies:
|
|
| 30 |
pip install -U sentence-transformers nltk
|
| 31 |
```
|
| 32 |
|
| 33 |
-
|
| 34 |
-
- `predictions` - List of predictions
|
| 35 |
-
- `references` - List of references
|
| 36 |
-
- `documents` - List of input documents
|
| 37 |
-
|
| 38 |
```python
|
| 39 |
from evaluate import load
|
| 40 |
predictions = [
|
|
@@ -55,6 +51,14 @@ mean_score, scores = metric.compute(predictions=predictions, references=referenc
|
|
| 55 |
print(f"Mean SemnCG: {mean_score}")
|
| 56 |
```
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
Sem-nCG also accepts several optional arguments:
|
| 59 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True
|
| 60 |
- `pre_compute_embeddings (bool)`: Flag to indicate whether to pre-compute embeddings for all sentences. Default=False
|
|
|
|
| 30 |
pip install -U sentence-transformers nltk
|
| 31 |
```
|
| 32 |
|
| 33 |
+
#### Python Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
```python
|
| 35 |
from evaluate import load
|
| 36 |
predictions = [
|
|
|
|
| 51 |
print(f"Mean SemnCG: {mean_score}")
|
| 52 |
```
|
| 53 |
|
| 54 |
+
First step is to initialize the metric as `metric = load("nbansal/semncg", model_name=model_name)` where `model_name` is
|
| 55 |
+
the sentence embedding model. The default value is `all-MiniLM-L6-v2`.
|
| 56 |
+
|
| 57 |
+
To `compute` the Sem-nCG scores, you need to provide three mandatory arguments:
|
| 58 |
+
- `predictions` - List of predictions
|
| 59 |
+
- `references` - List of references
|
| 60 |
+
- `documents` - List of input documents
|
| 61 |
+
|
| 62 |
Sem-nCG also accepts several optional arguments:
|
| 63 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True
|
| 64 |
- `pre_compute_embeddings (bool)`: Flag to indicate whether to pre-compute embeddings for all sentences. Default=False
|
encoder_models.py
CHANGED
|
@@ -125,5 +125,3 @@ def get_sbert_encoder(model_name: str) -> SentenceTransformer:
|
|
| 125 |
raise RuntimeError(str(err)) from None
|
| 126 |
|
| 127 |
return encoder
|
| 128 |
-
|
| 129 |
-
|
|
|
|
| 125 |
raise RuntimeError(str(err)) from None
|
| 126 |
|
| 127 |
return encoder
|
|
|
|
|
|
semncg.py
CHANGED
|
@@ -13,13 +13,12 @@
|
|
| 13 |
# limitations under the License.
|
| 14 |
"""Sem-NCG metric"""
|
| 15 |
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
-
import evaluate
|
| 18 |
-
import datasets
|
| 19 |
-
import re
|
| 20 |
import statistics
|
| 21 |
-
from
|
|
|
|
| 22 |
|
|
|
|
|
|
|
| 23 |
import nltk
|
| 24 |
import numpy as np
|
| 25 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
@@ -27,7 +26,8 @@ from tqdm import tqdm
|
|
| 27 |
|
| 28 |
from .encoder_models import get_sbert_encoder, get_encoder
|
| 29 |
from .type_aliases import DEVICE_TYPE, NDArray, DOCUMENT_TYPE
|
| 30 |
-
from .utils import get_gpu,
|
|
|
|
| 31 |
|
| 32 |
_CITATION = """\
|
| 33 |
@inproceedings{akter-etal-2022-revisiting,
|
|
@@ -128,8 +128,6 @@ Examples:
|
|
| 128 |
"""
|
| 129 |
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
@dataclass
|
| 134 |
class RankedGains:
|
| 135 |
"""
|
|
@@ -154,7 +152,7 @@ class RankedGains:
|
|
| 154 |
k: int
|
| 155 |
ncg: float
|
| 156 |
|
| 157 |
-
|
| 158 |
def compute_cosine_similarity(doc_embeds: NDArray, ref_embeds: NDArray) -> List[float]:
|
| 159 |
"""
|
| 160 |
Compute cosine similarity scores between each document embedding and reference embeddings.
|
|
@@ -333,7 +331,7 @@ def _validate_input_format(
|
|
| 333 |
|
| 334 |
|
| 335 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 336 |
-
class
|
| 337 |
"""
|
| 338 |
SemnCG (Semantic Normalized Cumulative Gain) Metric.
|
| 339 |
|
|
@@ -454,7 +452,7 @@ class SemnCG(evaluate.Metric):
|
|
| 454 |
|
| 455 |
# This is only done for debug case
|
| 456 |
sent_tokenized_documents = documents
|
| 457 |
-
|
| 458 |
# Compute All Embeddings
|
| 459 |
all_sentences = flatten_list(documents) + flatten_list(references) + flatten_list(predictions)
|
| 460 |
embeddings = encoder.encode(all_sentences)
|
|
@@ -467,7 +465,7 @@ class SemnCG(evaluate.Metric):
|
|
| 467 |
doc_embeddings = slice_embeddings(embeddings, document_sentences_count)
|
| 468 |
ref_embeddings = slice_embeddings(embeddings[sum(document_sentences_count):], reference_sentences_count)
|
| 469 |
pred_embeddings = slice_embeddings(
|
| 470 |
-
embeddings[sum(document_sentences_count+reference_sentences_count):], prediction_sentences_count
|
| 471 |
)
|
| 472 |
|
| 473 |
iterable_obj = zip(pred_embeddings, ref_embeddings, doc_embeddings)
|
|
@@ -495,7 +493,7 @@ class SemnCG(evaluate.Metric):
|
|
| 495 |
doc_embeddings = doc
|
| 496 |
ref_embeddings = ref
|
| 497 |
pred_embeddings = pred
|
| 498 |
-
|
| 499 |
doc_sentences = sent_tokenized_documents[idx]
|
| 500 |
|
| 501 |
# Compute Pair-Wise Cosine Similarity
|
|
|
|
| 13 |
# limitations under the License.
|
| 14 |
"""Sem-NCG metric"""
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import statistics
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import List, Tuple, Union
|
| 19 |
|
| 20 |
+
import datasets
|
| 21 |
+
import evaluate
|
| 22 |
import nltk
|
| 23 |
import numpy as np
|
| 24 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 26 |
|
| 27 |
from .encoder_models import get_sbert_encoder, get_encoder
|
| 28 |
from .type_aliases import DEVICE_TYPE, NDArray, DOCUMENT_TYPE
|
| 29 |
+
from .utils import get_gpu, flatten_list, slice_embeddings, is_nested_list_of_type, \
|
| 30 |
+
tokenize_and_prep_document
|
| 31 |
|
| 32 |
_CITATION = """\
|
| 33 |
@inproceedings{akter-etal-2022-revisiting,
|
|
|
|
| 128 |
"""
|
| 129 |
|
| 130 |
|
|
|
|
|
|
|
| 131 |
@dataclass
|
| 132 |
class RankedGains:
|
| 133 |
"""
|
|
|
|
| 152 |
k: int
|
| 153 |
ncg: float
|
| 154 |
|
| 155 |
+
|
| 156 |
def compute_cosine_similarity(doc_embeds: NDArray, ref_embeds: NDArray) -> List[float]:
|
| 157 |
"""
|
| 158 |
Compute cosine similarity scores between each document embedding and reference embeddings.
|
|
|
|
| 331 |
|
| 332 |
|
| 333 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 334 |
+
class SemNCG(evaluate.Metric):
|
| 335 |
"""
|
| 336 |
SemnCG (Semantic Normalized Cumulative Gain) Metric.
|
| 337 |
|
|
|
|
| 452 |
|
| 453 |
# This is only done for debug case
|
| 454 |
sent_tokenized_documents = documents
|
| 455 |
+
|
| 456 |
# Compute All Embeddings
|
| 457 |
all_sentences = flatten_list(documents) + flatten_list(references) + flatten_list(predictions)
|
| 458 |
embeddings = encoder.encode(all_sentences)
|
|
|
|
| 465 |
doc_embeddings = slice_embeddings(embeddings, document_sentences_count)
|
| 466 |
ref_embeddings = slice_embeddings(embeddings[sum(document_sentences_count):], reference_sentences_count)
|
| 467 |
pred_embeddings = slice_embeddings(
|
| 468 |
+
embeddings[sum(document_sentences_count + reference_sentences_count):], prediction_sentences_count
|
| 469 |
)
|
| 470 |
|
| 471 |
iterable_obj = zip(pred_embeddings, ref_embeddings, doc_embeddings)
|
|
|
|
| 493 |
doc_embeddings = doc
|
| 494 |
ref_embeddings = ref
|
| 495 |
pred_embeddings = pred
|
| 496 |
+
|
| 497 |
doc_sentences = sent_tokenized_documents[idx]
|
| 498 |
|
| 499 |
# Compute Pair-Wise Cosine Similarity
|
tests.py
CHANGED
|
@@ -1,16 +1,27 @@
|
|
| 1 |
-
import statistics
|
| 2 |
import unittest
|
| 3 |
-
from unittest.mock import patch, MagicMock
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
-
from numpy.testing import assert_almost_equal
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
|
| 11 |
from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder
|
| 12 |
-
from .semncg import
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class TestUtils(unittest.TestCase):
|
|
@@ -116,7 +127,7 @@ class TestUtils(unittest.TestCase):
|
|
| 116 |
doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count)
|
| 117 |
ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count)
|
| 118 |
pred_embeddings = slice_embeddings(
|
| 119 |
-
all_embeddings[sum(document_sentences_count+reference_sentences_count):], pred_sentences_count
|
| 120 |
)
|
| 121 |
|
| 122 |
self.assertTrue(doc_embeddings, expected_doc_embeddings)
|
|
@@ -350,7 +361,7 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 350 |
class TestSemnCG(unittest.TestCase):
|
| 351 |
def setUp(self):
|
| 352 |
self.model_name = "stsb-distilbert-base"
|
| 353 |
-
self.metric =
|
| 354 |
|
| 355 |
def _basic_assertion(self, result, debug: bool = False):
|
| 356 |
self.assertIsInstance(result, tuple)
|
|
|
|
|
|
|
| 1 |
import unittest
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
| 5 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 6 |
|
| 7 |
from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder
|
| 8 |
+
from .semncg import (
|
| 9 |
+
RankedGains,
|
| 10 |
+
compute_cosine_similarity,
|
| 11 |
+
compute_gain,
|
| 12 |
+
score_ncg,
|
| 13 |
+
compute_ncg,
|
| 14 |
+
_validate_input_format,
|
| 15 |
+
SemNCG
|
| 16 |
+
)
|
| 17 |
+
from .utils import (
|
| 18 |
+
get_gpu,
|
| 19 |
+
slice_embeddings,
|
| 20 |
+
is_nested_list_of_type,
|
| 21 |
+
flatten_list,
|
| 22 |
+
prep_sentences,
|
| 23 |
+
tokenize_and_prep_document
|
| 24 |
+
)
|
| 25 |
|
| 26 |
|
| 27 |
class TestUtils(unittest.TestCase):
|
|
|
|
| 127 |
doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count)
|
| 128 |
ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count)
|
| 129 |
pred_embeddings = slice_embeddings(
|
| 130 |
+
all_embeddings[sum(document_sentences_count + reference_sentences_count):], pred_sentences_count
|
| 131 |
)
|
| 132 |
|
| 133 |
self.assertTrue(doc_embeddings, expected_doc_embeddings)
|
|
|
|
| 361 |
class TestSemnCG(unittest.TestCase):
|
| 362 |
def setUp(self):
|
| 363 |
self.model_name = "stsb-distilbert-base"
|
| 364 |
+
self.metric = SemNCG(self.model_name)
|
| 365 |
|
| 366 |
def _basic_assertion(self, result, debug: bool = False):
|
| 367 |
self.assertIsInstance(result, tuple)
|
type_aliases.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import List, Union, Tuple
|
| 3 |
|
| 4 |
from numpy.typing import NDArray
|
| 5 |
|
| 6 |
-
|
| 7 |
NumSentencesType = Union[List[int], List[List[int]]]
|
| 8 |
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
|
| 9 |
DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
|
|
|
|
| 1 |
+
from typing import List, Union
|
|
|
|
| 2 |
|
| 3 |
from numpy.typing import NDArray
|
| 4 |
|
|
|
|
| 5 |
NumSentencesType = Union[List[int], List[List[int]]]
|
| 6 |
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
|
| 7 |
DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
|
utils.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
-
|
| 2 |
import string
|
| 3 |
-
from typing import List,
|
| 4 |
|
| 5 |
import nltk
|
| 6 |
-
import numpy as np
|
| 7 |
-
from numpy.typing import NDArray
|
| 8 |
import torch
|
|
|
|
| 9 |
|
| 10 |
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
|
| 11 |
|
|
@@ -204,7 +202,8 @@ def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
|
| 204 |
if depth == 0:
|
| 205 |
return isinstance(lst_obj, element_type)
|
| 206 |
elif depth > 0:
|
| 207 |
-
return isinstance(lst_obj, list) and all(
|
|
|
|
| 208 |
else:
|
| 209 |
raise ValueError("Depth can't be negative")
|
| 210 |
|
|
|
|
|
|
|
| 1 |
import string
|
| 2 |
+
from typing import List, Union
|
| 3 |
|
| 4 |
import nltk
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
+
from numpy.typing import NDArray
|
| 7 |
|
| 8 |
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
|
| 9 |
|
|
|
|
| 202 |
if depth == 0:
|
| 203 |
return isinstance(lst_obj, element_type)
|
| 204 |
elif depth > 0:
|
| 205 |
+
return isinstance(lst_obj, list) and all(
|
| 206 |
+
is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
|
| 207 |
else:
|
| 208 |
raise ValueError("Depth can't be negative")
|
| 209 |
|