Add single tensor gradient support for DynamicLossScale

PiperOrigin-RevId: 257209230
This commit is contained in:
A. Unique TensorFlower 2019-07-09 09:29:29 -07:00 committed by TensorFlower Gardener
parent 67a48c79a3
commit 45925ad3d0
2 changed files with 15 additions and 4 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -93,10 +94,10 @@ class LossScale(trackable.Trackable):
cross-replica context. cross-replica context.
Args: Args:
grads: A list of unscaled gradients, each which is the gradient of the grads: A nested structure of unscaled gradients, each which is the
loss with respect to a weight. The gradients should have already been gradient of the loss with respect to a weight. The gradients should have
divided by the loss scale being before passed to this function. 'None' already been divided by the loss scale being before passed to this
gradients are accepted, and are ignored. function. 'None' gradients are accepted, and are ignored.
Returns: Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss update_op: In eager mode, None. In graph mode, an op to update the loss
@ -328,6 +329,7 @@ class DynamicLossScale(LossScale):
def update(self, grads): def update(self, grads):
"""Updates loss scale based on if gradients are finite in current step.""" """Updates loss scale based on if gradients are finite in current step."""
grads = nest.flatten(grads)
if distribution_strategy_context.has_strategy(): if distribution_strategy_context.has_strategy():
distribution = distribution_strategy_context.get_cross_replica_context() distribution = distribution_strategy_context.get_cross_replica_context()

View File

@ -264,6 +264,15 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1] expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
self._test_helper(inputs, expected_outputs, init_loss_scale) self._test_helper(inputs, expected_outputs, init_loss_scale)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_single_tensor_gradient(self, strategy_fn):
with strategy_fn().scope():
loss_scale = loss_scale_module.DynamicLossScale()
grad = constant_op.constant(4.0)
_, should_apply = loss_scale.update(grad)
self.assertTrue(self.evaluate(should_apply))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def test_serialization(self): def test_serialization(self):
loss_scale = loss_scale_module.DynamicLossScale( loss_scale = loss_scale_module.DynamicLossScale(