Support clipping in LossScaleOptimizer.

This has been made possible by previous refactors to the base Optimizer.

PiperOrigin-RevId: 333336359
Change-Id: Ie31d96ff9b27e3961940c9a221486963be02ced7
This commit is contained in:
Reed Wanderman-Milne 2020-09-23 11:35:46 -07:00 committed by TensorFlower Gardener
parent 2c1693cfcf
commit c2926ab744
2 changed files with 70 additions and 12 deletions

View File

@ -248,20 +248,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
'got: %s' % optimizer)
if optimizer.clipnorm is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipnorm. Optimizer %s has clipnorm '
'%s' % (optimizer, optimizer.clipnorm))
if optimizer.clipvalue is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipvalue. Optimizer %s has '
'clipvalue %s' % (optimizer, optimizer.clipvalue))
self._raise_if_strategy_unsupported()
self.clipnorm = None
self.clipvalue = None
self._optimizer = optimizer
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
@ -486,6 +474,30 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
def set_weights(self, weights):
return self._optimizer.set_weights(weights)
@property
def clipnorm(self):
return self._optimizer.clipnorm
@clipnorm.setter
def clipnorm(self, val):
self._optimizer.clipnorm = val
@property
def global_clipnorm(self):
return self._optimizer.global_clipnorm
@global_clipnorm.setter
def global_clipnorm(self, val):
self._optimizer.global_clipnorm = val
@property
def clipvalue(self):
return self._optimizer.clipvalue
@clipvalue.setter
def clipvalue(self, val):
self._optimizer.clipvalue = val
def _aggregate_gradients(self, grads_and_vars):
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access

View File

@ -189,6 +189,52 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# 1.
self.assertAllClose([1.], self.evaluate(var))
# pylint: disable=cell-var-from-loop
@parameterized.named_parameters(*TESTCASES)
def testClipping(self, strategy_fn):
strategy = strategy_fn()
learning_rate = 2.
for clip_type in ('clipnorm', 'global_clipnorm', 'clipvalue'):
with strategy.scope(), self.subTest(clip_type=clip_type):
var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0})
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=2, increment_period=1, multiplier=2)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
self.assertEqual(getattr(opt, clip_type), 2.0)
self.assertEqual(
loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0)
loss = lambda: var * 4 / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, var_list=[var])
# Test running with clipped gradients
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The gradient is 4 but is clipped to 2, so the variable will be
# init_val - clipped_grad * lr == 5 - 2 * 2 == 1
self.assertAllClose([1.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), 4)
# Test changing the clip amount and running again
setattr(opt, clip_type, 3.0)
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
# The gradient is 4 but is clipped to 3, so the variable will be
# prev_var - clipped_grad * lr == 1 - 3 * 2 == -5
self.assertAllClose([-5.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), 8)
# Test Inf gradients are still skipped instead of being clipped
loss = lambda: var * float('Inf')
run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
self.assertAllClose([-5.], self.evaluate(var)) # Var does not change
self.assertEqual(self.evaluate(opt.loss_scale()), 4)
# pylint: enable=cell-var-from-loop
@parameterized.named_parameters(*TESTCASES)
def testDynamicUpdate(self, strategy_fn):
with strategy_fn().scope() as strategy: