Emit StatelessWhile for gradient while op if the cond and body functions are stateless.
We already emit the StatelessWhile op for the forward pass. PiperOrigin-RevId: 290967707 Change-Id: I6d22ce36874f4191131e55902fb4404ebb3402a2
This commit is contained in:
parent
466e818c1e
commit
5708850f62
@ -305,6 +305,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEmpty(while_1.control_inputs)
|
self.assertEmpty(while_1.control_inputs)
|
||||||
self.assertEmpty(while_2.control_inputs)
|
self.assertEmpty(while_2.control_inputs)
|
||||||
|
|
||||||
|
def testMultipleWhileLoopsGradStateless(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def Fn():
|
||||||
|
x = constant_op.constant(2.)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(x)
|
||||||
|
ret1 = while_loop_v2(
|
||||||
|
lambda v: v < 4.,
|
||||||
|
lambda v: v * v, [x],
|
||||||
|
return_same_structure=False,
|
||||||
|
name="while_1") # x**2
|
||||||
|
ret2 = while_loop_v2(
|
||||||
|
lambda v: v < 16.,
|
||||||
|
lambda v: v * v, [x],
|
||||||
|
return_same_structure=False,
|
||||||
|
name="while_2") # x**4
|
||||||
|
loss = ret1 + ret2
|
||||||
|
return tape.gradient(loss, x)
|
||||||
|
|
||||||
|
graph = Fn.get_concrete_function().graph
|
||||||
|
while_ops = [op for op in graph.get_operations() if "While" in op.type]
|
||||||
|
self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
|
||||||
|
"Must have exactly 4 StatelessWhile ops.")
|
||||||
|
for op in while_ops:
|
||||||
|
self.assertEmpty(op.control_inputs,
|
||||||
|
"{} should not have any control inputs".format(op.name))
|
||||||
|
|
||||||
def testMultipleWhileLoopsWithDeps(self):
|
def testMultipleWhileLoopsWithDeps(self):
|
||||||
x = variables.Variable(2.)
|
x = variables.Variable(2.)
|
||||||
c = constant_op.constant(2.)
|
c = constant_op.constant(2.)
|
||||||
|
@ -207,8 +207,6 @@ def while_loop(cond,
|
|||||||
num_cond_captures = len(cond_graph.external_captures)
|
num_cond_captures = len(cond_graph.external_captures)
|
||||||
assert (cond_graph.external_captures ==
|
assert (cond_graph.external_captures ==
|
||||||
body_graph.external_captures[:num_cond_captures])
|
body_graph.external_captures[:num_cond_captures])
|
||||||
cond_graph_captures = object_identity.ObjectIdentitySet(
|
|
||||||
cond_graph.external_captures)
|
|
||||||
_duplicate_body_captures_in_cond(
|
_duplicate_body_captures_in_cond(
|
||||||
cond_graph, body_graph.external_captures[num_cond_captures:])
|
cond_graph, body_graph.external_captures[num_cond_captures:])
|
||||||
|
|
||||||
@ -266,21 +264,10 @@ def while_loop(cond,
|
|||||||
output_shapes[orig_loop_vars_range] = nest.flatten(
|
output_shapes[orig_loop_vars_range] = nest.flatten(
|
||||||
shape_invariants, expand_composites=True)[orig_loop_vars_range]
|
shape_invariants, expand_composites=True)[orig_loop_vars_range]
|
||||||
|
|
||||||
cond_stateful_ops = [
|
outputs = _build_while_op(
|
||||||
op for op in cond_graph.get_operations() if op._is_stateful
|
|
||||||
]
|
|
||||||
body_stateful_ops = [
|
|
||||||
op for op in body_graph.get_operations() if op._is_stateful
|
|
||||||
]
|
|
||||||
if (cond_stateful_ops or body_stateful_ops):
|
|
||||||
op_fn = gen_functional_ops._while
|
|
||||||
else:
|
|
||||||
op_fn = gen_functional_ops.stateless_while
|
|
||||||
|
|
||||||
outputs = op_fn(
|
|
||||||
flattened_loop_vars,
|
flattened_loop_vars,
|
||||||
util.create_new_tf_function(cond_graph),
|
cond_graph,
|
||||||
util.create_new_tf_function(body_graph),
|
body_graph,
|
||||||
output_shapes=output_shapes,
|
output_shapes=output_shapes,
|
||||||
parallel_iterations=parallel_iterations,
|
parallel_iterations=parallel_iterations,
|
||||||
name=scope)
|
name=scope)
|
||||||
@ -406,10 +393,10 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
|
|
||||||
_check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
|
_check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
|
||||||
|
|
||||||
outputs = gen_functional_ops._while(
|
outputs = _build_while_op(
|
||||||
loop_vars,
|
loop_vars,
|
||||||
util.create_new_tf_function(cond_grad_graph),
|
cond_grad_graph,
|
||||||
util.create_new_tf_function(body_grad_graph),
|
body_grad_graph,
|
||||||
output_shapes=[t.shape for t in body_grad_graph.outputs],
|
output_shapes=[t.shape for t in body_grad_graph.outputs],
|
||||||
parallel_iterations=parallel_iterations,
|
parallel_iterations=parallel_iterations,
|
||||||
name="%s_grad" % while_op.name)
|
name="%s_grad" % while_op.name)
|
||||||
@ -424,6 +411,29 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
return _get_structured_grad_output(outputs, grads, body_grad_graph)
|
return _get_structured_grad_output(outputs, grads, body_grad_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
|
||||||
|
parallel_iterations, name):
|
||||||
|
"""Builds the functional StatelessWhile/While op."""
|
||||||
|
cond_stateful_ops = [
|
||||||
|
op for op in cond_graph.get_operations() if op._is_stateful
|
||||||
|
]
|
||||||
|
body_stateful_ops = [
|
||||||
|
op for op in body_graph.get_operations() if op._is_stateful
|
||||||
|
]
|
||||||
|
if (cond_stateful_ops or body_stateful_ops):
|
||||||
|
op_fn = gen_functional_ops._while
|
||||||
|
else:
|
||||||
|
op_fn = gen_functional_ops.stateless_while
|
||||||
|
|
||||||
|
return op_fn(
|
||||||
|
loop_vars,
|
||||||
|
util.create_new_tf_function(cond_graph),
|
||||||
|
util.create_new_tf_function(body_graph),
|
||||||
|
output_shapes=output_shapes,
|
||||||
|
parallel_iterations=parallel_iterations,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
def _get_intermediates(func_graph):
|
def _get_intermediates(func_graph):
|
||||||
"""Returns all tensors in `func_graph` that should be accumulated."""
|
"""Returns all tensors in `func_graph` that should be accumulated."""
|
||||||
# We currently accumulate output tensors of most ops in the function and rely
|
# We currently accumulate output tensors of most ops in the function and rely
|
||||||
|
Loading…
Reference in New Issue
Block a user