Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import math | |
| import numpy as np | |
| from unittest import TestCase | |
| import torch | |
| from fvcore.common.param_scheduler import ( | |
| CosineParamScheduler, | |
| MultiStepParamScheduler, | |
| StepWithFixedGammaParamScheduler, | |
| ) | |
| from torch import nn | |
| from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler | |
| class TestScheduler(TestCase): | |
| def test_warmup_multistep(self): | |
| p = nn.Parameter(torch.zeros(0)) | |
| opt = torch.optim.SGD([p], lr=5) | |
| multiplier = WarmupParamScheduler( | |
| MultiStepParamScheduler( | |
| [1, 0.1, 0.01, 0.001], | |
| milestones=[10, 15, 20], | |
| num_updates=30, | |
| ), | |
| 0.001, | |
| 5 / 30, | |
| ) | |
| sched = LRMultiplier(opt, multiplier, 30) | |
| # This is an equivalent of: | |
| # sched = WarmupMultiStepLR( | |
| # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5) | |
| p.sum().backward() | |
| opt.step() | |
| lrs = [0.005] | |
| for _ in range(30): | |
| sched.step() | |
| lrs.append(opt.param_groups[0]["lr"]) | |
| self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) | |
| self.assertTrue(np.allclose(lrs[5:10], 5.0)) | |
| self.assertTrue(np.allclose(lrs[10:15], 0.5)) | |
| self.assertTrue(np.allclose(lrs[15:20], 0.05)) | |
| self.assertTrue(np.allclose(lrs[20:], 0.005)) | |
| def test_warmup_cosine(self): | |
| p = nn.Parameter(torch.zeros(0)) | |
| opt = torch.optim.SGD([p], lr=5) | |
| multiplier = WarmupParamScheduler( | |
| CosineParamScheduler(1, 0), | |
| 0.001, | |
| 5 / 30, | |
| ) | |
| sched = LRMultiplier(opt, multiplier, 30) | |
| p.sum().backward() | |
| opt.step() | |
| self.assertEqual(opt.param_groups[0]["lr"], 0.005) | |
| lrs = [0.005] | |
| for _ in range(30): | |
| sched.step() | |
| lrs.append(opt.param_groups[0]["lr"]) | |
| for idx, lr in enumerate(lrs): | |
| expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30)) | |
| if idx >= 5: | |
| self.assertAlmostEqual(lr, expected_cosine) | |
| else: | |
| self.assertNotAlmostEqual(lr, expected_cosine) | |
| def test_warmup_cosine_end_value(self): | |
| from detectron2.config import CfgNode, get_cfg | |
| def _test_end_value(cfg_dict): | |
| cfg = get_cfg() | |
| cfg.merge_from_other_cfg(CfgNode(cfg_dict)) | |
| p = nn.Parameter(torch.zeros(0)) | |
| opt = torch.optim.SGD([p], lr=cfg.SOLVER.BASE_LR) | |
| scheduler = build_lr_scheduler(cfg, opt) | |
| p.sum().backward() | |
| opt.step() | |
| self.assertEqual( | |
| opt.param_groups[0]["lr"], cfg.SOLVER.BASE_LR * cfg.SOLVER.WARMUP_FACTOR | |
| ) | |
| lrs = [] | |
| for _ in range(cfg.SOLVER.MAX_ITER): | |
| scheduler.step() | |
| lrs.append(opt.param_groups[0]["lr"]) | |
| self.assertAlmostEqual(lrs[-1], cfg.SOLVER.BASE_LR_END) | |
| _test_end_value( | |
| { | |
| "SOLVER": { | |
| "LR_SCHEDULER_NAME": "WarmupCosineLR", | |
| "MAX_ITER": 100, | |
| "WARMUP_ITERS": 10, | |
| "WARMUP_FACTOR": 0.1, | |
| "BASE_LR": 5.0, | |
| "BASE_LR_END": 0.0, | |
| } | |
| } | |
| ) | |
| _test_end_value( | |
| { | |
| "SOLVER": { | |
| "LR_SCHEDULER_NAME": "WarmupCosineLR", | |
| "MAX_ITER": 100, | |
| "WARMUP_ITERS": 10, | |
| "WARMUP_FACTOR": 0.1, | |
| "BASE_LR": 5.0, | |
| "BASE_LR_END": 0.5, | |
| } | |
| } | |
| ) | |
| def test_warmup_stepwithfixedgamma(self): | |
| p = nn.Parameter(torch.zeros(0)) | |
| opt = torch.optim.SGD([p], lr=5) | |
| multiplier = WarmupParamScheduler( | |
| StepWithFixedGammaParamScheduler( | |
| base_value=1.0, | |
| gamma=0.1, | |
| num_decays=4, | |
| num_updates=30, | |
| ), | |
| 0.001, | |
| 5 / 30, | |
| rescale_interval=True, | |
| ) | |
| sched = LRMultiplier(opt, multiplier, 30) | |
| p.sum().backward() | |
| opt.step() | |
| lrs = [0.005] | |
| for _ in range(29): | |
| sched.step() | |
| lrs.append(opt.param_groups[0]["lr"]) | |
| self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) | |
| self.assertTrue(np.allclose(lrs[5:10], 5.0)) | |
| self.assertTrue(np.allclose(lrs[10:15], 0.5)) | |
| self.assertTrue(np.allclose(lrs[15:20], 0.05)) | |
| self.assertTrue(np.allclose(lrs[20:25], 0.005)) | |
| self.assertTrue(np.allclose(lrs[25:], 0.0005)) | |
| # Calling sche.step() after the last training iteration is done will trigger IndexError | |
| with self.assertRaises(IndexError, msg="list index out of range"): | |
| sched.step() | |