From 170a95de67f266c9fd7fea3ceedc5a7ecb0c80c3 Mon Sep 17 00:00:00 2001 From: Yanhua Sun Date: Wed, 31 Jul 2019 08:05:16 -0700 Subject: [PATCH] In while_v2 emit a StatelessIf op if the body is stateless. PiperOrigin-RevId: 260927755 --- .../compiler/jit/compilability_check_util.cc | 2 +- .../jit/extract_outside_compilation_pass.cc | 8 +- .../compiler/jit/mark_for_compilation_pass.cc | 3 +- tensorflow/compiler/tf2xla/const_analysis.cc | 2 +- .../tf2xla/rearrange_function_argument.cc | 2 +- .../compiler/tf2xla/side_effect_util.cc | 2 +- tensorflow/compiler/tf2xla/tf2xla_util.cc | 7 +- .../common_runtime/lower_functional_ops.cc | 2 +- tensorflow/core/graph/graph.cc | 4 +- tensorflow/core/graph/graph.h | 2 + .../grappler/optimizers/function_optimizer.cc | 2 +- tensorflow/core/ops/functional_ops.cc | 59 ++++--- tensorflow/lite/python/util_test.py | 2 +- tensorflow/python/framework/ops_test.py | 46 +++--- .../python/kernel_tests/while_v2_test.py | 150 +++++++++++++++--- tensorflow/python/ops/gradients_util.py | 3 +- tensorflow/python/ops/while_v2.py | 20 ++- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 19 files changed, 227 insertions(+), 93 deletions(-) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 5e3b93d30e5..049a38976ee 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -318,7 +318,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } - if (node.type_string() == "While" && + if (node.IsWhileNode() && !IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) { LogNotCompilable(node, "unsupported while"); return false; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 85fb69b620d..05b1e6626e5 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -440,7 +440,7 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr(attr_name); n->AddAttr(attr_name, branch_func); } - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { for (const string& attr_name : std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); @@ -595,7 +595,7 @@ void ReplaceLiftedArgNodePlaceholderWithArg( Status PostprocessLiftedArgsForWhile( const std::unordered_map& outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { - TF_RET_CHECK(n->type_string() == "While"); + TF_RET_CHECK(n->IsWhileNode()); // Check if there is any lifted args in body function. NameAttrList body_func; @@ -936,7 +936,7 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { continue; } - if (n->type_string() == "While") { + if (n->IsWhileNode()) { TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile( outside_compilation_attr_to_node, g, n, fld)); } @@ -1782,7 +1782,7 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( for (Node* n : g->nodes()) { if (n->IsIfNode()) { if_nodes.push_back(n); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { while_nodes.push_back(n); } else if (IsFunctionCall(*fld, *n)) { func_call_nodes.push_back(n); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 91423f63d28..41a2bf6d964 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -970,8 +970,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { int effective_cluster_size = (node->IsIdentity() || node->IsConstant()) ? 0 : 1; - bool has_functional_control_flow = - node->type_string() == "While" || node->IsIfNode(); + bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode(); absl::optional deadness_predicate; if (deadness_analysis_) { diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index ad2cc7b32f0..48513a43fb3 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -91,7 +91,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, FunctionLibraryRuntime* flib_runtime) { DCHECK(op_def != nullptr || op_kernel != nullptr); // TODO(b/124403063): Implement similar functionality for function call nodes. - if (node.op() == "While") { + if (node.op() == "While" || node.op() == "StatelessWhile") { // For While nodes, recurse into the body and cond graphs. const FunctionBody* fcond = nullptr; const FunctionBody* fbody = nullptr; diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index b376fe94743..b6f8928f31e 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -527,7 +527,7 @@ Status RearrangeFunctionArguments( // Rewrite If/While nodes. for (Node* n : g->nodes()) { - if (n->type_string() == "While") { + if (n->IsWhileNode()) { bool node_rewritten; TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld, &node_rewritten)); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index eebeec87b60..fb8b4815be2 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -50,7 +50,7 @@ Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { node->ClearAttr(attr_name); node->AddAttr(attr_name, branch_func); } - } else if (node->type_string() == "While") { + } else if (node->IsWhileNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); for (const string& attr_name : std::vector{"cond", "body"}) { diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 3e8b9eb79d8..e82546def46 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -765,7 +765,7 @@ Status PropagateConstIntoFunctionalNodes( for (Node* n : g->op_nodes()) { if (n->IsIfNode()) { TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld)); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld)); } } @@ -796,7 +796,7 @@ Status RewriteTensorListWithConstElement(Graph* g, // Find the forward While op. std::vector fwd_while_edges; for (const Edge* e : n->out_edges()) { - if (!e->IsControlEdge() && e->dst()->type_string() == "While") { + if (!e->IsControlEdge() && e->dst()->IsWhileNode()) { fwd_while_edges.push_back(e); } } @@ -810,8 +810,7 @@ Status RewriteTensorListWithConstElement(Graph* g, int fwd_while_dst_input = fwd_while_edges[0]->dst_input(); std::vector bwd_while_edges; for (const Edge* e : fwd_while->out_edges()) { - if (e->src_output() == fwd_while_dst_input && - e->dst()->type_string() == "While") { + if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) { bwd_while_edges.push_back(e); } } diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 2b8d941a295..30bec353da9 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -143,7 +143,7 @@ Status LowerFunctionalOpsPass::Run( } else if (n->type_string() == "Case") { TF_RETURN_IF_ERROR( RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { TF_RETURN_IF_ERROR( RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); } else { diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index cc8e18a685d..b2137020c77 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -90,6 +90,8 @@ const std::unordered_map& Node::kNodeClassTable = {"StatefulPartitionedCall", NC_PARTITIONED_CALL}, {"If", NC_IF}, {"StatelessIf", NC_IF}, + {"While", NC_WHILE}, + {"StatelessWhile", NC_WHILE}, // Not using the constants defined in FunctionLibraryDefinition for the // 4 ops below because android inference library does not link // tf.function related files. @@ -592,7 +594,7 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, } Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) { - if (dst->type_string() != "While") { + if (!dst->IsWhileNode()) { return errors::Internal( "dst argument to AddWhileEdgeHack should be a While op, got: ", dst->DebugString()); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 1d9a45b562e..b4343c9ee98 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -177,6 +177,7 @@ class Node { bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; } bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; } bool IsIfNode() const { return class_ == NC_IF; } + bool IsWhileNode() const { return class_ == NC_WHILE; } // Is this node a function input bool IsArg() const { return class_ == NC_ARG; } // Is this node a function output @@ -264,6 +265,7 @@ class Node { NC_FAKE_PARAM, NC_PARTITIONED_CALL, NC_IF, + NC_WHILE, NC_ARG, NC_RETVAL, NC_OTHER // Not a special kind of node diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index ca8f7a2e05f..012431b491b 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -1219,7 +1219,7 @@ Status InlineFunctionCalls(const GrapplerItem& item, TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false)); } else if (n->type_string() == "Case") { TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false)); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false)); } continue; diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 8f1ac77af7d..f5f7244d306 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -195,6 +195,31 @@ body: A function that takes a list of tensors and returns another by T. )doc"); +Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) { + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else use the input shapes. + if (!output_shapes.empty()) { + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + } else { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } + } + return Status::OK(); +} + REGISTER_OP("While") .Input("input: T") .Output("output: T") @@ -204,30 +229,7 @@ REGISTER_OP("While") .Attr("output_shapes: list(shape) = []") .Attr("parallel_iterations: int = 10") .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - // If `output_shapes` attr is set use that as the shapes of the outputs - // else use the input shapes. - if (!output_shapes.empty()) { - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as num outputs (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - } else { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(i)); - } - } - return Status::OK(); - }); + .SetShapeFn(WhileShapeInferenceFn); REGISTER_OP("StatelessWhile") .Input("input: T") @@ -235,12 +237,9 @@ REGISTER_OP("StatelessWhile") .Attr("T: list(type) >= 0") .Attr("cond: func") .Attr("body: func") - .SetShapeFn([](shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(i)); - } - return Status::OK(); - }); + .Attr("output_shapes: list(shape) = []") + .Attr("parallel_iterations: int = 10") + .SetShapeFn(WhileShapeInferenceFn); REGISTER_OP("For") .Input("start: int32") diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index f13fad5e821..05e402f01d2 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -80,7 +80,7 @@ class UtilTest(test_util.TensorFlowTestCase): sess.graph_def) lower_using_switch_merge_is_removed = False for node in new_graph_def.node: - if node.op == "While": + if node.op == "While" or node.op == "StatelessWhile": if not node.attr["_lower_using_switch_merge"].b: lower_using_switch_merge_is_removed = True self.assertEqual(lower_using_switch_merge_is_removed, True) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 0495c9d5be5..865294073eb 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -29,6 +29,7 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.client import session +from tensorflow.python.compat import compat as forward_compat from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -793,33 +794,34 @@ class OperationTest(test_util.TensorFlowTestCase): @test_util.enable_control_flow_v2 @test_util.run_v1_only("b/120545219") def testAddWhileInput(self): - @eager_function.defun - def test(): - output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, - [1]) - while_op = output.op.inputs[0].op - self.assertEqual(while_op.type, "While") - orig_num_inputs = len(while_op.inputs) + if forward_compat.forward_compatible(2019, 8, 23): + @eager_function.defun + def test(): + output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, + [1]) + while_op = output.op.inputs[0].op + self.assertEqual(while_op.type, "StatelessWhile") + orig_num_inputs = len(while_op.inputs) - # Make sure we can handle the while op having a control input. - while_op._add_control_input(constant_op.constant(0).op) + # Make sure we can handle the while op having a control input. + while_op._add_control_input(constant_op.constant(0).op) - new_input1 = constant_op.constant(1.0) - new_input2 = constant_op.constant(True) + new_input1 = constant_op.constant(1.0) + new_input2 = constant_op.constant(True) - # Clear output shapes to bypass shape checking. - while_op._set_shape_list_attr("output_shapes", []) - while_op._set_type_list_attr("T", - [t.dtype for t in while_op.inputs] + - [new_input1.dtype, new_input2.dtype]) + # Clear output shapes to bypass shape checking. + while_op._set_shape_list_attr("output_shapes", []) + while_op._set_type_list_attr("T", + [t.dtype for t in while_op.inputs] + + [new_input1.dtype, new_input2.dtype]) - while_op._add_while_inputs([new_input1, new_input2]) - # Can't add an edge beyond what's specified by "T" - with self.assertRaises(errors.OutOfRangeError): - while_op._add_while_inputs([new_input2]) - self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert + while_op._add_while_inputs([new_input1, new_input2]) + # Can't add an edge beyond what's specified by "T" + with self.assertRaises(errors.OutOfRangeError): + while_op._add_while_inputs([new_input2]) + self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert - test() + test() @test_util.run_deprecated_v1 def testOpDef(self): diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 8f62bd6d90f..4a44b9ee8d2 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -22,10 +22,9 @@ from absl.testing import parameterized from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.compat import compat from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import random_ops from tensorflow.python.eager import def_function @@ -36,6 +35,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn @@ -196,6 +197,115 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertSequenceEqual(self.evaluate(grad), [32.]) self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) + def testMultipleWhileLoopsWithFunc(self): + if compat.forward_compatible(2019, 8, 23): + x = constant_op.constant(2.) + + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * v, [x], + return_same_structure=False, + name="while_1") # x**2 + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * v, [x], + return_same_structure=False, + name="while_2") # x**4 + return ret1, ret2 + + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + self.assertEqual(while_1.type, "StatelessWhile") + self.assertEqual(while_2.type, "StatelessWhile") + self.assertEmpty(while_1.control_inputs) + self.assertEmpty(while_2.control_inputs) + + def testMultipleWhileLoopsWithDeps(self): + if compat.forward_compatible(2019, 8, 23): + x = variables.Variable(2.) + c = constant_op.constant(2.) + + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x, [c], + return_same_structure=False, + name="while_1") # 2x + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x * x, [c], + return_same_structure=False, + name="while_2") # 4x + return ret1, ret2 + + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + self.assertEqual(while_1.type, "While") + self.assertEqual(while_2.type, "While") + self.assertEmpty(while_1.control_inputs) + self.assertLen(while_2.control_inputs, 1) + self.assertIs(while_2.control_inputs[0], while_1) + + def testMultipleWhileLoopsWithVarsDeps(self): + if compat.forward_compatible(2019, 8, 23): + x1 = variables.Variable(2.) + x2 = variables.Variable(3.) + c = constant_op.constant(2.) + + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x1, [c], + return_same_structure=False, + name="while_1") # 2x + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x1 * x1, [c], + return_same_structure=False, + name="while_2") # 4x + ret3 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x2, [c], + return_same_structure=False, + name="while_3") # 3x + ret4 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x2 * x2, [c], + return_same_structure=False, + name="while_4") # 9x + ret5 = while_loop_v2( + lambda v: v < 16., + lambda v: v * v, [c], + return_same_structure=False, + name="while_stateless") # x**2 + return ret1, ret2, ret3, ret4, ret5 + + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + while_3 = concrete_fn.graph.get_operation_by_name("while_3") + while_4 = concrete_fn.graph.get_operation_by_name("while_4") + while_stateless = concrete_fn.graph.get_operation_by_name( + "while_stateless") + self.assertEqual(while_1.type, "While") + self.assertEqual(while_2.type, "While") + self.assertEqual(while_3.type, "While") + self.assertEqual(while_4.type, "While") + self.assertEqual(while_stateless.type, "StatelessWhile") + self.assertEmpty(while_1.control_inputs) + self.assertLen(while_2.control_inputs, 1) + self.assertIs(while_2.control_inputs[0], while_1) + self.assertEmpty(while_3.control_inputs) + self.assertLen(while_4.control_inputs, 1) + self.assertIs(while_4.control_inputs[0], while_3) + self.assertEmpty(while_stateless.control_inputs) + @test_util.run_deprecated_v1 def testDoubleDerivative(self): x = constant_op.constant(2.) @@ -360,7 +470,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): Cond, Body, [x, tensor_list], return_same_structure=False) for op in ops.get_default_graph().get_operations(): - if op.type == "While": + if op.type == "While" or op.type == "StatelessWhile": while_op = op body_graph = while_v2._get_graph(while_op, "body") @@ -443,7 +553,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): lambda i: i + 1, [constant_op.constant(0)], return_same_structure=False) while_op = output.op.inputs[0].op - self.assertEqual(while_op.type, "While") + if compat.forward_compatible(2019, 8, 23): + self.assertEqual(while_op.type, "StatelessWhile") return while_op def testDefaultName(self): @@ -524,23 +635,24 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): @test_util.run_deprecated_v1 def testForwardPassRewrite(self): - x = constant_op.constant(1.0, name="x") - output = while_v2.while_loop(lambda x: x < 10.0, - lambda x: x * 2.0, - [x])[0] - while_op = output.op.inputs[0].op - self.assertEqual(while_op.type, "While") - # outputs = [loop_counter, max_iters, x] - self.assertLen(while_op.outputs, 3) + if compat.forward_compatible(2019, 8, 23): + x = constant_op.constant(1.0, name="x") + output = while_v2.while_loop(lambda x: x < 10.0, + lambda x: x * 2.0, + [x])[0] + while_op = output.op.inputs[0].op + self.assertEqual(while_op.type, "StatelessWhile") + # outputs = [loop_counter, max_iters, x] + self.assertLen(while_op.outputs, 3) - gradients_impl.gradients(output, x) - # while_op should have been rewritten to output 2.0 intermediate. - # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator] - self.assertLen(while_op.outputs, 5) + gradients_impl.gradients(output, x) + # while_op should have been rewritten to output 2.0 intermediate. + # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator] + self.assertLen(while_op.outputs, 5) - gradients_impl.gradients(output, x) - # Computing the gradient again shouldn't rewrite while_op again. - self.assertLen(while_op.outputs, 5) + gradients_impl.gradients(output, x) + # Computing the gradient again shouldn't rewrite while_op again. + self.assertLen(while_op.outputs, 5) @test_util.run_deprecated_v1 def testRandomUniformShape(self): diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index ca4f0406360..231d9584779 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -257,7 +257,8 @@ def _VerifyGeneratedGradients(grads, op): """ # While ops have inputs added to them during the gradient computation, so we # skip the below check. See while_v2 for details. - if op.type == "While": return + if op.type == "While" or op.type == "StatelessWhile": + return if len(grads) != len(op.inputs): raise ValueError("Num gradients %d generated for op %s do not match num " diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 174fcb97bb9..42d307059c3 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -23,6 +23,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module @@ -236,7 +237,23 @@ def while_loop(cond, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] - outputs = gen_functional_ops._while( + + cond_stateful_ops = [ + op for op in cond_graph.get_operations() if op._is_stateful + ] + body_stateful_ops = [ + op for op in body_graph.get_operations() if op._is_stateful + ] + # TODO(yanhuasun): Remove this after Aug 23, 2019. This is required to + # abide by 3-week forward compat window of new TF python op generating + # code with stale runtime binaries. + if (cond_stateful_ops or body_stateful_ops or + not compat.forward_compatible(2019, 8, 23)): + op_fn = gen_functional_ops._while + else: + op_fn = gen_functional_ops.stateless_while + + outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), @@ -270,6 +287,7 @@ def while_loop(cond, return outputs +@ops.RegisterGradient("StatelessWhile") @ops.RegisterGradient("While") def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 6ee387d3353..cf8868af342 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4062,7 +4062,7 @@ tf_module { } member_method { name: "StatelessWhile" - argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], " } member_method { name: "StaticRegexFullMatch" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 6ee387d3353..cf8868af342 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4062,7 +4062,7 @@ tf_module { } member_method { name: "StatelessWhile" - argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], " } member_method { name: "StaticRegexFullMatch"