Add single tensor gradient support for DynamicLossScale

PiperOrigin-RevId: 257209230
This commit is contained in:
A. Unique TensorFlower 2019-07-09 09:29:29 -07:00 committed by TensorFlower Gardener
parent 67a48c79a3
commit 45925ad3d0
2 changed files with 15 additions and 4 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -93,10 +94,10 @@ class LossScale(trackable.Trackable):
cross-replica context.
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and are ignored.
grads: A nested structure of unscaled gradients, each which is the
gradient of the loss with respect to a weight. The gradients should have
already been divided by the loss scale being before passed to this
function. 'None' gradients are accepted, and are ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
@ -328,6 +329,7 @@ class DynamicLossScale(LossScale):
def update(self, grads):
"""Updates loss scale based on if gradients are finite in current step."""
grads = nest.flatten(grads)
if distribution_strategy_context.has_strategy():
distribution = distribution_strategy_context.get_cross_replica_context()

View File

@ -264,6 +264,15 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_single_tensor_gradient(self, strategy_fn):
with strategy_fn().scope():
loss_scale = loss_scale_module.DynamicLossScale()
grad = constant_op.constant(4.0)
_, should_apply = loss_scale.update(grad)
self.assertTrue(self.evaluate(should_apply))
@test_util.run_in_graph_and_eager_modes
def test_serialization(self):
loss_scale = loss_scale_module.DynamicLossScale(