Apply changes to v2 loss scale
-Also update docstring and test
This commit is contained in:
parent
4896cc21ab
commit
3b486c4464
@ -79,7 +79,8 @@ class LossScale(trackable.Trackable):
|
||||
Args:
|
||||
grads: A list of unscaled gradients, each which is the gradient of the
|
||||
loss with respect to a weight. The gradients should have already been
|
||||
divided by the loss scale being before passed to this function.
|
||||
divided by the loss scale being before passed to this function. 'None'
|
||||
gradients are accepted, and should be ignored.
|
||||
|
||||
Returns:
|
||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||
@ -183,7 +184,7 @@ class FixedLossScale(LossScale):
|
||||
def _is_all_finite(grads):
|
||||
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
|
||||
is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g))
|
||||
for g in grads]
|
||||
for g in grads if g is not None]
|
||||
return math_ops.reduce_all(is_finite_per_grad)
|
||||
|
||||
|
||||
|
@ -266,6 +266,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(loss_scale.increment_period, 2)
|
||||
self.assertEqual(loss_scale.multiplier, 3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_update_with_none_gradients(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale()
|
||||
loss_scale.update([None])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_get(self):
|
||||
scalar = loss_scale_module.get('dynamic')
|
||||
|
@ -83,7 +83,8 @@ class LossScale(trackable.Trackable):
|
||||
Args:
|
||||
grads: A list of unscaled gradients, each which is the gradient of the
|
||||
loss with respect to a weight. The gradients should have already been
|
||||
divided by the loss scale being before passed to this function.
|
||||
divided by the loss scale being before passed to this function. 'None'
|
||||
gradients are accepted, and should be ignored.
|
||||
|
||||
Returns:
|
||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||
|
@ -254,8 +254,8 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_update_with_none_gradients(self):
|
||||
loss_scaler = loss_scale_module.DynamicLossScale()
|
||||
loss_scaler.update([None])
|
||||
loss_scale = loss_scale_module.DynamicLossScale()
|
||||
loss_scale.update([None])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_get(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user