From 7152155517fbda482b4bffe66ddf56fd06b6aa04 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 8 Jun 2020 14:40:01 -0700 Subject: [PATCH] Always increment iterations in LossScaleOptimizer. Now self.iterations is incremented in LossScaleOptimizer.apply_gradients even when gradients are not applied to the variables due to NaNs. Before, self.iterations was not incremented because I considered self.iterations to represent the number of times gradients were applied to variables. But this caused confusion because apply_gradients() did not always increment self.iterations, so now it is incremented unconditionally. PiperOrigin-RevId: 315356072 Change-Id: I4e8620d1ef84af55eb7a1f70b20671a7d7af38b9 --- .../experimental/loss_scale_optimizer.py | 8 ++++++- .../experimental/loss_scale_optimizer_test.py | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index d6a786aa4e4..72d9d0390fd 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -395,13 +395,19 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): self._apply_gradients, args=(grads, wrapped_vars, name, experimental_aggregate_gradients)) + def do_not_apply_fn(): + # Normally self._optimizer.iterations is incremented in + # self._optimizer.apply_gradients(). Since that is not called in this + # branch, we increment it here instead. + return self._optimizer.iterations.assign_add(1, read_value=False) + # Note: We must call this cond() in a cross-replica context. # DistributionStrategy does not support having a cond in a replica context # with a branch that calls `merge_call`, and self._optimizer.apply_gradients # calls `merge_call`. maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, - control_flow_ops.no_op) + do_not_apply_fn) return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) def _apply_gradients(self, grads, wrapped_vars, name, diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py index 20252ff3885..992ed17f0c6 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -285,6 +285,30 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertEqual(lso.iterations, 7) self.assertEqual(opt.iterations, 7) + @parameterized.named_parameters(*TESTCASES) + def testIterationsIncremented(self, strategy_fn): + with strategy_fn().scope() as strategy: + # Test iterations is incremented in opt.minimize. + opt = gradient_descent.SGD(1.0) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale='dynamic') + var = variables.Variable([5.0]) + loss = lambda: var * 2.0 / strategy.num_replicas_in_sync + run_fn = lambda: opt.minimize(loss, [var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), 3.0) # Grad is 2, so var is 5 - 2 + self.assertEqual(self.evaluate(opt.iterations), 1) + + # Test iterations is incremented in opt.minimize even if gradients aren't + # applied to variables due to NaN gradients. + loss = lambda: var * float('NaN') + run_fn = lambda: opt.minimize(loss, [var]) + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), 3.0) + self.assertEqual(self.evaluate(opt.iterations), 2) + def testWeightMethods(self): with self.test_session(): var = variables.Variable([1.0])