Add single tensor gradient support for DynamicLossScale
PiperOrigin-RevId: 257209230
This commit is contained in:
parent
67a48c79a3
commit
45925ad3d0
@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -93,10 +94,10 @@ class LossScale(trackable.Trackable):
|
||||
cross-replica context.
|
||||
|
||||
Args:
|
||||
grads: A list of unscaled gradients, each which is the gradient of the
|
||||
loss with respect to a weight. The gradients should have already been
|
||||
divided by the loss scale being before passed to this function. 'None'
|
||||
gradients are accepted, and are ignored.
|
||||
grads: A nested structure of unscaled gradients, each which is the
|
||||
gradient of the loss with respect to a weight. The gradients should have
|
||||
already been divided by the loss scale being before passed to this
|
||||
function. 'None' gradients are accepted, and are ignored.
|
||||
|
||||
Returns:
|
||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||
@ -328,6 +329,7 @@ class DynamicLossScale(LossScale):
|
||||
|
||||
def update(self, grads):
|
||||
"""Updates loss scale based on if gradients are finite in current step."""
|
||||
grads = nest.flatten(grads)
|
||||
if distribution_strategy_context.has_strategy():
|
||||
distribution = distribution_strategy_context.get_cross_replica_context()
|
||||
|
||||
|
@ -264,6 +264,15 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
|
||||
self._test_helper(inputs, expected_outputs, init_loss_scale)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_single_tensor_gradient(self, strategy_fn):
|
||||
with strategy_fn().scope():
|
||||
loss_scale = loss_scale_module.DynamicLossScale()
|
||||
grad = constant_op.constant(4.0)
|
||||
_, should_apply = loss_scale.update(grad)
|
||||
self.assertTrue(self.evaluate(should_apply))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialization(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
|
Loading…
Reference in New Issue
Block a user