Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Test script to verify the training pipeline fixes | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| def test_imports(): | |
| """Test that all imports work correctly""" | |
| print("π Testing imports...") | |
| try: | |
| from src.config import get_config | |
| print("β config.py imported successfully") | |
| except Exception as e: | |
| print(f"β config.py import failed: {e}") | |
| return False | |
| try: | |
| from src.model import SmolLM3Model | |
| print("β model.py imported successfully") | |
| except Exception as e: | |
| print(f"β model.py import failed: {e}") | |
| return False | |
| try: | |
| from src.data import SmolLM3Dataset | |
| print("β data.py imported successfully") | |
| except Exception as e: | |
| print(f"β data.py import failed: {e}") | |
| return False | |
| try: | |
| from src.trainer import SmolLM3Trainer | |
| print("β trainer.py imported successfully") | |
| except Exception as e: | |
| print(f"β trainer.py import failed: {e}") | |
| return False | |
| try: | |
| from src.monitoring import create_monitor_from_config | |
| print("β monitoring.py imported successfully") | |
| except Exception as e: | |
| print(f"β monitoring.py import failed: {e}") | |
| return False | |
| return True | |
| def test_config_loading(): | |
| """Test configuration loading""" | |
| print("\nπ Testing configuration loading...") | |
| try: | |
| from src.config import get_config | |
| # Test loading the H100 lightweight config | |
| config = get_config("config/train_smollm3_h100_lightweight.py") | |
| print("β Configuration loaded successfully") | |
| print(f" Model: {config.model_name}") | |
| print(f" Dataset: {config.dataset_name}") | |
| print(f" Batch size: {config.batch_size}") | |
| print(f" Learning rate: {config.learning_rate}") | |
| return True | |
| except Exception as e: | |
| print(f"β Configuration loading failed: {e}") | |
| return False | |
| def test_monitoring_setup(): | |
| """Test monitoring setup without Trackio Space""" | |
| print("\nπ Testing monitoring setup...") | |
| try: | |
| from src.monitoring import create_monitor_from_config | |
| from src.config import get_config | |
| # Load config | |
| config = get_config("config/train_smollm3_h100_lightweight.py") | |
| # Set Trackio URL to a non-existent one to test fallback | |
| config.trackio_url = "https://non-existent-space.hf.space" | |
| config.experiment_name = "test_experiment" | |
| # Create monitor | |
| monitor = create_monitor_from_config(config) | |
| print("β Monitoring setup successful") | |
| print(f" Experiment: {monitor.experiment_name}") | |
| print(f" Tracking enabled: {monitor.enable_tracking}") | |
| print(f" HF Dataset: {monitor.dataset_repo}") | |
| return True | |
| except Exception as e: | |
| print(f"β Monitoring setup failed: {e}") | |
| return False | |
| def test_trainer_creation(): | |
| """Test trainer creation""" | |
| print("\nπ Testing trainer creation...") | |
| try: | |
| from src.config import get_config | |
| from src.model import SmolLM3Model | |
| from src.data import SmolLM3Dataset | |
| from src.trainer import SmolLM3Trainer | |
| # Load config | |
| config = get_config("config/train_smollm3_h100_lightweight.py") | |
| # Create model (without loading the actual model) | |
| model = SmolLM3Model( | |
| model_name=config.model_name, | |
| max_seq_length=config.max_seq_length, | |
| config=config | |
| ) | |
| print("β Model created successfully") | |
| # Create dataset (without loading actual data) | |
| dataset = SmolLM3Dataset( | |
| data_path=config.dataset_name, | |
| tokenizer=model.tokenizer, | |
| max_seq_length=config.max_seq_length, | |
| config=config | |
| ) | |
| print("β Dataset created successfully") | |
| # Create trainer | |
| trainer = SmolLM3Trainer( | |
| model=model, | |
| dataset=dataset, | |
| config=config, | |
| output_dir="/tmp/test_output", | |
| init_from="scratch" | |
| ) | |
| print("β Trainer created successfully") | |
| return True | |
| except Exception as e: | |
| print(f"β Trainer creation failed: {e}") | |
| return False | |
| def test_format_string_fix(): | |
| """Test that the format string fix works""" | |
| print("\nπ Testing format string fix...") | |
| try: | |
| from src.trainer import SmolLM3Trainer | |
| # Test the SimpleConsoleCallback format string handling | |
| from transformers import TrainerCallback | |
| class TestCallback(TrainerCallback): | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs and isinstance(logs, dict): | |
| step = getattr(state, 'global_step', 'unknown') | |
| loss = logs.get('loss', 'N/A') | |
| lr = logs.get('learning_rate', 'N/A') | |
| # Test the fixed format string logic | |
| if isinstance(loss, (int, float)): | |
| loss_str = f"{loss:.4f}" | |
| else: | |
| loss_str = str(loss) | |
| if isinstance(lr, (int, float)): | |
| lr_str = f"{lr:.2e}" | |
| else: | |
| lr_str = str(lr) | |
| print(f"Step {step}: loss={loss_str}, lr={lr_str}") | |
| print("β Format string fix works correctly") | |
| return True | |
| except Exception as e: | |
| print(f"β Format string fix test failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests""" | |
| print("π Testing SmolLM3 Training Pipeline Fixes") | |
| print("=" * 50) | |
| tests = [ | |
| test_imports, | |
| test_config_loading, | |
| test_monitoring_setup, | |
| test_trainer_creation, | |
| test_format_string_fix | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test in tests: | |
| try: | |
| if test(): | |
| passed += 1 | |
| except Exception as e: | |
| print(f"β Test {test.__name__} crashed: {e}") | |
| print(f"\nπ Test Results: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("β All tests passed! The training pipeline should work correctly.") | |
| return True | |
| else: | |
| print("β Some tests failed. Please check the errors above.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |