Fix crash if tf.transpose used with tf.LossScaleGradientTape.

Thank you @benbarsdell for finding and debugging this issue

PiperOrigin-RevId: 291014436
Change-Id: I8f82a5e35f0818c799332b29c3abcb35b5484b3d
This commit is contained in:
Reed Wanderman-Milne 2020-01-22 13:07:39 -08:00 committed by TensorFlower Gardener
parent 82273f00d5
commit 93fbba3529
2 changed files with 39 additions and 3 deletions

View File

@ -270,7 +270,8 @@ def _compute_gradients_until_finite(
def body(grads, ready_to_update, is_first_iteration):
"""The body of the while loop."""
del grads, ready_to_update, is_first_iteration
def replica_fn(gradient_tape, target, flattened_sources, output_gradients):
def replica_fn(gradient_tape, target, flattened_sources, output_gradients,
initial_grads):
"""Scales the loss, computes the gradients, and unscales the gradients."""
loss_scale_val = loss_scale()
with gradient_tape: # re-enter gradient tape so it sees the loss scaling
@ -285,6 +286,12 @@ def _compute_gradients_until_finite(
grads = [] # The unscaled gradients
for g, initial_grad in zip(scaled_grads, initial_grads):
if g is not None:
# We call ensure_shape as shape information can be lost for certain
# ops, such as tf.transpose, if the op is called in a tf.function and
# has inputs created outside the tf.function.
# TODO(b/132092188): Remove ensure_shape call after this has been
# fixed.
g = array_ops.ensure_shape(g, initial_grad.shape)
grads.append(g * math_ops.cast(inv_loss_scale, g.dtype))
else:
# We cannot return None from a tf.while_loop, so we pass a dummy
@ -297,7 +304,7 @@ def _compute_gradients_until_finite(
# Switch to a replica-context to compute gradients once per replica.
grads = distribution.experimental_run_v2(
replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources,
output_gradients))
output_gradients, initial_grads))
# Check for non-finite gradients possibly resulting from scaling.
_, ready_to_update = loss_scale.update(grads)
is_first_iteration = False
@ -305,7 +312,8 @@ def _compute_gradients_until_finite(
grads, _, _ = control_flow_ops.while_loop(
cond, body, [initial_grads, initial_ready_to_update,
initial_is_first_iteration])
initial_is_first_iteration],
)
grads = [None if is_none else g for g, is_none in zip(grads, is_nones)]
grads = nest.pack_sequence_as(sources, grads)
return grads

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_combinations
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -488,6 +489,33 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
for dy_dx in dy_dx_list:
self.assertAllClose(self.evaluate(dy_dx), 3.0)
@test_combinations.generate(
test_combinations.combine(
loss_scale=[
loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale
],
strategy_fn=[default_strategy_fn, create_mirrored_strategy],
use_tf_function=[True, False]))
def test_transpose(self, loss_scale, strategy_fn, use_tf_function):
# Calling tf.transpose insde a tf.function can cause static shape
# information to be lost. This tests that LossScaleGradientTape can handle
# this.
loss_scale = loss_scale(32)
strategy = strategy_fn()
with strategy.scope():
x = variables.Variable(array_ops.ones((2, 3)))
def run_fn():
with lsgt.LossScaleGradientTape(loss_scale) as g:
y = array_ops.transpose(x) * 2.
return g.gradient(y, x)
dy_dx_list = self._run_with_strategy(run_fn, strategy, use_tf_function)
self.assertEqual(loss_scale(), 32)
for dy_dx in dy_dx_list:
self.assertAllEqual(dy_dx, np.full((2, 3), 2.))
def test_passing_non_loss_scale_raises_error(self):
with self.assertRaisesRegexp(
ValueError,