Fix crash if tf.transpose used with tf.LossScaleGradientTape.
Thank you @benbarsdell for finding and debugging this issue PiperOrigin-RevId: 291014436 Change-Id: I8f82a5e35f0818c799332b29c3abcb35b5484b3d
This commit is contained in:
parent
82273f00d5
commit
93fbba3529
@ -270,7 +270,8 @@ def _compute_gradients_until_finite(
|
||||
def body(grads, ready_to_update, is_first_iteration):
|
||||
"""The body of the while loop."""
|
||||
del grads, ready_to_update, is_first_iteration
|
||||
def replica_fn(gradient_tape, target, flattened_sources, output_gradients):
|
||||
def replica_fn(gradient_tape, target, flattened_sources, output_gradients,
|
||||
initial_grads):
|
||||
"""Scales the loss, computes the gradients, and unscales the gradients."""
|
||||
loss_scale_val = loss_scale()
|
||||
with gradient_tape: # re-enter gradient tape so it sees the loss scaling
|
||||
@ -285,6 +286,12 @@ def _compute_gradients_until_finite(
|
||||
grads = [] # The unscaled gradients
|
||||
for g, initial_grad in zip(scaled_grads, initial_grads):
|
||||
if g is not None:
|
||||
# We call ensure_shape as shape information can be lost for certain
|
||||
# ops, such as tf.transpose, if the op is called in a tf.function and
|
||||
# has inputs created outside the tf.function.
|
||||
# TODO(b/132092188): Remove ensure_shape call after this has been
|
||||
# fixed.
|
||||
g = array_ops.ensure_shape(g, initial_grad.shape)
|
||||
grads.append(g * math_ops.cast(inv_loss_scale, g.dtype))
|
||||
else:
|
||||
# We cannot return None from a tf.while_loop, so we pass a dummy
|
||||
@ -297,7 +304,7 @@ def _compute_gradients_until_finite(
|
||||
# Switch to a replica-context to compute gradients once per replica.
|
||||
grads = distribution.experimental_run_v2(
|
||||
replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources,
|
||||
output_gradients))
|
||||
output_gradients, initial_grads))
|
||||
# Check for non-finite gradients possibly resulting from scaling.
|
||||
_, ready_to_update = loss_scale.update(grads)
|
||||
is_first_iteration = False
|
||||
@ -305,7 +312,8 @@ def _compute_gradients_until_finite(
|
||||
|
||||
grads, _, _ = control_flow_ops.while_loop(
|
||||
cond, body, [initial_grads, initial_ready_to_update,
|
||||
initial_is_first_iteration])
|
||||
initial_is_first_iteration],
|
||||
)
|
||||
grads = [None if is_none else g for g, is_none in zip(grads, is_nones)]
|
||||
grads = nest.pack_sequence_as(sources, grads)
|
||||
return grads
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_combinations
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -488,6 +489,33 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
for dy_dx in dy_dx_list:
|
||||
self.assertAllClose(self.evaluate(dy_dx), 3.0)
|
||||
|
||||
@test_combinations.generate(
|
||||
test_combinations.combine(
|
||||
loss_scale=[
|
||||
loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale
|
||||
],
|
||||
strategy_fn=[default_strategy_fn, create_mirrored_strategy],
|
||||
use_tf_function=[True, False]))
|
||||
def test_transpose(self, loss_scale, strategy_fn, use_tf_function):
|
||||
# Calling tf.transpose insde a tf.function can cause static shape
|
||||
# information to be lost. This tests that LossScaleGradientTape can handle
|
||||
# this.
|
||||
loss_scale = loss_scale(32)
|
||||
strategy = strategy_fn()
|
||||
with strategy.scope():
|
||||
x = variables.Variable(array_ops.ones((2, 3)))
|
||||
|
||||
def run_fn():
|
||||
with lsgt.LossScaleGradientTape(loss_scale) as g:
|
||||
y = array_ops.transpose(x) * 2.
|
||||
return g.gradient(y, x)
|
||||
|
||||
dy_dx_list = self._run_with_strategy(run_fn, strategy, use_tf_function)
|
||||
self.assertEqual(loss_scale(), 32)
|
||||
for dy_dx in dy_dx_list:
|
||||
self.assertAllEqual(dy_dx, np.full((2, 3), 2.))
|
||||
|
||||
def test_passing_non_loss_scale_raises_error(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
|
Loading…
Reference in New Issue
Block a user