Always increment iterations in LossScaleOptimizer.
Now self.iterations is incremented in LossScaleOptimizer.apply_gradients even when gradients are not applied to the variables due to NaNs. Before, self.iterations was not incremented because I considered self.iterations to represent the number of times gradients were applied to variables. But this caused confusion because apply_gradients() did not always increment self.iterations, so now it is incremented unconditionally. PiperOrigin-RevId: 315356072 Change-Id: I4e8620d1ef84af55eb7a1f70b20671a7d7af38b9
This commit is contained in:
parent
de901d9be9
commit
7152155517
@ -395,13 +395,19 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
self._apply_gradients,
|
||||
args=(grads, wrapped_vars, name, experimental_aggregate_gradients))
|
||||
|
||||
def do_not_apply_fn():
|
||||
# Normally self._optimizer.iterations is incremented in
|
||||
# self._optimizer.apply_gradients(). Since that is not called in this
|
||||
# branch, we increment it here instead.
|
||||
return self._optimizer.iterations.assign_add(1, read_value=False)
|
||||
|
||||
# Note: We must call this cond() in a cross-replica context.
|
||||
# DistributionStrategy does not support having a cond in a replica context
|
||||
# with a branch that calls `merge_call`, and self._optimizer.apply_gradients
|
||||
# calls `merge_call`.
|
||||
maybe_apply_op = smart_cond.smart_cond(should_apply_grads,
|
||||
apply_fn,
|
||||
control_flow_ops.no_op)
|
||||
do_not_apply_fn)
|
||||
return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
|
||||
|
||||
def _apply_gradients(self, grads, wrapped_vars, name,
|
||||
|
@ -285,6 +285,30 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(lso.iterations, 7)
|
||||
self.assertEqual(opt.iterations, 7)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def testIterationsIncremented(self, strategy_fn):
|
||||
with strategy_fn().scope() as strategy:
|
||||
# Test iterations is incremented in opt.minimize.
|
||||
opt = gradient_descent.SGD(1.0)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale='dynamic')
|
||||
var = variables.Variable([5.0])
|
||||
loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
|
||||
run_fn = lambda: opt.minimize(loss, [var])
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
self.assertEqual(self.evaluate(var), 3.0) # Grad is 2, so var is 5 - 2
|
||||
self.assertEqual(self.evaluate(opt.iterations), 1)
|
||||
|
||||
# Test iterations is incremented in opt.minimize even if gradients aren't
|
||||
# applied to variables due to NaN gradients.
|
||||
loss = lambda: var * float('NaN')
|
||||
run_fn = lambda: opt.minimize(loss, [var])
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
self.assertEqual(self.evaluate(var), 3.0)
|
||||
self.assertEqual(self.evaluate(opt.iterations), 2)
|
||||
|
||||
def testWeightMethods(self):
|
||||
with self.test_session():
|
||||
var = variables.Variable([1.0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user