Apply changes to v2 loss scale

-Also update docstring and test
This commit is contained in:
Matt Conley 2019-04-16 12:47:27 -07:00
parent 4896cc21ab
commit 3b486c4464
4 changed files with 12 additions and 5 deletions

View File

@ -79,7 +79,8 @@ class LossScale(trackable.Trackable):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and should be ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
@ -183,7 +184,7 @@ class FixedLossScale(LossScale):
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g))
for g in grads]
for g in grads if g is not None]
return math_ops.reduce_all(is_finite_per_grad)

View File

@ -266,6 +266,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
self.assertEqual(loss_scale.increment_period, 2)
self.assertEqual(loss_scale.multiplier, 3)
@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])
@test_util.run_in_graph_and_eager_modes
def test_get(self):
scalar = loss_scale_module.get('dynamic')

View File

@ -83,7 +83,8 @@ class LossScale(trackable.Trackable):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and should be ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss

View File

@ -254,8 +254,8 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scaler = loss_scale_module.DynamicLossScale()
loss_scaler.update([None])
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])
@test_util.run_in_graph_and_eager_modes
def test_get(self):