Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Test script to verify TrackioConfig update method fix | |
| """ | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| def test_trackio_config_update(): | |
| """Test that TrackioConfig update method works correctly""" | |
| print("π§ͺ Testing TrackioConfig update method...") | |
| try: | |
| # Import trackio module | |
| import trackio | |
| # Test that config attribute exists | |
| assert hasattr(trackio, 'config'), "trackio.config not found" | |
| print("β trackio.config exists") | |
| # Test that config has update method | |
| config = trackio.config | |
| assert hasattr(config, 'update'), "TrackioConfig.update method not found" | |
| print("β TrackioConfig.update method exists") | |
| # Test update method functionality with dictionary | |
| test_config = { | |
| 'project_name': 'test_project', | |
| 'experiment_name': 'test_experiment', | |
| 'new_attribute': 'test_value' | |
| } | |
| # Call update method with dictionary | |
| config.update(test_config) | |
| # Verify updates | |
| assert config.project_name == 'test_project', f"Expected 'test_project', got '{config.project_name}'" | |
| assert config.experiment_name == 'test_experiment', f"Expected 'test_experiment', got '{config.experiment_name}'" | |
| assert config.new_attribute == 'test_value', f"Expected 'test_value', got '{config.new_attribute}'" | |
| print("β TrackioConfig.update method works correctly with dictionary") | |
| # Test update method with keyword arguments (TRL style) | |
| config.update(allow_val_change=True, trl_setting='test_value') | |
| # Verify keyword argument updates | |
| assert config.allow_val_change == True, f"Expected True, got '{config.allow_val_change}'" | |
| assert config.trl_setting == 'test_value', f"Expected 'test_value', got '{config.trl_setting}'" | |
| print("β TrackioConfig.update method works correctly with keyword arguments") | |
| print("β All attributes updated successfully") | |
| return True | |
| except Exception as e: | |
| print(f"β Test failed: {e}") | |
| return False | |
| def test_trackio_trl_compatibility(): | |
| """Test that trackio is fully compatible with TRL expectations""" | |
| print("\nπ Testing TRL Compatibility...") | |
| try: | |
| import trackio | |
| # Test all required functions exist | |
| required_functions = ['init', 'log', 'finish'] | |
| for func_name in required_functions: | |
| assert hasattr(trackio, func_name), f"trackio.{func_name} not found" | |
| print(f"β trackio.{func_name} exists") | |
| # Test config attribute exists and has update method | |
| assert hasattr(trackio, 'config'), "trackio.config not found" | |
| assert hasattr(trackio.config, 'update'), "trackio.config.update not found" | |
| print("β trackio.config.update exists") | |
| # Test that init can be called without arguments (TRL compatibility) | |
| try: | |
| experiment_id = trackio.init() | |
| print(f"β trackio.init() called successfully: {experiment_id}") | |
| except Exception as e: | |
| print(f"β trackio.init() failed: {e}") | |
| return False | |
| # Test that log can be called | |
| try: | |
| trackio.log({'test_metric': 1.0}) | |
| print("β trackio.log() called successfully") | |
| except Exception as e: | |
| print(f"β trackio.log() failed: {e}") | |
| return False | |
| # Test that finish can be called | |
| try: | |
| trackio.finish() | |
| print("β trackio.finish() called successfully") | |
| except Exception as e: | |
| print(f"β trackio.finish() failed: {e}") | |
| return False | |
| print("β All TRL compatibility tests passed") | |
| return True | |
| except Exception as e: | |
| print(f"β TRL compatibility test failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests""" | |
| print("π§ͺ TrackioConfig Update Fix Test") | |
| print("=" * 40) | |
| # Test 1: Update method functionality | |
| test1_passed = test_trackio_config_update() | |
| # Test 2: TRL compatibility | |
| test2_passed = test_trackio_trl_compatibility() | |
| # Summary | |
| print("\n" + "=" * 40) | |
| print("π Test Results Summary") | |
| print("=" * 40) | |
| print(f"β Update Method Test: {'PASSED' if test1_passed else 'FAILED'}") | |
| print(f"β TRL Compatibility Test: {'PASSED' if test2_passed else 'FAILED'}") | |
| if test1_passed and test2_passed: | |
| print("\nπ All tests passed! TrackioConfig update fix is working correctly.") | |
| return True | |
| else: | |
| print("\nβ Some tests failed. Please check the implementation.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |