Spaces:
Sleeping
Sleeping
| # Copyright 2020 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import doctest | |
| import glob | |
| import importlib | |
| import inspect | |
| import os | |
| import re | |
| from contextlib import contextmanager | |
| from functools import wraps | |
| from unittest.mock import patch | |
| import numpy as np | |
| import pytest | |
| from absl.testing import parameterized | |
| import evaluate | |
| from evaluate import load | |
| from .utils import _run_slow_tests, for_all_test_methods, local, slow | |
| REQUIRE_FAIRSEQ = {"comet"} | |
| _has_fairseq = importlib.util.find_spec("fairseq") is not None | |
| UNSUPPORTED_ON_WINDOWS = {"code_eval"} | |
| _on_windows = os.name == "nt" | |
| SLOW_METRIC = {"perplexity", "regard", "toxicity"} | |
| def skip_if_metric_requires_fairseq(test_case): | |
| def wrapper(self, evaluation_module_name, evaluation_module_type): | |
| if not _has_fairseq and evaluation_module_name in REQUIRE_FAIRSEQ: | |
| self.skipTest('"test requires Fairseq"') | |
| else: | |
| test_case(self, evaluation_module_name, evaluation_module_type) | |
| return wrapper | |
| def skip_on_windows_if_not_windows_compatible(test_case): | |
| def wrapper(self, evaluation_module_name, evaluation_module_type): | |
| if _on_windows and evaluation_module_name in UNSUPPORTED_ON_WINDOWS: | |
| self.skipTest('"test not supported on Windows"') | |
| else: | |
| test_case(self, evaluation_module_name, evaluation_module_type) | |
| return wrapper | |
| def skip_slow_metrics(test_case): | |
| def wrapper(self, evaluation_module_name, evaluation_module_type): | |
| if not _run_slow_tests and evaluation_module_name in SLOW_METRIC: | |
| self.skipTest('"test is slow"') | |
| else: | |
| test_case(self, evaluation_module_name, evaluation_module_type) | |
| return wrapper | |
| def get_local_module_names(): | |
| metrics = [metric_dir.split(os.sep)[-2] for metric_dir in glob.glob("./metrics/*/")] | |
| comparisons = [metric_dir.split(os.sep)[-2] for metric_dir in glob.glob("./comparisons/*/")] | |
| measurements = [metric_dir.split(os.sep)[-2] for metric_dir in glob.glob("./measurements/*/")] | |
| evaluation_modules = metrics + comparisons + measurements | |
| evaluation_module_types = ( | |
| ["metric"] * len(metrics) + ["comparison"] * len(comparisons) + ["measurement"] * len(measurements) | |
| ) | |
| return [ | |
| {"testcase_name": f"{t}_{x}", "evaluation_module_name": x, "evaluation_module_type": t} | |
| for x, t in zip(evaluation_modules, evaluation_module_types) | |
| if x != "gleu" # gleu is unfinished | |
| ] | |
| class LocalModuleTest(parameterized.TestCase): | |
| INTENSIVE_CALLS_PATCHER = {} | |
| evaluation_module_name = None | |
| evaluation_module_type = None | |
| def test_load(self, evaluation_module_name, evaluation_module_type): | |
| doctest.ELLIPSIS_MARKER = "[...]" | |
| evaluation_module = importlib.import_module( | |
| evaluate.loading.evaluation_module_factory( | |
| os.path.join(evaluation_module_type + "s", evaluation_module_name), module_type=evaluation_module_type | |
| ).module_path | |
| ) | |
| evaluation_instance = evaluate.loading.import_main_class(evaluation_module.__name__) | |
| # check parameters | |
| parameters = inspect.signature(evaluation_instance._compute).parameters | |
| self.assertTrue(all([p.kind != p.VAR_KEYWORD for p in parameters.values()])) # no **kwargs | |
| # run doctest | |
| with self.patch_intensive_calls(evaluation_module_name, evaluation_module.__name__): | |
| with self.use_local_metrics(evaluation_module_type): | |
| try: | |
| results = doctest.testmod(evaluation_module, verbose=True, raise_on_error=True) | |
| except doctest.UnexpectedException as e: | |
| raise e.exc_info[1] # raise the exception that doctest caught | |
| self.assertEqual(results.failed, 0) | |
| self.assertGreater(results.attempted, 1) | |
| def test_load_real_metric(self, evaluation_module_name, evaluation_module_type): | |
| doctest.ELLIPSIS_MARKER = "[...]" | |
| metric_module = importlib.import_module( | |
| evaluate.loading.evaluation_module_factory( | |
| os.path.join(evaluation_module_type, evaluation_module_name) | |
| ).module_path | |
| ) | |
| # run doctest | |
| with self.use_local_metrics(): | |
| results = doctest.testmod(metric_module, verbose=True, raise_on_error=True) | |
| self.assertEqual(results.failed, 0) | |
| self.assertGreater(results.attempted, 1) | |
| def patch_intensive_calls(self, evaluation_module_name, module_name): | |
| if evaluation_module_name in self.INTENSIVE_CALLS_PATCHER: | |
| with self.INTENSIVE_CALLS_PATCHER[evaluation_module_name](module_name): | |
| yield | |
| else: | |
| yield | |
| def use_local_metrics(self, evaluation_module_type): | |
| def load_local_metric(evaluation_module_name, *args, **kwargs): | |
| return load(os.path.join(evaluation_module_type + "s", evaluation_module_name), *args, **kwargs) | |
| with patch("evaluate.load") as mock_load: | |
| mock_load.side_effect = load_local_metric | |
| yield | |
| def register_intensive_calls_patcher(cls, evaluation_module_name): | |
| def wrapper(patcher): | |
| patcher = contextmanager(patcher) | |
| cls.INTENSIVE_CALLS_PATCHER[evaluation_module_name] = patcher | |
| return patcher | |
| return wrapper | |
| # Metrics intensive calls patchers | |
| # -------------------------------- | |
| def patch_bleurt(module_name): | |
| import tensorflow.compat.v1 as tf | |
| from bleurt.score import Predictor | |
| tf.flags.DEFINE_string("sv", "", "") # handle pytest cli flags | |
| class MockedPredictor(Predictor): | |
| def predict(self, input_dict): | |
| assert len(input_dict["input_ids"]) == 2 | |
| return np.array([1.03, 1.04]) | |
| # mock predict_fn which is supposed to do a forward pass with a bleurt model | |
| with patch("bleurt.score._create_predictor") as mock_create_predictor: | |
| mock_create_predictor.return_value = MockedPredictor() | |
| yield | |
| def patch_bertscore(module_name): | |
| import torch | |
| def bert_cos_score_idf(model, refs, *args, **kwargs): | |
| return torch.tensor([[1.0, 1.0, 1.0]] * len(refs)) | |
| # mock get_model which is supposed to do download a bert model | |
| # mock bert_cos_score_idf which is supposed to do a forward pass with a bert model | |
| with patch("bert_score.scorer.get_model"), patch( | |
| "bert_score.scorer.bert_cos_score_idf" | |
| ) as mock_bert_cos_score_idf: | |
| mock_bert_cos_score_idf.side_effect = bert_cos_score_idf | |
| yield | |
| def patch_comet(module_name): | |
| def load_from_checkpoint(model_path): | |
| class Model: | |
| def predict(self, data, *args, **kwargs): | |
| assert len(data) == 2 | |
| scores = [0.19, 0.92] | |
| return scores, sum(scores) / len(scores) | |
| return Model() | |
| # mock load_from_checkpoint which is supposed to do download a bert model | |
| # mock load_from_checkpoint which is supposed to do download a bert model | |
| with patch("comet.download_model") as mock_download_model: | |
| mock_download_model.return_value = None | |
| with patch("comet.load_from_checkpoint") as mock_load_from_checkpoint: | |
| mock_load_from_checkpoint.side_effect = load_from_checkpoint | |
| yield | |
| def test_seqeval_raises_when_incorrect_scheme(): | |
| metric = load(os.path.join("metrics", "seqeval")) | |
| wrong_scheme = "ERROR" | |
| error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}" | |
| with pytest.raises(ValueError, match=re.escape(error_message)): | |
| metric.compute(predictions=[], references=[], scheme=wrong_scheme) | |