Support passing IndexedSlices to the tf.while_loop gradient function.
PiperOrigin-RevId: 326753804 Change-Id: I5376273f7bbc8e5187f88c2ec2f96b9d87a6ab14
This commit is contained in:
parent
a9c09d46de
commit
2d1e9501e3
@ -1830,6 +1830,18 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
return grad_out
|
||||
self.assertAllEqual(F(), 8.0)
|
||||
|
||||
def testIndexedSlicesInIncomingGrads(self):
|
||||
@def_function.function
|
||||
def F():
|
||||
x = constant_op.constant([2.])
|
||||
# Computes x^4
|
||||
ret = while_loop_v2(
|
||||
lambda _: True, lambda v: v * v, [x], return_same_structure=False,
|
||||
maximum_iterations=2)
|
||||
v = array_ops.gather(ret, [0])
|
||||
return gradients_impl.gradients(v, [x])[0] # 4*x^3
|
||||
self.assertAllEqual(self.evaluate(F()), [32.])
|
||||
|
||||
|
||||
def ScalarShape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
@ -520,6 +520,12 @@ def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
|
||||
default_gradient.supports_default_grad(while_op_input) and grad is None):
|
||||
return _zeros_like(while_op_input, while_op_output)
|
||||
|
||||
# Convert IndexedSlices to dense tensors since it is unlikely that downstream
|
||||
# gradient functions with properly handle indexed slices. This is similar to
|
||||
# what we do in tf.function gradients.
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
return ops.convert_to_tensor(grad)
|
||||
|
||||
return grad
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user