Merge pull request #27898 from MattConley:loss_scale_patch
PiperOrigin-RevId: 243923342
This commit is contained in:
commit
ffe054107f
@ -79,7 +79,8 @@ class LossScale(trackable.Trackable):
|
||||
Args:
|
||||
grads: A list of unscaled gradients, each which is the gradient of the
|
||||
loss with respect to a weight. The gradients should have already been
|
||||
divided by the loss scale being before passed to this function.
|
||||
divided by the loss scale being before passed to this function. 'None'
|
||||
gradients are accepted, and are ignored.
|
||||
|
||||
Returns:
|
||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||
@ -182,8 +183,9 @@ class FixedLossScale(LossScale):
|
||||
|
||||
def _is_all_finite(grads):
|
||||
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
|
||||
is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g))
|
||||
for g in grads]
|
||||
is_finite_per_grad = [
|
||||
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
|
||||
]
|
||||
return math_ops.reduce_all(is_finite_per_grad)
|
||||
|
||||
|
||||
|
@ -266,6 +266,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(loss_scale.increment_period, 2)
|
||||
self.assertEqual(loss_scale.multiplier, 3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_update_with_none_gradients(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale()
|
||||
loss_scale.update([None])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_get(self):
|
||||
scalar = loss_scale_module.get('dynamic')
|
||||
|
@ -83,7 +83,8 @@ class LossScale(trackable.Trackable):
|
||||
Args:
|
||||
grads: A list of unscaled gradients, each which is the gradient of the
|
||||
loss with respect to a weight. The gradients should have already been
|
||||
divided by the loss scale being before passed to this function.
|
||||
divided by the loss scale being before passed to this function. 'None'
|
||||
gradients are accepted, and are ignored.
|
||||
|
||||
Returns:
|
||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||
@ -196,7 +197,7 @@ class FixedLossScale(LossScale):
|
||||
def _is_all_finite(grads):
|
||||
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
|
||||
is_finite_per_grad = [
|
||||
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads
|
||||
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
|
||||
]
|
||||
return math_ops.reduce_all(is_finite_per_grad)
|
||||
|
||||
|
@ -252,6 +252,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
|
||||
self._test_helper(inputs, expected_outputs, init_loss_scale)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_update_with_none_gradients(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale()
|
||||
loss_scale.update([None])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_get(self):
|
||||
scalar = loss_scale_module.get('dynamic')
|
||||
|
Loading…
Reference in New Issue
Block a user