From 45925ad3d0c749db7a08abee38c0200a14728c6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jul 2019 09:29:29 -0700 Subject: [PATCH] Add single tensor gradient support for DynamicLossScale PiperOrigin-RevId: 257209230 --- tensorflow/python/training/experimental/loss_scale.py | 10 ++++++---- .../python/training/experimental/loss_scale_test.py | 9 +++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index ebe80050b73..bbbd0cd7ec4 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.training.tracking import base as trackable from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -93,10 +94,10 @@ class LossScale(trackable.Trackable): cross-replica context. Args: - grads: A list of unscaled gradients, each which is the gradient of the - loss with respect to a weight. The gradients should have already been - divided by the loss scale being before passed to this function. 'None' - gradients are accepted, and are ignored. + grads: A nested structure of unscaled gradients, each which is the + gradient of the loss with respect to a weight. The gradients should have + already been divided by the loss scale being before passed to this + function. 'None' gradients are accepted, and are ignored. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss @@ -328,6 +329,7 @@ class DynamicLossScale(LossScale): def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" + grads = nest.flatten(grads) if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context() diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index 7891156539e..c3e18a18422 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -264,6 +264,15 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1] self._test_helper(inputs, expected_outputs, init_loss_scale) + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_single_tensor_gradient(self, strategy_fn): + with strategy_fn().scope(): + loss_scale = loss_scale_module.DynamicLossScale() + grad = constant_op.constant(4.0) + _, should_apply = loss_scale.update(grad) + self.assertTrue(self.evaluate(should_apply)) + @test_util.run_in_graph_and_eager_modes def test_serialization(self): loss_scale = loss_scale_module.DynamicLossScale(