diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py index e72983ee491..dfb69d3e4d1 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py @@ -79,7 +79,8 @@ class LossScale(trackable.Trackable): Args: grads: A list of unscaled gradients, each which is the gradient of the loss with respect to a weight. The gradients should have already been - divided by the loss scale being before passed to this function. + divided by the loss scale being before passed to this function. 'None' + gradients are accepted, and are ignored. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss @@ -182,8 +183,9 @@ class FixedLossScale(LossScale): def _is_all_finite(grads): """Returns a scalar boolean tensor indicating if all gradients are finite.""" - is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g)) - for g in grads] + is_finite_per_grad = [ + math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None + ] return math_ops.reduce_all(is_finite_per_grad) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py index f8dee5203bb..25622e2fff8 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py @@ -266,6 +266,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): self.assertEqual(loss_scale.increment_period, 2) self.assertEqual(loss_scale.multiplier, 3) + @test_util.run_in_graph_and_eager_modes + def test_update_with_none_gradients(self): + loss_scale = loss_scale_module.DynamicLossScale() + loss_scale.update([None]) + @test_util.run_in_graph_and_eager_modes def test_get(self): scalar = loss_scale_module.get('dynamic') diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index c833241e26a..47f0bc67a03 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -83,7 +83,8 @@ class LossScale(trackable.Trackable): Args: grads: A list of unscaled gradients, each which is the gradient of the loss with respect to a weight. The gradients should have already been - divided by the loss scale being before passed to this function. + divided by the loss scale being before passed to this function. 'None' + gradients are accepted, and are ignored. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss @@ -196,7 +197,7 @@ class FixedLossScale(LossScale): def _is_all_finite(grads): """Returns a scalar boolean tensor indicating if all gradients are finite.""" is_finite_per_grad = [ - math_ops.reduce_all(math_ops.is_finite(g)) for g in grads + math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None ] return math_ops.reduce_all(is_finite_per_grad) diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index f135161de07..d7d52f0050a 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -252,6 +252,11 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1] self._test_helper(inputs, expected_outputs, init_loss_scale) + @test_util.run_in_graph_and_eager_modes + def test_update_with_none_gradients(self): + loss_scale = loss_scale_module.DynamicLossScale() + loss_scale.update([None]) + @test_util.run_in_graph_and_eager_modes def test_get(self): scalar = loss_scale_module.get('dynamic')