Calculate only necessary gradients of captures in while loop
PiperOrigin-RevId: 308722418 Change-Id: Id6603a23e51f27eb7727d3df018db39614528d14
This commit is contained in:
parent
7cd479b8fc
commit
cfb0250c51
@ -39,8 +39,10 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -1366,6 +1368,30 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
|
||||
c, b, [i], return_same_structure=True, maximum_iterations=50)
|
||||
self.assertEqual(self.evaluate(r), [10])
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSkipsUnnecessaryCaptureGradients(self):
|
||||
@custom_gradient.custom_gradient
|
||||
def gradient_trap(t):
|
||||
def grad(w):
|
||||
# Computing this gradient should fail the test
|
||||
check_ops.assert_equal(0, 1)
|
||||
return w
|
||||
return t, grad
|
||||
|
||||
x = array_ops.constant(0.0, name="x")
|
||||
y = array_ops.constant(1.0, name="y")
|
||||
def cond(s):
|
||||
return s < 10.0
|
||||
def body(s):
|
||||
return s + 2*x + gradient_trap(y)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
out = control_flow_ops.while_loop(cond, body, (array_ops.constant(0.0),))
|
||||
|
||||
grad = tape.gradient(out, x)
|
||||
self.assertAllEqual(grad, 20.0)
|
||||
|
||||
|
||||
class AssertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -335,6 +335,14 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
while_op.outputs[:num_original_outputs])
|
||||
] + [None] * num_intermediates
|
||||
|
||||
# Skip gradients with respect to the captures whenever possible.
|
||||
if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
|
||||
captures_start_index = (
|
||||
len(body_graph.inputs) - len(body_graph.internal_captures))
|
||||
for i in op.skip_input_indices:
|
||||
if i >= captures_start_index:
|
||||
grads[i] = None
|
||||
|
||||
# We compute the gradient for the sub-graph between trainable ys and xs
|
||||
# with non-None incoming gradients. We later pad the None's to the list of
|
||||
# outputs.
|
||||
|
Loading…
Reference in New Issue
Block a user