Emit StatelessWhile for gradient while op if the cond and body functions are stateless.

We already emit the StatelessWhile op for the forward pass.

PiperOrigin-RevId: 290967707
Change-Id: I6d22ce36874f4191131e55902fb4404ebb3402a2
This commit is contained in:
Saurabh Saxena 2020-01-22 09:22:26 -08:00 committed by TensorFlower Gardener
parent 466e818c1e
commit 5708850f62
2 changed files with 57 additions and 19 deletions

View File

@ -305,6 +305,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertEmpty(while_1.control_inputs)
self.assertEmpty(while_2.control_inputs)
def testMultipleWhileLoopsGradStateless(self):
@def_function.function
def Fn():
x = constant_op.constant(2.)
with backprop.GradientTape() as tape:
tape.watch(x)
ret1 = while_loop_v2(
lambda v: v < 4.,
lambda v: v * v, [x],
return_same_structure=False,
name="while_1") # x**2
ret2 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * v, [x],
return_same_structure=False,
name="while_2") # x**4
loss = ret1 + ret2
return tape.gradient(loss, x)
graph = Fn.get_concrete_function().graph
while_ops = [op for op in graph.get_operations() if "While" in op.type]
self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
"Must have exactly 4 StatelessWhile ops.")
for op in while_ops:
self.assertEmpty(op.control_inputs,
"{} should not have any control inputs".format(op.name))
def testMultipleWhileLoopsWithDeps(self):
x = variables.Variable(2.)
c = constant_op.constant(2.)

View File

@ -207,8 +207,6 @@ def while_loop(cond,
num_cond_captures = len(cond_graph.external_captures)
assert (cond_graph.external_captures ==
body_graph.external_captures[:num_cond_captures])
cond_graph_captures = object_identity.ObjectIdentitySet(
cond_graph.external_captures)
_duplicate_body_captures_in_cond(
cond_graph, body_graph.external_captures[num_cond_captures:])
@ -266,21 +264,10 @@ def while_loop(cond,
output_shapes[orig_loop_vars_range] = nest.flatten(
shape_invariants, expand_composites=True)[orig_loop_vars_range]
cond_stateful_ops = [
op for op in cond_graph.get_operations() if op._is_stateful
]
body_stateful_ops = [
op for op in body_graph.get_operations() if op._is_stateful
]
if (cond_stateful_ops or body_stateful_ops):
op_fn = gen_functional_ops._while
else:
op_fn = gen_functional_ops.stateless_while
outputs = op_fn(
outputs = _build_while_op(
flattened_loop_vars,
util.create_new_tf_function(cond_graph),
util.create_new_tf_function(body_graph),
cond_graph,
body_graph,
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=scope)
@ -406,10 +393,10 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
_check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
outputs = gen_functional_ops._while(
outputs = _build_while_op(
loop_vars,
util.create_new_tf_function(cond_grad_graph),
util.create_new_tf_function(body_grad_graph),
cond_grad_graph,
body_grad_graph,
output_shapes=[t.shape for t in body_grad_graph.outputs],
parallel_iterations=parallel_iterations,
name="%s_grad" % while_op.name)
@ -424,6 +411,29 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
parallel_iterations, name):
"""Builds the functional StatelessWhile/While op."""
cond_stateful_ops = [
op for op in cond_graph.get_operations() if op._is_stateful
]
body_stateful_ops = [
op for op in body_graph.get_operations() if op._is_stateful
]
if (cond_stateful_ops or body_stateful_ops):
op_fn = gen_functional_ops._while
else:
op_fn = gen_functional_ops.stateless_while
return op_fn(
loop_vars,
util.create_new_tf_function(cond_graph),
util.create_new_tf_function(body_graph),
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=name)
def _get_intermediates(func_graph):
"""Returns all tensors in `func_graph` that should be accumulated."""
# We currently accumulate output tensors of most ops in the function and rely