Support passing IndexedSlices to the tf.while_loop gradient function.

PiperOrigin-RevId: 326753804
Change-Id: I5376273f7bbc8e5187f88c2ec2f96b9d87a6ab14
This commit is contained in:
Saurabh Saxena 2020-08-14 16:45:04 -07:00 committed by TensorFlower Gardener
parent a9c09d46de
commit 2d1e9501e3
2 changed files with 18 additions and 0 deletions

View File

@ -1830,6 +1830,18 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
return grad_out
self.assertAllEqual(F(), 8.0)
def testIndexedSlicesInIncomingGrads(self):
@def_function.function
def F():
x = constant_op.constant([2.])
# Computes x^4
ret = while_loop_v2(
lambda _: True, lambda v: v * v, [x], return_same_structure=False,
maximum_iterations=2)
v = array_ops.gather(ret, [0])
return gradients_impl.gradients(v, [x])[0] # 4*x^3
self.assertAllEqual(self.evaluate(F()), [32.])
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -520,6 +520,12 @@ def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
default_gradient.supports_default_grad(while_op_input) and grad is None):
return _zeros_like(while_op_input, while_op_output)
# Convert IndexedSlices to dense tensors since it is unlikely that downstream
# gradient functions with properly handle indexed slices. This is similar to
# what we do in tf.function gradients.
if isinstance(grad, ops.IndexedSlices):
return ops.convert_to_tensor(grad)
return grad