Calculate only necessary gradients of captures in while loop

PiperOrigin-RevId: 308722418
Change-Id: Id6603a23e51f27eb7727d3df018db39614528d14
This commit is contained in:
A. Unique TensorFlower 2020-04-27 16:32:17 -07:00 committed by TensorFlower Gardener
parent 7cd479b8fc
commit cfb0250c51
2 changed files with 34 additions and 0 deletions

View File

@ -39,8 +39,10 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@ -1366,6 +1368,30 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
c, b, [i], return_same_structure=True, maximum_iterations=50)
self.assertEqual(self.evaluate(r), [10])
@test_util.enable_control_flow_v2
@test_util.run_in_graph_and_eager_modes
def testSkipsUnnecessaryCaptureGradients(self):
@custom_gradient.custom_gradient
def gradient_trap(t):
def grad(w):
# Computing this gradient should fail the test
check_ops.assert_equal(0, 1)
return w
return t, grad
x = array_ops.constant(0.0, name="x")
y = array_ops.constant(1.0, name="y")
def cond(s):
return s < 10.0
def body(s):
return s + 2*x + gradient_trap(y)
with backprop.GradientTape() as tape:
tape.watch(x)
out = control_flow_ops.while_loop(cond, body, (array_ops.constant(0.0),))
grad = tape.gradient(out, x)
self.assertAllEqual(grad, 20.0)
class AssertTest(test_util.TensorFlowTestCase):

View File

@ -335,6 +335,14 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
while_op.outputs[:num_original_outputs])
] + [None] * num_intermediates
# Skip gradients with respect to the captures whenever possible.
if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
captures_start_index = (
len(body_graph.inputs) - len(body_graph.internal_captures))
for i in op.skip_input_indices:
if i >= captures_start_index:
grads[i] = None
# We compute the gradient for the sub-graph between trainable ys and xs
# with non-None incoming gradients. We later pad the None's to the list of
# outputs.