Support clipping in LossScaleOptimizer.
This has been made possible by previous refactors to the base Optimizer. PiperOrigin-RevId: 333336359 Change-Id: Ie31d96ff9b27e3961940c9a221486963be02ced7
This commit is contained in:
parent
2c1693cfcf
commit
c2926ab744
@ -248,20 +248,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
|
||||
'got: %s' % optimizer)
|
||||
if optimizer.clipnorm is not None:
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipnorm. Optimizer %s has clipnorm '
|
||||
'%s' % (optimizer, optimizer.clipnorm))
|
||||
|
||||
if optimizer.clipvalue is not None:
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipvalue. Optimizer %s has '
|
||||
'clipvalue %s' % (optimizer, optimizer.clipvalue))
|
||||
self._raise_if_strategy_unsupported()
|
||||
|
||||
self.clipnorm = None
|
||||
self.clipvalue = None
|
||||
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
if self._loss_scale is None:
|
||||
@ -486,6 +474,30 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
def set_weights(self, weights):
|
||||
return self._optimizer.set_weights(weights)
|
||||
|
||||
@property
|
||||
def clipnorm(self):
|
||||
return self._optimizer.clipnorm
|
||||
|
||||
@clipnorm.setter
|
||||
def clipnorm(self, val):
|
||||
self._optimizer.clipnorm = val
|
||||
|
||||
@property
|
||||
def global_clipnorm(self):
|
||||
return self._optimizer.global_clipnorm
|
||||
|
||||
@global_clipnorm.setter
|
||||
def global_clipnorm(self, val):
|
||||
self._optimizer.global_clipnorm = val
|
||||
|
||||
@property
|
||||
def clipvalue(self):
|
||||
return self._optimizer.clipvalue
|
||||
|
||||
@clipvalue.setter
|
||||
def clipvalue(self, val):
|
||||
self._optimizer.clipvalue = val
|
||||
|
||||
def _aggregate_gradients(self, grads_and_vars):
|
||||
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
|
||||
|
||||
|
@ -189,6 +189,52 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# 1.
|
||||
self.assertAllClose([1.], self.evaluate(var))
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def testClipping(self, strategy_fn):
|
||||
strategy = strategy_fn()
|
||||
learning_rate = 2.
|
||||
for clip_type in ('clipnorm', 'global_clipnorm', 'clipvalue'):
|
||||
with strategy.scope(), self.subTest(clip_type=clip_type):
|
||||
var = variables.Variable([5.0])
|
||||
opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0})
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=2, increment_period=1, multiplier=2)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
self.assertEqual(getattr(opt, clip_type), 2.0)
|
||||
self.assertEqual(
|
||||
loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0)
|
||||
|
||||
loss = lambda: var * 4 / strategy.num_replicas_in_sync
|
||||
run_fn = lambda: opt.minimize(loss, var_list=[var])
|
||||
|
||||
# Test running with clipped gradients
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
# The gradient is 4 but is clipped to 2, so the variable will be
|
||||
# init_val - clipped_grad * lr == 5 - 2 * 2 == 1
|
||||
self.assertAllClose([1.], self.evaluate(var))
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 4)
|
||||
|
||||
# Test changing the clip amount and running again
|
||||
setattr(opt, clip_type, 3.0)
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
# The gradient is 4 but is clipped to 3, so the variable will be
|
||||
# prev_var - clipped_grad * lr == 1 - 3 * 2 == -5
|
||||
self.assertAllClose([-5.], self.evaluate(var))
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 8)
|
||||
|
||||
# Test Inf gradients are still skipped instead of being clipped
|
||||
loss = lambda: var * float('Inf')
|
||||
run_fn = lambda: opt.minimize(loss, var_list=[var])
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
self.assertAllClose([-5.], self.evaluate(var)) # Var does not change
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 4)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def testDynamicUpdate(self, strategy_fn):
|
||||
with strategy_fn().scope() as strategy:
|
||||
|
Loading…
Reference in New Issue
Block a user