Merge pull request #27898 from MattConley:loss_scale_patch

PiperOrigin-RevId: 243923342
This commit is contained in:
TensorFlower Gardener 2019-04-16 19:56:23 -07:00
commit ffe054107f
4 changed files with 18 additions and 5 deletions

View File

@ -79,7 +79,8 @@ class LossScale(trackable.Trackable):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and are ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
@ -182,8 +183,9 @@ class FixedLossScale(LossScale):
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g))
for g in grads]
is_finite_per_grad = [
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
]
return math_ops.reduce_all(is_finite_per_grad)

View File

@ -266,6 +266,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
self.assertEqual(loss_scale.increment_period, 2)
self.assertEqual(loss_scale.multiplier, 3)
@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])
@test_util.run_in_graph_and_eager_modes
def test_get(self):
scalar = loss_scale_module.get('dynamic')

View File

@ -83,7 +83,8 @@ class LossScale(trackable.Trackable):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and are ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
@ -196,7 +197,7 @@ class FixedLossScale(LossScale):
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
]
return math_ops.reduce_all(is_finite_per_grad)

View File

@ -252,6 +252,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])
@test_util.run_in_graph_and_eager_modes
def test_get(self):
scalar = loss_scale_module.get('dynamic')