Spaces:
Running
Running
| import unittest | |
| import numpy as np | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder | |
| from .semncg import ( | |
| RankedGains, | |
| compute_cosine_similarity, | |
| compute_gain, | |
| score_ncg, | |
| compute_ncg, | |
| _validate_input_format, | |
| SemNCG | |
| ) | |
| from .utils import ( | |
| get_gpu, | |
| slice_embeddings, | |
| is_nested_list_of_type, | |
| flatten_list, | |
| prep_sentences, | |
| tokenize_and_prep_document | |
| ) | |
| class TestUtils(unittest.TestCase): | |
| def test_get_gpu(self): | |
| gpu_count = torch.cuda.device_count() | |
| gpu_available = torch.cuda.is_available() | |
| # Test single boolean input | |
| self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu") | |
| self.assertEqual(get_gpu(False), "cpu") | |
| # Test single string input | |
| self.assertEqual(get_gpu("cpu"), "cpu") | |
| self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu") | |
| self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu") | |
| # Test single integer input | |
| self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu") | |
| self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu") | |
| # Test list input with unique elements | |
| self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"]) | |
| # Test list input with duplicate elements | |
| self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]) | |
| # Test list input with duplicate elements of different types | |
| self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]) | |
| # Test list input but only one element | |
| self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu") | |
| # Test list input with all integers | |
| self.assertEqual(get_gpu(list(range(gpu_count))), | |
| list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"]) | |
| with self.assertRaises(ValueError): | |
| get_gpu("invalid") | |
| with self.assertRaises(ValueError): | |
| get_gpu(torch.cuda.device_count()) | |
| def test_prep_sentences(self): | |
| # Test normal case | |
| self.assertEqual(prep_sentences(["Hello, world!", " This is a test. ", "!!!"]), | |
| ['Hello, world!', 'This is a test.']) | |
| # Test case with only punctuations | |
| with self.assertRaises(ValueError): | |
| prep_sentences(["!!!", "..."]) | |
| # Test case with empty list | |
| with self.assertRaises(ValueError): | |
| prep_sentences([]) | |
| def test_tokenize_and_prep_document(self): | |
| # Test tokenize=True with string input | |
| self.assertEqual(tokenize_and_prep_document("Hello, world! This is a test.", True), | |
| ['Hello, world!', 'This is a test.']) | |
| # Test tokenize=False with list of strings input | |
| self.assertEqual(tokenize_and_prep_document(["Hello, world!", "This is a test."], False), | |
| ['Hello, world!', 'This is a test.']) | |
| # Test tokenize=True with empty document | |
| with self.assertRaises(ValueError): | |
| tokenize_and_prep_document("!!! ...", True) | |
| def test_slice_embeddings(self): | |
| # Case 1 | |
| embeddings = np.random.rand(10, 5) | |
| num_sentences = [3, 2, 5] | |
| expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]] | |
| self.assertTrue( | |
| all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences), | |
| expected_output)) | |
| ) | |
| # Case 2 | |
| num_sentences_nested = [[2, 1], [3, 4]] | |
| expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] | |
| self.assertTrue( | |
| slice_embeddings(embeddings, num_sentences_nested), expected_output_nested | |
| ) | |
| # Case 3 | |
| document_sentences_count = [10, 8, 7] | |
| reference_sentences_count = [5, 3, 2] | |
| pred_sentences_count = [2, 2, 1] | |
| all_embeddings = np.random.rand( | |
| sum(document_sentences_count + reference_sentences_count + pred_sentences_count), 5, | |
| ) | |
| embeddings = all_embeddings | |
| expected_doc_embeddings = [embeddings[:10], embeddings[10:18], embeddings[18:25]] | |
| embeddings = all_embeddings[25:] | |
| expected_ref_embeddings = [embeddings[:5], embeddings[5:8], embeddings[8:10]] | |
| embeddings = all_embeddings[35:] | |
| expected_pred_embeddings = [embeddings[:2], embeddings[2:4], embeddings[4:5]] | |
| doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count) | |
| ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count) | |
| pred_embeddings = slice_embeddings( | |
| all_embeddings[sum(document_sentences_count + reference_sentences_count):], pred_sentences_count | |
| ) | |
| self.assertTrue(doc_embeddings, expected_doc_embeddings) | |
| self.assertTrue(ref_embeddings, expected_ref_embeddings) | |
| self.assertTrue(pred_embeddings, expected_pred_embeddings) | |
| with self.assertRaises(TypeError): | |
| slice_embeddings(embeddings, "invalid") | |
| def test_is_nested_list_of_type(self): | |
| # Test case: Depth 0, single element matching element_type | |
| self.assertEqual(is_nested_list_of_type("test", str, 0), (True, "")) | |
| # Test case: Depth 0, single element not matching element_type | |
| is_valid, err_msg = is_nested_list_of_type("test", int, 0) | |
| self.assertEqual(is_valid, False) | |
| # Test case: Depth 1, list of elements matching element_type | |
| self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, "")) | |
| # Test case: Depth 1, list of elements not matching element_type | |
| is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1) | |
| self.assertEqual(is_valid, False) | |
| # Test case: Depth 0 (Wrong), list of elements matching element_type | |
| is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0) | |
| self.assertEqual(is_valid, False) | |
| # Depth 2 | |
| self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, "")) | |
| self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, "")) | |
| is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) | |
| self.assertEqual(is_valid, False) | |
| # Depth 3 | |
| is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3) | |
| self.assertEqual(is_valid, False) | |
| self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, "")) | |
| # Test case: Depth is negative, expecting ValueError | |
| with self.assertRaises(ValueError): | |
| is_nested_list_of_type([1, 2], int, -1) | |
| def test_flatten_list(self): | |
| self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5]) | |
| self.assertEqual(flatten_list([]), []) | |
| self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3]) | |
| self.assertEqual(flatten_list([[[[1]]]]), [1]) | |
| class TestSBertEncoder(unittest.TestCase): | |
| def setUp(self) -> None: | |
| # Set up a test SentenceTransformer model | |
| self.model_name = "paraphrase-distilroberta-base-v1" | |
| self.sbert_model = get_sbert_encoder(self.model_name) | |
| self.device = "cpu" # For testing on CPU | |
| self.batch_size = 32 | |
| self.verbose = False | |
| self.encoder = SBertEncoder(self.sbert_model, self.device, self.batch_size, self.verbose) | |
| def test_encode_single_sentence(self): | |
| sentence = "Hello, world!" | |
| embeddings = self.encoder.encode([sentence]) | |
| self.assertEqual(embeddings.shape, (1, 768)) # Adjust shape based on your model's embedding dimension | |
| def test_encode_multiple_sentences(self): | |
| sentences = ["Hello, world!", "This is a test."] | |
| embeddings = self.encoder.encode(sentences) | |
| self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension | |
| def test_get_sbert_encoder(self): | |
| model_name = "paraphrase-distilroberta-base-v1" | |
| sbert_model = get_sbert_encoder(model_name) | |
| self.assertIsInstance(sbert_model, SentenceTransformer) | |
| def test_encode_with_gpu(self): | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| encoder = get_encoder(self.sbert_model, device, self.batch_size, self.verbose) | |
| sentences = ["Hello, world!", "This is a test."] | |
| embeddings = encoder.encode(sentences) | |
| self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension | |
| else: | |
| self.skipTest("CUDA not available, skipping GPU test.") | |
| def test_encode_multi_device(self): | |
| if torch.cuda.device_count() < 2: | |
| self.skipTest("Multi-GPU test requires at least 2 GPUs.") | |
| else: | |
| devices = ["cuda:0", "cuda:1"] | |
| encoder = get_encoder(self.sbert_model, devices, self.batch_size, self.verbose) | |
| sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."] | |
| embeddings = encoder.encode(sentences) | |
| self.assertIsInstance(embeddings, np.ndarray) | |
| self.assertEqual(embeddings.shape[0], 3) | |
| self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()) | |
| class TestGetEncoder(unittest.TestCase): | |
| def setUp(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.batch_size = 8 | |
| self.verbose = False | |
| def _base_test(self, model_name): | |
| sbert_model = get_sbert_encoder(model_name) | |
| encoder = get_encoder(sbert_model, self.device, self.batch_size, self.verbose) | |
| # Assert | |
| self.assertIsInstance(encoder, SBertEncoder) | |
| self.assertEqual(encoder.device, self.device) | |
| self.assertEqual(encoder.batch_size, self.batch_size) | |
| self.assertEqual(encoder.verbose, self.verbose) | |
| def test_get_sbert_encoder(self): | |
| model_name = "stsb-roberta-large" | |
| self._base_test(model_name) | |
| def test_sbert_model(self): | |
| model_name = "all-mpnet-base-v2" | |
| self._base_test(model_name) | |
| def test_huggingface_model(self): | |
| """Test Huggingface models which work with SBert library""" | |
| model_name = "roberta-base" | |
| self._base_test(model_name) | |
| def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator | |
| model_name = "abc" # Wrong model_name | |
| with self.assertRaises(EnvironmentError): | |
| get_sbert_encoder(model_name) | |
| def test_get_encoder_other_exception(self): | |
| model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib | |
| with self.assertRaises(RuntimeError): | |
| get_sbert_encoder(model_name) | |
| class TestRankedGainsDataclass(unittest.TestCase): | |
| def test_ranked_gains_dataclass(self): | |
| # Test initialization and attribute access | |
| gt_gains = [("doc1", 0.8), ("doc2", 0.6)] | |
| pred_gains = [("doc2", 0.7), ("doc1", 0.5)] | |
| k = 2 | |
| ncg = 0.75 | |
| ranked_gains = RankedGains(gt_gains, pred_gains, k, ncg) | |
| self.assertEqual(ranked_gains.gt_gains, gt_gains) | |
| self.assertEqual(ranked_gains.pred_gains, pred_gains) | |
| self.assertEqual(ranked_gains.k, k) | |
| self.assertEqual(ranked_gains.ncg, ncg) | |
| class TestComputeCosineSimilarity(unittest.TestCase): | |
| def test_compute_cosine_similarity(self): | |
| doc_embeds = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) | |
| ref_embeds = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]]) | |
| # Test compute_cosine_similarity function | |
| similarity_scores = compute_cosine_similarity(doc_embeds, ref_embeds) | |
| print(similarity_scores) | |
| # Example values, change as per actual function output | |
| expected_scores = [0.980, 0.997] | |
| self.assertAlmostEqual(similarity_scores[0], expected_scores[0], places=3) | |
| self.assertAlmostEqual(similarity_scores[1], expected_scores[1], places=3) | |
| class TestComputeGain(unittest.TestCase): | |
| def test_compute_gain(self): | |
| # Test compute_gain function | |
| sim_scores = [0.8, 0.6, 0.7] | |
| gains = compute_gain(sim_scores) | |
| print(gains) | |
| # Example values, change as per actual function output | |
| expected_gains = [(0, 0.5), (2, 0.3333333333333333), (1, 0.16666666666666666)] | |
| self.assertEqual(gains, expected_gains) | |
| class TestScoreNcg(unittest.TestCase): | |
| def test_score_ncg(self): | |
| # Test score_ncg function | |
| model_relevance = [0.8, 0.7, 0.6] | |
| gt_relevance = [1.0, 0.9, 0.8] | |
| ncg_score = score_ncg(model_relevance, gt_relevance) | |
| expected_ncg = 0.778 # Example value, change as per actual function output | |
| self.assertAlmostEqual(ncg_score, expected_ncg, places=3) | |
| class TestComputeNcg(unittest.TestCase): | |
| def test_compute_ncg(self): | |
| # Test compute_ncg function | |
| pred_gains = [(0, 0.8), (2, 0.7), (1, 0.6)] | |
| gt_gains = [(0, 1.0), (1, 0.9), (2, 0.8)] | |
| k = 3 | |
| ncg_score = compute_ncg(pred_gains, gt_gains, k) | |
| expected_ncg = 1.0 # TODO: Confirm this with Dr. Santu | |
| self.assertAlmostEqual(ncg_score, expected_ncg, places=6) | |
| class TestValidateInputFormat(unittest.TestCase): | |
| def test_validate_input_format(self): | |
| # Test _validate_input_format function | |
| tokenize_sentences = True | |
| predictions = ["Prediction 1", "Prediction 2"] | |
| references = ["Reference 1", "Reference 2"] | |
| documents = ["Document 1", "Document 2"] | |
| # No exception should be raised for valid input | |
| try: | |
| _validate_input_format(tokenize_sentences, predictions, references, documents) | |
| except ValueError as e: | |
| self.fail(f"_validate_input_format raised ValueError unexpectedly: {str(e)}") | |
| # Test invalid input format | |
| predictions_invalid = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."], | |
| ["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]] | |
| references_invalid = [["Sentences in reference 1."], ["Sentences in reference 2."]] | |
| documents_invalid = [["Sentence 1 in document 1.", "Sentence 2 in document 1."], | |
| ["Sentence 1 in document 2.", "Sentence 2 in document 2."]] | |
| with self.assertRaises(ValueError): | |
| _validate_input_format(tokenize_sentences, predictions_invalid, references, documents) | |
| with self.assertRaises(ValueError): | |
| _validate_input_format(tokenize_sentences, predictions, references_invalid, documents) | |
| with self.assertRaises(ValueError): | |
| _validate_input_format(tokenize_sentences, predictions, references, documents_invalid) | |
| class TestSemNCG(unittest.TestCase): | |
| def setUp(self): | |
| self.model_name = "stsb-distilbert-base" | |
| self.metric = SemNCG(self.model_name) | |
| def _basic_assertion(self, result, debug: bool = False): | |
| self.assertIsInstance(result, tuple) | |
| self.assertEqual(len(result), 2) | |
| self.assertIsInstance(result[0], float) | |
| self.assertTrue(0.0 <= result[0] <= 1.0) | |
| self.assertIsInstance(result[1], list) | |
| if debug: | |
| for ranked_gain in result[1]: | |
| self.assertTrue(isinstance(ranked_gain, RankedGains)) | |
| self.assertTrue(0.0 <= ranked_gain.ncg <= 1.0) | |
| else: | |
| for gain in result[1]: | |
| self.assertTrue(isinstance(gain, float)) | |
| self.assertTrue(0.0 <= gain <= 1.0) | |
| def test_compute_basic(self): | |
| predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
| references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
| documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
| result = self.metric.compute(predictions=predictions, references=references, documents=documents) | |
| self._basic_assertion(result) | |
| def test_compute_with_tokenization(self): | |
| predictions = [["The cat sat on the mat."], ["The quick brown fox jumps over the lazy dog."]] | |
| references = [["A cat was sitting on a mat."], ["A quick brown fox jumped over a lazy dog."]] | |
| documents = [["There was a cat on a mat."], ["The quick brown fox jumped over the lazy dog."]] | |
| result = self.metric.compute( | |
| predictions=predictions, references=references, documents=documents, tokenize_sentences=False | |
| ) | |
| self._basic_assertion(result) | |
| def test_compute_with_pre_compute_embeddings(self): | |
| predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
| references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
| documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
| result = self.metric.compute( | |
| predictions=predictions, references=references, documents=documents, pre_compute_embeddings=True | |
| ) | |
| self._basic_assertion(result) | |
| def test_compute_with_debug(self): | |
| predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
| references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
| documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
| result = self.metric.compute( | |
| predictions=predictions, references=references, documents=documents, debug=True | |
| ) | |
| self._basic_assertion(result, debug=True) | |
| def test_compute_invalid_input_format(self): | |
| predictions = "The cat sat on the mat." | |
| references = ["A cat was sitting on a mat."] | |
| documents = ["There was a cat on a mat."] | |
| with self.assertRaises(ValueError): | |
| self.metric.compute(predictions=predictions, references=references, documents=documents) | |
| def test_bad_inputs(self): | |
| def _call_metric(preds, refs, docs, tok): | |
| with self.assertRaises(Exception) as ctx: | |
| _ = self.metric.compute( | |
| predictions=preds, | |
| references=refs, | |
| documents=docs, | |
| tokenize_sentences=tok, | |
| pre_compute_embeddings=True, | |
| ) | |
| print(f"Raised Exception with message: {ctx.exception}") | |
| return "" | |
| # None Inputs | |
| # Case I | |
| tokenize_sentences = True | |
| predictions = [None] | |
| references = ["A cat was sitting on a mat."] | |
| documents = ["There was a cat on a mat."] | |
| print(f"Case I\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
| # Case II | |
| tokenize_sentences = False | |
| predictions = [["A cat was sitting on a mat.", None]] | |
| references = [["A cat was sitting on a mat.", "A cat was sitting on a mat."]] | |
| documents = [["There was a cat on a mat.", "There was a cat on a mat."]] | |
| print(f"Case II\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
| # Empty Input | |
| tokenize_sentences = True | |
| predictions = [] | |
| references = ["A cat was sitting on a mat."] | |
| documents = ["There was a cat on a mat."] | |
| print(f"Case: Empty Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
| # Empty String Input | |
| tokenize_sentences = True | |
| predictions = [""] | |
| references = ["A cat was sitting on a mat."] | |
| documents = ["There was a cat on a mat."] | |
| print(f"Case: Empty String Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
| def _test_check_verbose(self): | |
| """UNUSED: previously used to manually check the progress bar | |
| This test should not be used since they rely on files that are | |
| not kept in version control. this is purely just left here for | |
| historical purposes and has the '_' prepended to the function | |
| name to avoid being executed. | |
| """ | |
| import sqlite3 | |
| import string | |
| con = sqlite3.connect('sem_ncg_samples.db') | |
| cur = con.cursor() | |
| data = cur.execute( | |
| 'SELECT * FROM sem_ncg_samples').fetchmany(100) | |
| data = list(filter( | |
| lambda x: x[0].translate( | |
| str.maketrans('', '', string.punctuation) | |
| ).strip() != '', | |
| data | |
| )) | |
| preds, refs, docs = list(zip(*data)) | |
| result = self.metric.compute( | |
| predictions=preds, references=refs, | |
| documents=docs, verbose=True, | |
| gpu=2 | |
| ) | |
| breakpoint() | |
| if __name__ == '__main__': | |
| unittest.main(verbosity=2) | |