Emit StatelessWhile for gradient while op if the cond and body functions are stateless.
We already emit the StatelessWhile op for the forward pass. PiperOrigin-RevId: 290967707 Change-Id: I6d22ce36874f4191131e55902fb4404ebb3402a2
This commit is contained in:
parent
466e818c1e
commit
5708850f62
@ -305,6 +305,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertEmpty(while_2.control_inputs)
|
||||
|
||||
def testMultipleWhileLoopsGradStateless(self):
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
x = constant_op.constant(2.)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_1") # x**2
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_2") # x**4
|
||||
loss = ret1 + ret2
|
||||
return tape.gradient(loss, x)
|
||||
|
||||
graph = Fn.get_concrete_function().graph
|
||||
while_ops = [op for op in graph.get_operations() if "While" in op.type]
|
||||
self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
|
||||
"Must have exactly 4 StatelessWhile ops.")
|
||||
for op in while_ops:
|
||||
self.assertEmpty(op.control_inputs,
|
||||
"{} should not have any control inputs".format(op.name))
|
||||
|
||||
def testMultipleWhileLoopsWithDeps(self):
|
||||
x = variables.Variable(2.)
|
||||
c = constant_op.constant(2.)
|
||||
|
@ -207,8 +207,6 @@ def while_loop(cond,
|
||||
num_cond_captures = len(cond_graph.external_captures)
|
||||
assert (cond_graph.external_captures ==
|
||||
body_graph.external_captures[:num_cond_captures])
|
||||
cond_graph_captures = object_identity.ObjectIdentitySet(
|
||||
cond_graph.external_captures)
|
||||
_duplicate_body_captures_in_cond(
|
||||
cond_graph, body_graph.external_captures[num_cond_captures:])
|
||||
|
||||
@ -266,21 +264,10 @@ def while_loop(cond,
|
||||
output_shapes[orig_loop_vars_range] = nest.flatten(
|
||||
shape_invariants, expand_composites=True)[orig_loop_vars_range]
|
||||
|
||||
cond_stateful_ops = [
|
||||
op for op in cond_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
body_stateful_ops = [
|
||||
op for op in body_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
if (cond_stateful_ops or body_stateful_ops):
|
||||
op_fn = gen_functional_ops._while
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_while
|
||||
|
||||
outputs = op_fn(
|
||||
outputs = _build_while_op(
|
||||
flattened_loop_vars,
|
||||
util.create_new_tf_function(cond_graph),
|
||||
util.create_new_tf_function(body_graph),
|
||||
cond_graph,
|
||||
body_graph,
|
||||
output_shapes=output_shapes,
|
||||
parallel_iterations=parallel_iterations,
|
||||
name=scope)
|
||||
@ -406,10 +393,10 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
|
||||
_check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
|
||||
|
||||
outputs = gen_functional_ops._while(
|
||||
outputs = _build_while_op(
|
||||
loop_vars,
|
||||
util.create_new_tf_function(cond_grad_graph),
|
||||
util.create_new_tf_function(body_grad_graph),
|
||||
cond_grad_graph,
|
||||
body_grad_graph,
|
||||
output_shapes=[t.shape for t in body_grad_graph.outputs],
|
||||
parallel_iterations=parallel_iterations,
|
||||
name="%s_grad" % while_op.name)
|
||||
@ -424,6 +411,29 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
return _get_structured_grad_output(outputs, grads, body_grad_graph)
|
||||
|
||||
|
||||
def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
|
||||
parallel_iterations, name):
|
||||
"""Builds the functional StatelessWhile/While op."""
|
||||
cond_stateful_ops = [
|
||||
op for op in cond_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
body_stateful_ops = [
|
||||
op for op in body_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
if (cond_stateful_ops or body_stateful_ops):
|
||||
op_fn = gen_functional_ops._while
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_while
|
||||
|
||||
return op_fn(
|
||||
loop_vars,
|
||||
util.create_new_tf_function(cond_graph),
|
||||
util.create_new_tf_function(body_graph),
|
||||
output_shapes=output_shapes,
|
||||
parallel_iterations=parallel_iterations,
|
||||
name=name)
|
||||
|
||||
|
||||
def _get_intermediates(func_graph):
|
||||
"""Returns all tensors in `func_graph` that should be accumulated."""
|
||||
# We currently accumulate output tensors of most ops in the function and rely
|
||||
|
Loading…
Reference in New Issue
Block a user