diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 267362dcba6..7c321329272 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -305,6 +305,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertEmpty(while_1.control_inputs) self.assertEmpty(while_2.control_inputs) + def testMultipleWhileLoopsGradStateless(self): + + @def_function.function + def Fn(): + x = constant_op.constant(2.) + with backprop.GradientTape() as tape: + tape.watch(x) + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * v, [x], + return_same_structure=False, + name="while_1") # x**2 + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * v, [x], + return_same_structure=False, + name="while_2") # x**4 + loss = ret1 + ret2 + return tape.gradient(loss, x) + + graph = Fn.get_concrete_function().graph + while_ops = [op for op in graph.get_operations() if "While" in op.type] + self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4, + "Must have exactly 4 StatelessWhile ops.") + for op in while_ops: + self.assertEmpty(op.control_inputs, + "{} should not have any control inputs".format(op.name)) + def testMultipleWhileLoopsWithDeps(self): x = variables.Variable(2.) c = constant_op.constant(2.) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index e0964cd4482..c0d982eede2 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -207,8 +207,6 @@ def while_loop(cond, num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) - cond_graph_captures = object_identity.ObjectIdentitySet( - cond_graph.external_captures) _duplicate_body_captures_in_cond( cond_graph, body_graph.external_captures[num_cond_captures:]) @@ -266,21 +264,10 @@ def while_loop(cond, output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] - cond_stateful_ops = [ - op for op in cond_graph.get_operations() if op._is_stateful - ] - body_stateful_ops = [ - op for op in body_graph.get_operations() if op._is_stateful - ] - if (cond_stateful_ops or body_stateful_ops): - op_fn = gen_functional_ops._while - else: - op_fn = gen_functional_ops.stateless_while - - outputs = op_fn( + outputs = _build_while_op( flattened_loop_vars, - util.create_new_tf_function(cond_graph), - util.create_new_tf_function(body_graph), + cond_graph, + body_graph, output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) @@ -406,10 +393,10 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) - outputs = gen_functional_ops._while( + outputs = _build_while_op( loop_vars, - util.create_new_tf_function(cond_grad_graph), - util.create_new_tf_function(body_grad_graph), + cond_grad_graph, + body_grad_graph, output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) @@ -424,6 +411,29 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name return _get_structured_grad_output(outputs, grads, body_grad_graph) +def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, + parallel_iterations, name): + """Builds the functional StatelessWhile/While op.""" + cond_stateful_ops = [ + op for op in cond_graph.get_operations() if op._is_stateful + ] + body_stateful_ops = [ + op for op in body_graph.get_operations() if op._is_stateful + ] + if (cond_stateful_ops or body_stateful_ops): + op_fn = gen_functional_ops._while + else: + op_fn = gen_functional_ops.stateless_while + + return op_fn( + loop_vars, + util.create_new_tf_function(cond_graph), + util.create_new_tf_function(body_graph), + output_shapes=output_shapes, + parallel_iterations=parallel_iterations, + name=name) + + def _get_intermediates(func_graph): """Returns all tensors in `func_graph` that should be accumulated.""" # We currently accumulate output tensors of most ops in the function and rely