Spaces:
Runtime error
Runtime error
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import re | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| import transformers | |
| from transformers.commands.add_new_model_like import ( | |
| ModelPatterns, | |
| _re_class_func, | |
| add_content_to_file, | |
| add_content_to_text, | |
| clean_frameworks_in_init, | |
| duplicate_doc_file, | |
| duplicate_module, | |
| filter_framework_files, | |
| find_base_model_checkpoint, | |
| get_model_files, | |
| get_module_from_file, | |
| parse_module_content, | |
| replace_model_patterns, | |
| retrieve_info_for_model, | |
| retrieve_model_classes, | |
| simplify_replacements, | |
| ) | |
| from transformers.testing_utils import require_flax, require_tf, require_torch | |
| BERT_MODEL_FILES = { | |
| "src/transformers/models/bert/__init__.py", | |
| "src/transformers/models/bert/configuration_bert.py", | |
| "src/transformers/models/bert/tokenization_bert.py", | |
| "src/transformers/models/bert/tokenization_bert_fast.py", | |
| "src/transformers/models/bert/tokenization_bert_tf.py", | |
| "src/transformers/models/bert/modeling_bert.py", | |
| "src/transformers/models/bert/modeling_flax_bert.py", | |
| "src/transformers/models/bert/modeling_tf_bert.py", | |
| "src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py", | |
| "src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py", | |
| "src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py", | |
| "src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py", | |
| } | |
| VIT_MODEL_FILES = { | |
| "src/transformers/models/vit/__init__.py", | |
| "src/transformers/models/vit/configuration_vit.py", | |
| "src/transformers/models/vit/convert_dino_to_pytorch.py", | |
| "src/transformers/models/vit/convert_vit_timm_to_pytorch.py", | |
| "src/transformers/models/vit/feature_extraction_vit.py", | |
| "src/transformers/models/vit/image_processing_vit.py", | |
| "src/transformers/models/vit/modeling_vit.py", | |
| "src/transformers/models/vit/modeling_tf_vit.py", | |
| "src/transformers/models/vit/modeling_flax_vit.py", | |
| } | |
| WAV2VEC2_MODEL_FILES = { | |
| "src/transformers/models/wav2vec2/__init__.py", | |
| "src/transformers/models/wav2vec2/configuration_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py", | |
| "src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py", | |
| "src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/modeling_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/processing_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/tokenization_wav2vec2.py", | |
| } | |
| REPO_PATH = Path(transformers.__path__[0]).parent.parent | |
| class TestAddNewModelLike(unittest.TestCase): | |
| def init_file(self, file_name, content): | |
| with open(file_name, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| def check_result(self, file_name, expected_result): | |
| with open(file_name, "r", encoding="utf-8") as f: | |
| result = f.read() | |
| self.assertEqual(result, expected_result) | |
| def test_re_class_func(self): | |
| self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function") | |
| self.assertEqual(_re_class_func.search("class MyClass:").groups()[0], "MyClass") | |
| self.assertEqual(_re_class_func.search("class MyClass(SuperClass):").groups()[0], "MyClass") | |
| def test_model_patterns_defaults(self): | |
| model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base") | |
| self.assertEqual(model_patterns.model_type, "gpt-new-new") | |
| self.assertEqual(model_patterns.model_lower_cased, "gpt_new_new") | |
| self.assertEqual(model_patterns.model_camel_cased, "GPTNewNew") | |
| self.assertEqual(model_patterns.model_upper_cased, "GPT_NEW_NEW") | |
| self.assertEqual(model_patterns.config_class, "GPTNewNewConfig") | |
| self.assertIsNone(model_patterns.tokenizer_class) | |
| self.assertIsNone(model_patterns.feature_extractor_class) | |
| self.assertIsNone(model_patterns.processor_class) | |
| def test_parse_module_content(self): | |
| test_code = """SOME_CONSTANT = a constant | |
| CONSTANT_DEFINED_ON_SEVERAL_LINES = [ | |
| first_item, | |
| second_item | |
| ] | |
| def function(args): | |
| some code | |
| # Copied from transformers.some_module | |
| class SomeClass: | |
| some code | |
| """ | |
| expected_parts = [ | |
| "SOME_CONSTANT = a constant\n", | |
| "CONSTANT_DEFINED_ON_SEVERAL_LINES = [\n first_item,\n second_item\n]", | |
| "", | |
| "def function(args):\n some code\n", | |
| "# Copied from transformers.some_module\nclass SomeClass:\n some code\n", | |
| ] | |
| self.assertEqual(parse_module_content(test_code), expected_parts) | |
| def test_add_content_to_text(self): | |
| test_text = """all_configs = { | |
| "gpt": "GPTConfig", | |
| "bert": "BertConfig", | |
| "t5": "T5Config", | |
| }""" | |
| expected = """all_configs = { | |
| "gpt": "GPTConfig", | |
| "gpt2": "GPT2Config", | |
| "bert": "BertConfig", | |
| "t5": "T5Config", | |
| }""" | |
| line = ' "gpt2": "GPT2Config",' | |
| self.assertEqual(add_content_to_text(test_text, line, add_before="bert"), expected) | |
| self.assertEqual(add_content_to_text(test_text, line, add_before="bert", exact_match=True), test_text) | |
| self.assertEqual( | |
| add_content_to_text(test_text, line, add_before=' "bert": "BertConfig",', exact_match=True), expected | |
| ) | |
| self.assertEqual(add_content_to_text(test_text, line, add_before=re.compile('^\s*"bert":')), expected) | |
| self.assertEqual(add_content_to_text(test_text, line, add_after="gpt"), expected) | |
| self.assertEqual(add_content_to_text(test_text, line, add_after="gpt", exact_match=True), test_text) | |
| self.assertEqual( | |
| add_content_to_text(test_text, line, add_after=' "gpt": "GPTConfig",', exact_match=True), expected | |
| ) | |
| self.assertEqual(add_content_to_text(test_text, line, add_after=re.compile('^\s*"gpt":')), expected) | |
| def test_add_content_to_file(self): | |
| test_text = """all_configs = { | |
| "gpt": "GPTConfig", | |
| "bert": "BertConfig", | |
| "t5": "T5Config", | |
| }""" | |
| expected = """all_configs = { | |
| "gpt": "GPTConfig", | |
| "gpt2": "GPT2Config", | |
| "bert": "BertConfig", | |
| "t5": "T5Config", | |
| }""" | |
| line = ' "gpt2": "GPT2Config",' | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| file_name = os.path.join(tmp_dir, "code.py") | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_before="bert") | |
| self.check_result(file_name, expected) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_before="bert", exact_match=True) | |
| self.check_result(file_name, test_text) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_before=' "bert": "BertConfig",', exact_match=True) | |
| self.check_result(file_name, expected) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_before=re.compile('^\s*"bert":')) | |
| self.check_result(file_name, expected) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_after="gpt") | |
| self.check_result(file_name, expected) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_after="gpt", exact_match=True) | |
| self.check_result(file_name, test_text) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_after=' "gpt": "GPTConfig",', exact_match=True) | |
| self.check_result(file_name, expected) | |
| self.init_file(file_name, test_text) | |
| add_content_to_file(file_name, line, add_after=re.compile('^\s*"gpt":')) | |
| self.check_result(file_name, expected) | |
| def test_simplify_replacements(self): | |
| self.assertEqual(simplify_replacements([("Bert", "NewBert")]), [("Bert", "NewBert")]) | |
| self.assertEqual( | |
| simplify_replacements([("Bert", "NewBert"), ("bert", "new-bert")]), | |
| [("Bert", "NewBert"), ("bert", "new-bert")], | |
| ) | |
| self.assertEqual( | |
| simplify_replacements([("BertConfig", "NewBertConfig"), ("Bert", "NewBert"), ("bert", "new-bert")]), | |
| [("Bert", "NewBert"), ("bert", "new-bert")], | |
| ) | |
| def test_replace_model_patterns(self): | |
| bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") | |
| new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") | |
| bert_test = '''class TFBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = BertConfig | |
| load_tf_weights = load_tf_weights_in_bert | |
| base_model_prefix = "bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| model_type = "bert" | |
| BERT_CONSTANT = "value" | |
| ''' | |
| bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = NewBertConfig | |
| load_tf_weights = load_tf_weights_in_new_bert | |
| base_model_prefix = "new_bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| model_type = "new-bert" | |
| NEW_BERT_CONSTANT = "value" | |
| ''' | |
| bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns) | |
| self.assertEqual(bert_converted, bert_expected) | |
| # Replacements are empty here since bert as been replaced by bert_new in some instances and bert-new | |
| # in others. | |
| self.assertEqual(replacements, "") | |
| # If we remove the model type, we will get replacements | |
| bert_test = bert_test.replace(' model_type = "bert"\n', "") | |
| bert_expected = bert_expected.replace(' model_type = "new-bert"\n', "") | |
| bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns) | |
| self.assertEqual(bert_converted, bert_expected) | |
| self.assertEqual(replacements, "BERT->NEW_BERT,Bert->NewBert,bert->new_bert") | |
| gpt_model_patterns = ModelPatterns("GPT2", "gpt2") | |
| new_gpt_model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base") | |
| gpt_test = '''class GPT2PreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = GPT2Config | |
| load_tf_weights = load_tf_weights_in_gpt2 | |
| base_model_prefix = "transformer" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| GPT2_CONSTANT = "value" | |
| ''' | |
| gpt_expected = '''class GPTNewNewPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = GPTNewNewConfig | |
| load_tf_weights = load_tf_weights_in_gpt_new_new | |
| base_model_prefix = "transformer" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| GPT_NEW_NEW_CONSTANT = "value" | |
| ''' | |
| gpt_converted, replacements = replace_model_patterns(gpt_test, gpt_model_patterns, new_gpt_model_patterns) | |
| self.assertEqual(gpt_converted, gpt_expected) | |
| # Replacements are empty here since GPT2 as been replaced by GPTNewNew in some instances and GPT_NEW_NEW | |
| # in others. | |
| self.assertEqual(replacements, "") | |
| roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta") | |
| new_roberta_model_patterns = ModelPatterns( | |
| "RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew" | |
| ) | |
| roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta | |
| class RobertaModel(RobertaPreTrainedModel): | |
| """ The base RoBERTa model. """ | |
| checkpoint = roberta-base | |
| base_model_prefix = "roberta" | |
| ''' | |
| roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew | |
| class RobertaNewModel(RobertaNewPreTrainedModel): | |
| """ The base RoBERTa-New model. """ | |
| checkpoint = huggingface/roberta-new-base | |
| base_model_prefix = "roberta_new" | |
| ''' | |
| roberta_converted, replacements = replace_model_patterns( | |
| roberta_test, roberta_model_patterns, new_roberta_model_patterns | |
| ) | |
| self.assertEqual(roberta_converted, roberta_expected) | |
| def test_get_module_from_file(self): | |
| self.assertEqual( | |
| get_module_from_file("/git/transformers/src/transformers/models/bert/modeling_tf_bert.py"), | |
| "transformers.models.bert.modeling_tf_bert", | |
| ) | |
| self.assertEqual( | |
| get_module_from_file("/transformers/models/gpt2/modeling_gpt2.py"), | |
| "transformers.models.gpt2.modeling_gpt2", | |
| ) | |
| with self.assertRaises(ValueError): | |
| get_module_from_file("/models/gpt2/modeling_gpt2.py") | |
| def test_duplicate_module(self): | |
| bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") | |
| new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") | |
| bert_test = '''class TFBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = BertConfig | |
| load_tf_weights = load_tf_weights_in_bert | |
| base_model_prefix = "bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| BERT_CONSTANT = "value" | |
| ''' | |
| bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = NewBertConfig | |
| load_tf_weights = load_tf_weights_in_new_bert | |
| base_model_prefix = "new_bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| NEW_BERT_CONSTANT = "value" | |
| ''' | |
| bert_expected_with_copied_from = ( | |
| "# Copied from transformers.bert_module.TFBertPreTrainedModel with Bert->NewBert,bert->new_bert\n" | |
| + bert_expected | |
| ) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| work_dir = os.path.join(tmp_dir, "transformers") | |
| os.makedirs(work_dir) | |
| file_name = os.path.join(work_dir, "bert_module.py") | |
| dest_file_name = os.path.join(work_dir, "new_bert_module.py") | |
| self.init_file(file_name, bert_test) | |
| duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns) | |
| self.check_result(dest_file_name, bert_expected_with_copied_from) | |
| self.init_file(file_name, bert_test) | |
| duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False) | |
| self.check_result(dest_file_name, bert_expected) | |
| def test_duplicate_module_with_copied_from(self): | |
| bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") | |
| new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") | |
| bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert | |
| class TFBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = BertConfig | |
| load_tf_weights = load_tf_weights_in_bert | |
| base_model_prefix = "bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| BERT_CONSTANT = "value" | |
| ''' | |
| bert_expected = '''# Copied from transformers.models.xxx.XxxModel with Xxx->NewBert | |
| class TFNewBertPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = NewBertConfig | |
| load_tf_weights = load_tf_weights_in_new_bert | |
| base_model_prefix = "new_bert" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| NEW_BERT_CONSTANT = "value" | |
| ''' | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| work_dir = os.path.join(tmp_dir, "transformers") | |
| os.makedirs(work_dir) | |
| file_name = os.path.join(work_dir, "bert_module.py") | |
| dest_file_name = os.path.join(work_dir, "new_bert_module.py") | |
| self.init_file(file_name, bert_test) | |
| duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns) | |
| # There should not be a new Copied from statement, the old one should be adapated. | |
| self.check_result(dest_file_name, bert_expected) | |
| self.init_file(file_name, bert_test) | |
| duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False) | |
| self.check_result(dest_file_name, bert_expected) | |
| def test_filter_framework_files(self): | |
| files = ["modeling_bert.py", "modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"] | |
| self.assertEqual(filter_framework_files(files), files) | |
| self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files)) | |
| self.assertEqual(set(filter_framework_files(files, ["pt"])), {"modeling_bert.py", "configuration_bert.py"}) | |
| self.assertEqual(set(filter_framework_files(files, ["tf"])), {"modeling_tf_bert.py", "configuration_bert.py"}) | |
| self.assertEqual( | |
| set(filter_framework_files(files, ["flax"])), {"modeling_flax_bert.py", "configuration_bert.py"} | |
| ) | |
| self.assertEqual( | |
| set(filter_framework_files(files, ["pt", "tf"])), | |
| {"modeling_tf_bert.py", "modeling_bert.py", "configuration_bert.py"}, | |
| ) | |
| self.assertEqual( | |
| set(filter_framework_files(files, ["tf", "flax"])), | |
| {"modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"}, | |
| ) | |
| self.assertEqual( | |
| set(filter_framework_files(files, ["pt", "flax"])), | |
| {"modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"}, | |
| ) | |
| def test_get_model_files(self): | |
| # BERT | |
| bert_files = get_model_files("bert") | |
| doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} | |
| self.assertEqual(model_files, BERT_MODEL_FILES) | |
| self.assertEqual(bert_files["module_name"], "bert") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} | |
| bert_test_files = { | |
| "tests/models/bert/test_tokenization_bert.py", | |
| "tests/models/bert/test_modeling_bert.py", | |
| "tests/models/bert/test_modeling_tf_bert.py", | |
| "tests/models/bert/test_modeling_flax_bert.py", | |
| } | |
| self.assertEqual(test_files, bert_test_files) | |
| # VIT | |
| vit_files = get_model_files("vit") | |
| doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} | |
| self.assertEqual(model_files, VIT_MODEL_FILES) | |
| self.assertEqual(vit_files["module_name"], "vit") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} | |
| vit_test_files = { | |
| "tests/models/vit/test_image_processing_vit.py", | |
| "tests/models/vit/test_modeling_vit.py", | |
| "tests/models/vit/test_modeling_tf_vit.py", | |
| "tests/models/vit/test_modeling_flax_vit.py", | |
| } | |
| self.assertEqual(test_files, vit_test_files) | |
| # Wav2Vec2 | |
| wav2vec2_files = get_model_files("wav2vec2") | |
| doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} | |
| self.assertEqual(model_files, WAV2VEC2_MODEL_FILES) | |
| self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} | |
| wav2vec2_test_files = { | |
| "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py", | |
| "tests/models/wav2vec2/test_processor_wav2vec2.py", | |
| "tests/models/wav2vec2/test_tokenization_wav2vec2.py", | |
| } | |
| self.assertEqual(test_files, wav2vec2_test_files) | |
| def test_get_model_files_only_pt(self): | |
| # BERT | |
| bert_files = get_model_files("bert", frameworks=["pt"]) | |
| doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} | |
| bert_model_files = BERT_MODEL_FILES - { | |
| "src/transformers/models/bert/modeling_tf_bert.py", | |
| "src/transformers/models/bert/modeling_flax_bert.py", | |
| } | |
| self.assertEqual(model_files, bert_model_files) | |
| self.assertEqual(bert_files["module_name"], "bert") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} | |
| bert_test_files = { | |
| "tests/models/bert/test_tokenization_bert.py", | |
| "tests/models/bert/test_modeling_bert.py", | |
| } | |
| self.assertEqual(test_files, bert_test_files) | |
| # VIT | |
| vit_files = get_model_files("vit", frameworks=["pt"]) | |
| doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} | |
| vit_model_files = VIT_MODEL_FILES - { | |
| "src/transformers/models/vit/modeling_tf_vit.py", | |
| "src/transformers/models/vit/modeling_flax_vit.py", | |
| } | |
| self.assertEqual(model_files, vit_model_files) | |
| self.assertEqual(vit_files["module_name"], "vit") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} | |
| vit_test_files = { | |
| "tests/models/vit/test_image_processing_vit.py", | |
| "tests/models/vit/test_modeling_vit.py", | |
| } | |
| self.assertEqual(test_files, vit_test_files) | |
| # Wav2Vec2 | |
| wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"]) | |
| doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} | |
| wav2vec2_model_files = WAV2VEC2_MODEL_FILES - { | |
| "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py", | |
| "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py", | |
| } | |
| self.assertEqual(model_files, wav2vec2_model_files) | |
| self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} | |
| wav2vec2_test_files = { | |
| "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_wav2vec2.py", | |
| "tests/models/wav2vec2/test_processor_wav2vec2.py", | |
| "tests/models/wav2vec2/test_tokenization_wav2vec2.py", | |
| } | |
| self.assertEqual(test_files, wav2vec2_test_files) | |
| def test_get_model_files_tf_and_flax(self): | |
| # BERT | |
| bert_files = get_model_files("bert", frameworks=["tf", "flax"]) | |
| doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} | |
| bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"} | |
| self.assertEqual(model_files, bert_model_files) | |
| self.assertEqual(bert_files["module_name"], "bert") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} | |
| bert_test_files = { | |
| "tests/models/bert/test_tokenization_bert.py", | |
| "tests/models/bert/test_modeling_tf_bert.py", | |
| "tests/models/bert/test_modeling_flax_bert.py", | |
| } | |
| self.assertEqual(test_files, bert_test_files) | |
| # VIT | |
| vit_files = get_model_files("vit", frameworks=["tf", "flax"]) | |
| doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} | |
| vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"} | |
| self.assertEqual(model_files, vit_model_files) | |
| self.assertEqual(vit_files["module_name"], "vit") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} | |
| vit_test_files = { | |
| "tests/models/vit/test_image_processing_vit.py", | |
| "tests/models/vit/test_modeling_tf_vit.py", | |
| "tests/models/vit/test_modeling_flax_vit.py", | |
| } | |
| self.assertEqual(test_files, vit_test_files) | |
| # Wav2Vec2 | |
| wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"]) | |
| doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx") | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} | |
| wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"} | |
| self.assertEqual(model_files, wav2vec2_model_files) | |
| self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} | |
| wav2vec2_test_files = { | |
| "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py", | |
| "tests/models/wav2vec2/test_processor_wav2vec2.py", | |
| "tests/models/wav2vec2/test_tokenization_wav2vec2.py", | |
| } | |
| self.assertEqual(test_files, wav2vec2_test_files) | |
| def test_find_base_model_checkpoint(self): | |
| self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased") | |
| self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2") | |
| def test_retrieve_model_classes(self): | |
| gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()} | |
| expected_gpt_classes = { | |
| "pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"}, | |
| "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"}, | |
| "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"}, | |
| } | |
| self.assertEqual(gpt_classes, expected_gpt_classes) | |
| del expected_gpt_classes["flax"] | |
| gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["pt", "tf"]).items()} | |
| self.assertEqual(gpt_classes, expected_gpt_classes) | |
| del expected_gpt_classes["pt"] | |
| gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["tf"]).items()} | |
| self.assertEqual(gpt_classes, expected_gpt_classes) | |
| def test_retrieve_info_for_model_with_bert(self): | |
| bert_info = retrieve_info_for_model("bert") | |
| bert_classes = [ | |
| "BertForTokenClassification", | |
| "BertForQuestionAnswering", | |
| "BertForNextSentencePrediction", | |
| "BertForSequenceClassification", | |
| "BertForMaskedLM", | |
| "BertForMultipleChoice", | |
| "BertModel", | |
| "BertForPreTraining", | |
| "BertLMHeadModel", | |
| ] | |
| expected_model_classes = { | |
| "pt": set(bert_classes), | |
| "tf": {f"TF{m}" for m in bert_classes}, | |
| "flax": {f"Flax{m}" for m in bert_classes[:-1] + ["BertForCausalLM"]}, | |
| } | |
| self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"}) | |
| model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()} | |
| self.assertEqual(model_classes, expected_model_classes) | |
| all_bert_files = bert_info["model_files"] | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]} | |
| self.assertEqual(model_files, BERT_MODEL_FILES) | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]} | |
| bert_test_files = { | |
| "tests/models/bert/test_tokenization_bert.py", | |
| "tests/models/bert/test_modeling_bert.py", | |
| "tests/models/bert/test_modeling_tf_bert.py", | |
| "tests/models/bert/test_modeling_flax_bert.py", | |
| } | |
| self.assertEqual(test_files, bert_test_files) | |
| doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx") | |
| self.assertEqual(all_bert_files["module_name"], "bert") | |
| bert_model_patterns = bert_info["model_patterns"] | |
| self.assertEqual(bert_model_patterns.model_name, "BERT") | |
| self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased") | |
| self.assertEqual(bert_model_patterns.model_type, "bert") | |
| self.assertEqual(bert_model_patterns.model_lower_cased, "bert") | |
| self.assertEqual(bert_model_patterns.model_camel_cased, "Bert") | |
| self.assertEqual(bert_model_patterns.model_upper_cased, "BERT") | |
| self.assertEqual(bert_model_patterns.config_class, "BertConfig") | |
| self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer") | |
| self.assertIsNone(bert_model_patterns.feature_extractor_class) | |
| self.assertIsNone(bert_model_patterns.processor_class) | |
| def test_retrieve_info_for_model_pt_tf_with_bert(self): | |
| bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"]) | |
| bert_classes = [ | |
| "BertForTokenClassification", | |
| "BertForQuestionAnswering", | |
| "BertForNextSentencePrediction", | |
| "BertForSequenceClassification", | |
| "BertForMaskedLM", | |
| "BertForMultipleChoice", | |
| "BertModel", | |
| "BertForPreTraining", | |
| "BertLMHeadModel", | |
| ] | |
| expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}} | |
| self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"}) | |
| model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()} | |
| self.assertEqual(model_classes, expected_model_classes) | |
| all_bert_files = bert_info["model_files"] | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]} | |
| bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"} | |
| self.assertEqual(model_files, bert_model_files) | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]} | |
| bert_test_files = { | |
| "tests/models/bert/test_tokenization_bert.py", | |
| "tests/models/bert/test_modeling_bert.py", | |
| "tests/models/bert/test_modeling_tf_bert.py", | |
| } | |
| self.assertEqual(test_files, bert_test_files) | |
| doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx") | |
| self.assertEqual(all_bert_files["module_name"], "bert") | |
| bert_model_patterns = bert_info["model_patterns"] | |
| self.assertEqual(bert_model_patterns.model_name, "BERT") | |
| self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased") | |
| self.assertEqual(bert_model_patterns.model_type, "bert") | |
| self.assertEqual(bert_model_patterns.model_lower_cased, "bert") | |
| self.assertEqual(bert_model_patterns.model_camel_cased, "Bert") | |
| self.assertEqual(bert_model_patterns.model_upper_cased, "BERT") | |
| self.assertEqual(bert_model_patterns.config_class, "BertConfig") | |
| self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer") | |
| self.assertIsNone(bert_model_patterns.feature_extractor_class) | |
| self.assertIsNone(bert_model_patterns.processor_class) | |
| def test_retrieve_info_for_model_with_vit(self): | |
| vit_info = retrieve_info_for_model("vit") | |
| vit_classes = ["ViTForImageClassification", "ViTModel"] | |
| pt_only_classes = ["ViTForMaskedImageModeling"] | |
| expected_model_classes = { | |
| "pt": set(vit_classes + pt_only_classes), | |
| "tf": {f"TF{m}" for m in vit_classes}, | |
| "flax": {f"Flax{m}" for m in vit_classes}, | |
| } | |
| self.assertEqual(set(vit_info["frameworks"]), {"pt", "tf", "flax"}) | |
| model_classes = {k: set(v) for k, v in vit_info["model_classes"].items()} | |
| self.assertEqual(model_classes, expected_model_classes) | |
| all_vit_files = vit_info["model_files"] | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["model_files"]} | |
| self.assertEqual(model_files, VIT_MODEL_FILES) | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]} | |
| vit_test_files = { | |
| "tests/models/vit/test_image_processing_vit.py", | |
| "tests/models/vit/test_modeling_vit.py", | |
| "tests/models/vit/test_modeling_tf_vit.py", | |
| "tests/models/vit/test_modeling_flax_vit.py", | |
| } | |
| self.assertEqual(test_files, vit_test_files) | |
| doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx") | |
| self.assertEqual(all_vit_files["module_name"], "vit") | |
| vit_model_patterns = vit_info["model_patterns"] | |
| self.assertEqual(vit_model_patterns.model_name, "ViT") | |
| self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224-in21k") | |
| self.assertEqual(vit_model_patterns.model_type, "vit") | |
| self.assertEqual(vit_model_patterns.model_lower_cased, "vit") | |
| self.assertEqual(vit_model_patterns.model_camel_cased, "ViT") | |
| self.assertEqual(vit_model_patterns.model_upper_cased, "VIT") | |
| self.assertEqual(vit_model_patterns.config_class, "ViTConfig") | |
| self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor") | |
| self.assertEqual(vit_model_patterns.image_processor_class, "ViTImageProcessor") | |
| self.assertIsNone(vit_model_patterns.tokenizer_class) | |
| self.assertIsNone(vit_model_patterns.processor_class) | |
| def test_retrieve_info_for_model_with_wav2vec2(self): | |
| wav2vec2_info = retrieve_info_for_model("wav2vec2") | |
| wav2vec2_classes = [ | |
| "Wav2Vec2Model", | |
| "Wav2Vec2ForPreTraining", | |
| "Wav2Vec2ForAudioFrameClassification", | |
| "Wav2Vec2ForCTC", | |
| "Wav2Vec2ForMaskedLM", | |
| "Wav2Vec2ForSequenceClassification", | |
| "Wav2Vec2ForXVector", | |
| ] | |
| expected_model_classes = { | |
| "pt": set(wav2vec2_classes), | |
| "tf": {f"TF{m}" for m in wav2vec2_classes[:1]}, | |
| "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]}, | |
| } | |
| self.assertEqual(set(wav2vec2_info["frameworks"]), {"pt", "tf", "flax"}) | |
| model_classes = {k: set(v) for k, v in wav2vec2_info["model_classes"].items()} | |
| self.assertEqual(model_classes, expected_model_classes) | |
| all_wav2vec2_files = wav2vec2_info["model_files"] | |
| model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["model_files"]} | |
| self.assertEqual(model_files, WAV2VEC2_MODEL_FILES) | |
| test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]} | |
| wav2vec2_test_files = { | |
| "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py", | |
| "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py", | |
| "tests/models/wav2vec2/test_processor_wav2vec2.py", | |
| "tests/models/wav2vec2/test_tokenization_wav2vec2.py", | |
| } | |
| self.assertEqual(test_files, wav2vec2_test_files) | |
| doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) | |
| self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx") | |
| self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2") | |
| wav2vec2_model_patterns = wav2vec2_info["model_patterns"] | |
| self.assertEqual(wav2vec2_model_patterns.model_name, "Wav2Vec2") | |
| self.assertEqual(wav2vec2_model_patterns.checkpoint, "facebook/wav2vec2-base-960h") | |
| self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2") | |
| self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2") | |
| self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2") | |
| self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2") | |
| self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config") | |
| self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor") | |
| self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor") | |
| self.assertEqual(wav2vec2_model_patterns.tokenizer_class, "Wav2Vec2CTCTokenizer") | |
| def test_clean_frameworks_in_init_with_gpt(self): | |
| test_init = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available | |
| _import_structure = { | |
| "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], | |
| "tokenization_gpt2": ["GPT2Tokenizer"], | |
| } | |
| try: | |
| if not is_tokenizers_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"] | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_gpt2"] = ["GPT2Model"] | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"] | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"] | |
| if TYPE_CHECKING: | |
| from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig | |
| from .tokenization_gpt2 import GPT2Tokenizer | |
| try: | |
| if not is_tokenizers_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .tokenization_gpt2_fast import GPT2TokenizerFast | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_gpt2 import GPT2Model | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_tf_gpt2 import TFGPT2Model | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_flax_gpt2 import FlaxGPT2Model | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_no_tokenizer = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available | |
| _import_structure = { | |
| "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], | |
| } | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_gpt2"] = ["GPT2Model"] | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"] | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"] | |
| if TYPE_CHECKING: | |
| from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_gpt2 import GPT2Model | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_tf_gpt2 import TFGPT2Model | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_flax_gpt2 import FlaxGPT2Model | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_pt_only = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_tokenizers_available, is_torch_available | |
| _import_structure = { | |
| "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], | |
| "tokenization_gpt2": ["GPT2Tokenizer"], | |
| } | |
| try: | |
| if not is_tokenizers_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"] | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_gpt2"] = ["GPT2Model"] | |
| if TYPE_CHECKING: | |
| from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig | |
| from .tokenization_gpt2 import GPT2Tokenizer | |
| try: | |
| if not is_tokenizers_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .tokenization_gpt2_fast import GPT2TokenizerFast | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_gpt2 import GPT2Model | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_pt_only_no_tokenizer = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_torch_available | |
| _import_structure = { | |
| "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], | |
| } | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_gpt2"] = ["GPT2Model"] | |
| if TYPE_CHECKING: | |
| from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_gpt2 import GPT2Model | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| file_name = os.path.join(tmp_dir, "../__init__.py") | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, keep_processing=False) | |
| self.check_result(file_name, init_no_tokenizer) | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, frameworks=["pt"]) | |
| self.check_result(file_name, init_pt_only) | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False) | |
| self.check_result(file_name, init_pt_only_no_tokenizer) | |
| def test_clean_frameworks_in_init_with_vit(self): | |
| test_init = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available | |
| _import_structure = { | |
| "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], | |
| } | |
| try: | |
| if not is_vision_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["image_processing_vit"] = ["ViTImageProcessor"] | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_vit"] = ["ViTModel"] | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_tf_vit"] = ["TFViTModel"] | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_flax_vit"] = ["FlaxViTModel"] | |
| if TYPE_CHECKING: | |
| from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig | |
| try: | |
| if not is_vision_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .image_processing_vit import ViTImageProcessor | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_vit import ViTModel | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_tf_vit import TFViTModel | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_flax_vit import FlaxViTModel | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_no_feature_extractor = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available | |
| _import_structure = { | |
| "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], | |
| } | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_vit"] = ["ViTModel"] | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_tf_vit"] = ["TFViTModel"] | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_flax_vit"] = ["FlaxViTModel"] | |
| if TYPE_CHECKING: | |
| from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_vit import ViTModel | |
| try: | |
| if not is_tf_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_tf_vit import TFViTModel | |
| try: | |
| if not is_flax_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_flax_vit import FlaxViTModel | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_pt_only = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_torch_available, is_vision_available | |
| _import_structure = { | |
| "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], | |
| } | |
| try: | |
| if not is_vision_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["image_processing_vit"] = ["ViTImageProcessor"] | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_vit"] = ["ViTModel"] | |
| if TYPE_CHECKING: | |
| from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig | |
| try: | |
| if not is_vision_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .image_processing_vit import ViTImageProcessor | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_vit import ViTModel | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| init_pt_only_no_feature_extractor = """ | |
| from typing import TYPE_CHECKING | |
| from ...utils import _LazyModule, is_torch_available | |
| _import_structure = { | |
| "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], | |
| } | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| _import_structure["modeling_vit"] = ["ViTModel"] | |
| if TYPE_CHECKING: | |
| from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig | |
| try: | |
| if not is_torch_available(): | |
| raise OptionalDependencyNotAvailable() | |
| except OptionalDependencyNotAvailable: | |
| pass | |
| else: | |
| from .modeling_vit import ViTModel | |
| else: | |
| import sys | |
| sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| file_name = os.path.join(tmp_dir, "../__init__.py") | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, keep_processing=False) | |
| self.check_result(file_name, init_no_feature_extractor) | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, frameworks=["pt"]) | |
| self.check_result(file_name, init_pt_only) | |
| self.init_file(file_name, test_init) | |
| clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False) | |
| self.check_result(file_name, init_pt_only_no_feature_extractor) | |
| def test_duplicate_doc_file(self): | |
| test_doc = """ | |
| # GPT2 | |
| ## Overview | |
| Overview of the model. | |
| ## GPT2Config | |
| [[autodoc]] GPT2Config | |
| ## GPT2Tokenizer | |
| [[autodoc]] GPT2Tokenizer | |
| - save_vocabulary | |
| ## GPT2TokenizerFast | |
| [[autodoc]] GPT2TokenizerFast | |
| ## GPT2 specific outputs | |
| [[autodoc]] models.gpt2.modeling_gpt2.GPT2DoubleHeadsModelOutput | |
| [[autodoc]] models.gpt2.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput | |
| ## GPT2Model | |
| [[autodoc]] GPT2Model | |
| - forward | |
| ## TFGPT2Model | |
| [[autodoc]] TFGPT2Model | |
| - call | |
| ## FlaxGPT2Model | |
| [[autodoc]] FlaxGPT2Model | |
| - __call__ | |
| """ | |
| test_new_doc = """ | |
| # GPT-New New | |
| ## Overview | |
| The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>. | |
| <INSERT SHORT SUMMARY HERE> | |
| The abstract from the paper is the following: | |
| *<INSERT PAPER ABSTRACT HERE>* | |
| Tips: | |
| <INSERT TIPS ABOUT MODEL HERE> | |
| This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>). | |
| The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>). | |
| ## GPTNewNewConfig | |
| [[autodoc]] GPTNewNewConfig | |
| ## GPTNewNewTokenizer | |
| [[autodoc]] GPTNewNewTokenizer | |
| - save_vocabulary | |
| ## GPTNewNewTokenizerFast | |
| [[autodoc]] GPTNewNewTokenizerFast | |
| ## GPTNewNew specific outputs | |
| [[autodoc]] models.gpt_new_new.modeling_gpt_new_new.GPTNewNewDoubleHeadsModelOutput | |
| [[autodoc]] models.gpt_new_new.modeling_tf_gpt_new_new.TFGPTNewNewDoubleHeadsModelOutput | |
| ## GPTNewNewModel | |
| [[autodoc]] GPTNewNewModel | |
| - forward | |
| ## TFGPTNewNewModel | |
| [[autodoc]] TFGPTNewNewModel | |
| - call | |
| ## FlaxGPTNewNewModel | |
| [[autodoc]] FlaxGPTNewNewModel | |
| - __call__ | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| doc_file = os.path.join(tmp_dir, "gpt2.mdx") | |
| new_doc_file = os.path.join(tmp_dir, "gpt-new-new.mdx") | |
| gpt2_model_patterns = ModelPatterns("GPT2", "gpt2", tokenizer_class="GPT2Tokenizer") | |
| new_model_patterns = ModelPatterns( | |
| "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPTNewNewTokenizer" | |
| ) | |
| self.init_file(doc_file, test_doc) | |
| duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns) | |
| self.check_result(new_doc_file, test_new_doc) | |
| test_new_doc_pt_only = test_new_doc.replace( | |
| """ | |
| ## TFGPTNewNewModel | |
| [[autodoc]] TFGPTNewNewModel | |
| - call | |
| ## FlaxGPTNewNewModel | |
| [[autodoc]] FlaxGPTNewNewModel | |
| - __call__ | |
| """, | |
| "", | |
| ) | |
| self.init_file(doc_file, test_doc) | |
| duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"]) | |
| self.check_result(new_doc_file, test_new_doc_pt_only) | |
| test_new_doc_no_tok = test_new_doc.replace( | |
| """ | |
| ## GPTNewNewTokenizer | |
| [[autodoc]] GPTNewNewTokenizer | |
| - save_vocabulary | |
| ## GPTNewNewTokenizerFast | |
| [[autodoc]] GPTNewNewTokenizerFast | |
| """, | |
| "", | |
| ) | |
| new_model_patterns = ModelPatterns( | |
| "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer" | |
| ) | |
| self.init_file(doc_file, test_doc) | |
| duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns) | |
| print(test_new_doc_no_tok) | |
| self.check_result(new_doc_file, test_new_doc_no_tok) | |
| test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace( | |
| """ | |
| ## TFGPTNewNewModel | |
| [[autodoc]] TFGPTNewNewModel | |
| - call | |
| ## FlaxGPTNewNewModel | |
| [[autodoc]] FlaxGPTNewNewModel | |
| - __call__ | |
| """, | |
| "", | |
| ) | |
| self.init_file(doc_file, test_doc) | |
| duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"]) | |
| self.check_result(new_doc_file, test_new_doc_pt_only_no_tok) | |