Spaces:
Running
on
Zero
Running
on
Zero
| import unittest | |
| from detectron2.solver.build import _expand_param_groups, reduce_param_groups | |
| class TestOptimizer(unittest.TestCase): | |
| def testExpandParamsGroups(self): | |
| params = [ | |
| { | |
| "params": ["p1", "p2", "p3", "p4"], | |
| "lr": 1.0, | |
| "weight_decay": 3.0, | |
| }, | |
| { | |
| "params": ["p2", "p3", "p5"], | |
| "lr": 2.0, | |
| "momentum": 2.0, | |
| }, | |
| { | |
| "params": ["p1"], | |
| "weight_decay": 4.0, | |
| }, | |
| ] | |
| out = _expand_param_groups(params) | |
| gt = [ | |
| dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa | |
| dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
| dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
| dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa | |
| dict(params=["p5"], lr=2.0, momentum=2.0), # noqa | |
| ] | |
| self.assertEqual(out, gt) | |
| def testReduceParamGroups(self): | |
| params = [ | |
| dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa | |
| dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
| dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
| dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa | |
| dict(params=["p5"], lr=2.0, momentum=2.0), # noqa | |
| ] | |
| gt_groups = [ | |
| { | |
| "lr": 1.0, | |
| "weight_decay": 4.0, | |
| "params": ["p1"], | |
| }, | |
| { | |
| "lr": 2.0, | |
| "weight_decay": 3.0, | |
| "momentum": 2.0, | |
| "params": ["p2", "p6", "p3"], | |
| }, | |
| { | |
| "lr": 1.0, | |
| "weight_decay": 3.0, | |
| "params": ["p4"], | |
| }, | |
| { | |
| "lr": 2.0, | |
| "momentum": 2.0, | |
| "params": ["p5"], | |
| }, | |
| ] | |
| out = reduce_param_groups(params) | |
| self.assertEqual(out, gt_groups) | |