Spaces:
Runtime error
Runtime error
| from tempfile import TemporaryDirectory | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock, patch | |
| from transformers import AutoModel, TFAutoModel | |
| from transformers.onnx import FeaturesManager | |
| from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch | |
| class DetermineFrameworkTest(TestCase): | |
| """ | |
| Test `FeaturesManager.determine_framework` | |
| """ | |
| def setUp(self): | |
| self.test_model = SMALL_MODEL_IDENTIFIER | |
| self.framework_pt = "pt" | |
| self.framework_tf = "tf" | |
| def _setup_pt_ckpt(self, save_dir): | |
| model_pt = AutoModel.from_pretrained(self.test_model) | |
| model_pt.save_pretrained(save_dir) | |
| def _setup_tf_ckpt(self, save_dir): | |
| model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) | |
| model_tf.save_pretrained(save_dir) | |
| def test_framework_provided(self): | |
| """ | |
| Ensure the that the provided framework is returned. | |
| """ | |
| mock_framework = "mock_framework" | |
| # Framework provided - return whatever the user provides | |
| result = FeaturesManager.determine_framework(self.test_model, mock_framework) | |
| self.assertEqual(result, mock_framework) | |
| # Local checkpoint and framework provided - return provided framework | |
| # PyTorch checkpoint | |
| with TemporaryDirectory() as local_pt_ckpt: | |
| self._setup_pt_ckpt(local_pt_ckpt) | |
| result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) | |
| self.assertEqual(result, mock_framework) | |
| # TensorFlow checkpoint | |
| with TemporaryDirectory() as local_tf_ckpt: | |
| self._setup_tf_ckpt(local_tf_ckpt) | |
| result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) | |
| self.assertEqual(result, mock_framework) | |
| def test_checkpoint_provided(self): | |
| """ | |
| Ensure that the determined framework is the one used for the local checkpoint. | |
| For the functionality to execute, local checkpoints are provided but framework is not. | |
| """ | |
| # PyTorch checkpoint | |
| with TemporaryDirectory() as local_pt_ckpt: | |
| self._setup_pt_ckpt(local_pt_ckpt) | |
| result = FeaturesManager.determine_framework(local_pt_ckpt) | |
| self.assertEqual(result, self.framework_pt) | |
| # TensorFlow checkpoint | |
| with TemporaryDirectory() as local_tf_ckpt: | |
| self._setup_tf_ckpt(local_tf_ckpt) | |
| result = FeaturesManager.determine_framework(local_tf_ckpt) | |
| self.assertEqual(result, self.framework_tf) | |
| # Invalid local checkpoint | |
| with TemporaryDirectory() as local_invalid_ckpt: | |
| with self.assertRaises(FileNotFoundError): | |
| result = FeaturesManager.determine_framework(local_invalid_ckpt) | |
| def test_from_environment(self): | |
| """ | |
| Ensure that the determined framework is the one available in the environment. | |
| For the functionality to execute, framework and local checkpoints are not provided. | |
| """ | |
| # Framework not provided, hub model is used (no local checkpoint directory) | |
| # TensorFlow not in environment -> use PyTorch | |
| mock_tf_available = MagicMock(return_value=False) | |
| with patch("transformers.onnx.features.is_tf_available", mock_tf_available): | |
| result = FeaturesManager.determine_framework(self.test_model) | |
| self.assertEqual(result, self.framework_pt) | |
| # PyTorch not in environment -> use TensorFlow | |
| mock_torch_available = MagicMock(return_value=False) | |
| with patch("transformers.onnx.features.is_torch_available", mock_torch_available): | |
| result = FeaturesManager.determine_framework(self.test_model) | |
| self.assertEqual(result, self.framework_tf) | |
| # Both in environment -> use PyTorch | |
| mock_tf_available = MagicMock(return_value=True) | |
| mock_torch_available = MagicMock(return_value=True) | |
| with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | |
| "transformers.onnx.features.is_torch_available", mock_torch_available | |
| ): | |
| result = FeaturesManager.determine_framework(self.test_model) | |
| self.assertEqual(result, self.framework_pt) | |
| # Both not in environment -> raise error | |
| mock_tf_available = MagicMock(return_value=False) | |
| mock_torch_available = MagicMock(return_value=False) | |
| with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | |
| "transformers.onnx.features.is_torch_available", mock_torch_available | |
| ): | |
| with self.assertRaises(EnvironmentError): | |
| result = FeaturesManager.determine_framework(self.test_model) | |