Spaces:
Running
Running
| """ | |
| Unit tests for FileService | |
| """ | |
| import pytest | |
| import os | |
| import tempfile | |
| from unittest.mock import Mock, patch, mock_open | |
| from werkzeug.datastructures import FileStorage | |
| from io import BytesIO | |
| from app.services.file_service import FileService | |
| class TestFileService: | |
| """Test cases for FileService.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.service = FileService() | |
| def test_is_allowed_file_valid_extensions(self): | |
| """Test allowed file extension checking.""" | |
| # Valid extensions | |
| assert self.service.is_allowed_file('test.txt') is True | |
| assert self.service.is_allowed_file('document.md') is True | |
| assert self.service.is_allowed_file('script.py') is True | |
| assert self.service.is_allowed_file('code.js') is True | |
| assert self.service.is_allowed_file('data.json') is True | |
| assert self.service.is_allowed_file('styles.css') is True | |
| assert self.service.is_allowed_file('page.html') is True | |
| assert self.service.is_allowed_file('data.csv') is True | |
| assert self.service.is_allowed_file('app.log') is True | |
| def test_is_allowed_file_invalid_extensions(self): | |
| """Test invalid file extensions.""" | |
| # Invalid extensions | |
| assert self.service.is_allowed_file('virus.exe') is False | |
| assert self.service.is_allowed_file('archive.zip') is False | |
| assert self.service.is_allowed_file('image.jpg') is False | |
| assert self.service.is_allowed_file('document.pdf') is False | |
| assert self.service.is_allowed_file('data.xlsx') is False | |
| def test_is_allowed_file_edge_cases(self): | |
| """Test edge cases for file extension checking.""" | |
| # Empty filename | |
| assert self.service.is_allowed_file('') is False | |
| assert self.service.is_allowed_file(None) is False | |
| # No extension | |
| assert self.service.is_allowed_file('filename') is False | |
| # Multiple dots | |
| assert self.service.is_allowed_file('file.backup.txt') is True | |
| # Case sensitivity | |
| assert self.service.is_allowed_file('FILE.TXT') is True | |
| assert self.service.is_allowed_file('Document.MD') is True | |
| def test_generate_secure_filename_basic(self): | |
| """Test basic secure filename generation.""" | |
| filename = self.service.generate_secure_filename('test.txt') | |
| assert filename.endswith('_test.txt') | |
| assert len(filename) > len('test.txt') # Should have UUID prefix | |
| # Should be different each time | |
| filename2 = self.service.generate_secure_filename('test.txt') | |
| assert filename != filename2 | |
| def test_generate_secure_filename_special_characters(self): | |
| """Test secure filename with special characters.""" | |
| # Test filename with spaces and special chars | |
| filename = self.service.generate_secure_filename('my file name.txt') | |
| assert 'my_file_name.txt' in filename | |
| # Test with path separators (should be removed) | |
| filename = self.service.generate_secure_filename('../../../etc/passwd') | |
| assert '..' not in filename | |
| assert '/' not in filename | |
| assert '\\' not in filename | |
| def test_generate_secure_filename_empty_input(self): | |
| """Test secure filename generation with empty input.""" | |
| filename = self.service.generate_secure_filename('') | |
| assert filename.endswith('.txt') | |
| assert len(filename) > 4 # Should have UUID | |
| filename = self.service.generate_secure_filename(None) | |
| assert filename.endswith('.txt') | |
| assert len(filename) > 4 | |
| def test_save_uploaded_file_basic(self, mock_makedirs, temp_file): | |
| """Test basic file upload saving.""" | |
| # Create a mock uploaded file | |
| file_content = b"Hello world!" | |
| uploaded_file = FileStorage( | |
| stream=BytesIO(file_content), | |
| filename='test.txt', | |
| content_type='text/plain' | |
| ) | |
| upload_folder = '/tmp/test_uploads' | |
| with patch('builtins.open', mock_open()) as mock_file: | |
| file_path = self.service.save_uploaded_file(uploaded_file, upload_folder) | |
| # Check that directory creation was attempted | |
| mock_makedirs.assert_called_once_with(upload_folder, exist_ok=True) | |
| # Check that file path has correct structure | |
| assert file_path.startswith(upload_folder) | |
| assert file_path.endswith('_test.txt') | |
| def test_cleanup_file_existing(self, temp_file): | |
| """Test cleanup of existing file.""" | |
| # Verify file exists | |
| assert os.path.exists(temp_file) | |
| # Cleanup | |
| self.service.cleanup_file(temp_file) | |
| # Verify file is deleted | |
| assert not os.path.exists(temp_file) | |
| def test_cleanup_file_nonexistent(self): | |
| """Test cleanup of non-existent file (should not raise error).""" | |
| # Should not raise an exception | |
| self.service.cleanup_file('/path/that/does/not/exist.txt') | |
| def test_process_file_for_tokenization_basic(self, mock_stats, mock_tokenizer, temp_file): | |
| """Test basic file processing for tokenization.""" | |
| # Mock tokenizer service | |
| mock_tokenizer_obj = Mock() | |
| mock_tokenizer_obj.tokenize.return_value = ['Hello', ' world', '!'] | |
| mock_tokenizer.load_tokenizer.return_value = (mock_tokenizer_obj, {}, None) | |
| # Mock stats service | |
| mock_stats.get_token_stats.return_value = { | |
| 'basic_stats': {'total_tokens': 3}, | |
| 'length_stats': {'avg_length': '2.0'} | |
| } | |
| mock_stats.format_tokens_for_display.return_value = [ | |
| {'display': 'Hello', 'original': 'Hello', 'token_id': 1, 'colors': {}, 'newline': False} | |
| ] | |
| result = self.service.process_file_for_tokenization( | |
| file_path=temp_file, | |
| model_id_or_name='gpt2', | |
| preview_char_limit=1000, | |
| max_display_tokens=100, | |
| chunk_size=1024 | |
| ) | |
| assert isinstance(result, dict) | |
| assert 'tokens' in result | |
| assert 'stats' in result | |
| assert 'display_limit_reached' in result | |
| assert 'total_tokens' in result | |
| assert 'preview_only' in result | |
| assert 'tokenizer_info' in result | |
| def test_process_file_tokenizer_error(self, mock_tokenizer, temp_file): | |
| """Test file processing with tokenizer error.""" | |
| # Mock tokenizer service to return error | |
| mock_tokenizer.load_tokenizer.return_value = (None, {}, "Tokenizer error") | |
| with pytest.raises(Exception) as excinfo: | |
| self.service.process_file_for_tokenization( | |
| file_path=temp_file, | |
| model_id_or_name='invalid-model', | |
| preview_char_limit=1000, | |
| max_display_tokens=100 | |
| ) | |
| assert "Tokenizer error" in str(excinfo.value) | |
| def test_process_text_for_tokenization_basic(self, mock_stats, mock_tokenizer): | |
| """Test basic text processing for tokenization.""" | |
| # Mock tokenizer service | |
| mock_tokenizer_obj = Mock() | |
| mock_tokenizer_obj.tokenize.return_value = ['Hello', ' world'] | |
| mock_tokenizer.load_tokenizer.return_value = (mock_tokenizer_obj, {'vocab_size': 1000}, None) | |
| # Mock stats service | |
| mock_stats.get_token_stats.return_value = { | |
| 'basic_stats': {'total_tokens': 2}, | |
| 'length_stats': {'avg_length': '3.0'} | |
| } | |
| mock_stats.format_tokens_for_display.return_value = [ | |
| {'display': 'Hello', 'original': 'Hello', 'token_id': 1, 'colors': {}, 'newline': False}, | |
| {'display': ' world', 'original': ' world', 'token_id': 2, 'colors': {}, 'newline': False} | |
| ] | |
| result = self.service.process_text_for_tokenization( | |
| text="Hello world", | |
| model_id_or_name='gpt2', | |
| max_display_tokens=100 | |
| ) | |
| assert isinstance(result, dict) | |
| assert 'tokens' in result | |
| assert 'stats' in result | |
| assert result['display_limit_reached'] is False | |
| assert result['total_tokens'] == 2 | |
| assert result['tokenizer_info']['vocab_size'] == 1000 | |
| def test_process_text_display_limit(self, mock_stats, mock_tokenizer): | |
| """Test text processing with display limit.""" | |
| # Create a large number of tokens | |
| tokens = [f'token{i}' for i in range(200)] | |
| # Mock tokenizer service | |
| mock_tokenizer_obj = Mock() | |
| mock_tokenizer_obj.tokenize.return_value = tokens | |
| mock_tokenizer.load_tokenizer.return_value = (mock_tokenizer_obj, {}, None) | |
| # Mock stats service | |
| mock_stats.get_token_stats.return_value = { | |
| 'basic_stats': {'total_tokens': 200}, | |
| 'length_stats': {'avg_length': '6.0'} | |
| } | |
| mock_stats.format_tokens_for_display.return_value = [] | |
| result = self.service.process_text_for_tokenization( | |
| text="Long text", | |
| model_id_or_name='gpt2', | |
| max_display_tokens=100 # Limit lower than token count | |
| ) | |
| assert result['display_limit_reached'] is True | |
| assert result['total_tokens'] == 200 | |
| def test_process_text_tokenizer_error(self, mock_tokenizer): | |
| """Test text processing with tokenizer error.""" | |
| # Mock tokenizer service to return error | |
| mock_tokenizer.load_tokenizer.return_value = (None, {}, "Model not found") | |
| with pytest.raises(Exception) as excinfo: | |
| self.service.process_text_for_tokenization( | |
| text="Hello world", | |
| model_id_or_name='invalid-model' | |
| ) | |
| assert "Model not found" in str(excinfo.value) | |
| def test_process_text_preview_mode(self, mock_stats, mock_tokenizer): | |
| """Test text processing in preview mode.""" | |
| long_text = "A" * 10000 # Long text | |
| # Mock tokenizer service | |
| mock_tokenizer_obj = Mock() | |
| mock_tokenizer_obj.tokenize.return_value = ['A'] * 5000 # Many tokens | |
| mock_tokenizer.load_tokenizer.return_value = (mock_tokenizer_obj, {}, None) | |
| # Mock stats service | |
| mock_stats.get_token_stats.return_value = { | |
| 'basic_stats': {'total_tokens': 5000}, | |
| 'length_stats': {'avg_length': '1.0'} | |
| } | |
| mock_stats.format_tokens_for_display.return_value = [] | |
| result = self.service.process_text_for_tokenization( | |
| text=long_text, | |
| model_id_or_name='gpt2', | |
| is_preview=True, | |
| preview_char_limit=100 | |
| ) | |
| assert result['preview_only'] is True | |
| def test_allowed_extensions_constant(self): | |
| """Test that ALLOWED_EXTENSIONS contains expected extensions.""" | |
| extensions = self.service.ALLOWED_EXTENSIONS | |
| assert isinstance(extensions, set) | |
| # Check for required extensions | |
| required_extensions = {'.txt', '.md', '.py', '.js', '.json', '.html', '.css', '.csv', '.log'} | |
| assert required_extensions.issubset(extensions) | |
| # All extensions should start with dot | |
| for ext in extensions: | |
| assert ext.startswith('.') | |
| assert len(ext) > 1 |