From 93fbba3529493c7d4bf436bb96e7c7cba045a639 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 22 Jan 2020 13:07:39 -0800 Subject: [PATCH] Fix crash if tf.transpose used with tf.LossScaleGradientTape. Thank you @benbarsdell for finding and debugging this issue PiperOrigin-RevId: 291014436 Change-Id: I8f82a5e35f0818c799332b29c3abcb35b5484b3d --- .../loss_scaling_gradient_tape.py | 14 ++++++++-- .../loss_scaling_gradient_tape_test.py | 28 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py b/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py index 356431358bb..43bed153bd6 100644 --- a/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py +++ b/tensorflow/python/training/experimental/loss_scaling_gradient_tape.py @@ -270,7 +270,8 @@ def _compute_gradients_until_finite( def body(grads, ready_to_update, is_first_iteration): """The body of the while loop.""" del grads, ready_to_update, is_first_iteration - def replica_fn(gradient_tape, target, flattened_sources, output_gradients): + def replica_fn(gradient_tape, target, flattened_sources, output_gradients, + initial_grads): """Scales the loss, computes the gradients, and unscales the gradients.""" loss_scale_val = loss_scale() with gradient_tape: # re-enter gradient tape so it sees the loss scaling @@ -285,6 +286,12 @@ def _compute_gradients_until_finite( grads = [] # The unscaled gradients for g, initial_grad in zip(scaled_grads, initial_grads): if g is not None: + # We call ensure_shape as shape information can be lost for certain + # ops, such as tf.transpose, if the op is called in a tf.function and + # has inputs created outside the tf.function. + # TODO(b/132092188): Remove ensure_shape call after this has been + # fixed. + g = array_ops.ensure_shape(g, initial_grad.shape) grads.append(g * math_ops.cast(inv_loss_scale, g.dtype)) else: # We cannot return None from a tf.while_loop, so we pass a dummy @@ -297,7 +304,7 @@ def _compute_gradients_until_finite( # Switch to a replica-context to compute gradients once per replica. grads = distribution.experimental_run_v2( replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources, - output_gradients)) + output_gradients, initial_grads)) # Check for non-finite gradients possibly resulting from scaling. _, ready_to_update = loss_scale.update(grads) is_first_iteration = False @@ -305,7 +312,8 @@ def _compute_gradients_until_finite( grads, _, _ = control_flow_ops.while_loop( cond, body, [initial_grads, initial_ready_to_update, - initial_is_first_iteration]) + initial_is_first_iteration], + ) grads = [None if is_none else g for g, is_none in zip(grads, is_nones)] grads = nest.pack_sequence_as(sources, grads) return grads diff --git a/tensorflow/python/training/experimental/loss_scaling_gradient_tape_test.py b/tensorflow/python/training/experimental/loss_scaling_gradient_tape_test.py index 4278d55d530..c1394a17307 100644 --- a/tensorflow/python/training/experimental/loss_scaling_gradient_tape_test.py +++ b/tensorflow/python/training/experimental/loss_scaling_gradient_tape_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_combinations from tensorflow.python.keras.mixed_precision.experimental import autocast_variable +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -488,6 +489,33 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase): for dy_dx in dy_dx_list: self.assertAllClose(self.evaluate(dy_dx), 3.0) + @test_combinations.generate( + test_combinations.combine( + loss_scale=[ + loss_scale_module.FixedLossScale, + loss_scale_module.DynamicLossScale + ], + strategy_fn=[default_strategy_fn, create_mirrored_strategy], + use_tf_function=[True, False])) + def test_transpose(self, loss_scale, strategy_fn, use_tf_function): + # Calling tf.transpose insde a tf.function can cause static shape + # information to be lost. This tests that LossScaleGradientTape can handle + # this. + loss_scale = loss_scale(32) + strategy = strategy_fn() + with strategy.scope(): + x = variables.Variable(array_ops.ones((2, 3))) + + def run_fn(): + with lsgt.LossScaleGradientTape(loss_scale) as g: + y = array_ops.transpose(x) * 2. + return g.gradient(y, x) + + dy_dx_list = self._run_with_strategy(run_fn, strategy, use_tf_function) + self.assertEqual(loss_scale(), 32) + for dy_dx in dy_dx_list: + self.assertAllEqual(dy_dx, np.full((2, 3), 2.)) + def test_passing_non_loss_scale_raises_error(self): with self.assertRaisesRegexp( ValueError,