diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 394477e2e89..8fe1635911c 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -20,9 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_function_call_op.h" #include "tensorflow/core/common_runtime/lower_if_op.h" #include "tensorflow/core/common_runtime/lower_while_op.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -122,6 +120,16 @@ Status LowerFunctionalOpsPass::Run( (options.session_options->config.experimental().executor_type() == "SINGLE_THREADED_EXECUTOR"); + // Returns true if `node` will be used for XLA compilation. + const auto used_by_xla = [](Node* node) -> bool { + return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node); + }; + + // Returns true if control flow `node` should be lowered to Switch/Merge. + const auto lower_control_flow = [&](Node* node) -> bool { + return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node); + }; + // Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr` // attr set and inline all function calls into the graph. // We start at `i` = 2 to skip the source and sink nodes. @@ -132,31 +140,34 @@ Status LowerFunctionalOpsPass::Run( for (int i = 2; i < g->num_node_ids(); ++i) { Node* n = g->FindNodeId(i); if (n == nullptr) continue; // deleted node - if (MarkedForTpuCompilation(n)) continue; - if (MarkedForXlaCompilation(n)) continue; - // Always lower function calls produces by lowering If/While nodes. - if (IsFunctionCall(*flib_def, *n) && + // Always lower function calls produced by lowering If/While nodes. + if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) && (lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) { TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); continue; } - if (!functional_control_flow && LowerUsingSwitchMergeIsOn(n)) { - if (n->IsIfNode()) { - TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable)); - } else if (n->type_string() == "Case") { - TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable)); - } else if (n->IsWhileNode()) { - TF_RETURN_IF_ERROR( - RewriteWhileNode(n, g, keep_lowered_nodes_fetchable)); - } else { - return errors::Internal( - "Node ", FormatNodeForError(*n), " of type ", n->type_string(), - " has '", LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, - "' attr set but it does not support lowering.\n"); - } + // If we are allowed to used function control flow, we do not need to check + // for If/While/Case nodes in the graph. + if (functional_control_flow) continue; + + if (n->IsIfNode() && lower_control_flow(n)) { + TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable)); + + } else if (n->type_string() == "Case" && lower_control_flow(n)) { + TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable)); + + } else if (n->IsWhileNode() && lower_control_flow(n)) { + TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, keep_lowered_nodes_fetchable)); + + } else { + DCHECK(!lower_control_flow(n)) + << "Node " << FormatNodeForError(*n) << " of type " + << n->type_string() << " has '" + << LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr + << "' attr set but it does not support lowering.\n"; } } diff --git a/tensorflow/core/common_runtime/lower_functional_ops_test.cc b/tensorflow/core/common_runtime/lower_functional_ops_test.cc index 9bef90a01e4..21f2a5e82d8 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops_test.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops_test.cc @@ -326,25 +326,5 @@ TEST(LowerIfWhileTest, WhileInCond) { } } -TEST(LowerIfWhileTest, RaisesWhenLoweringUnhandledOpType) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - Scope root = Scope::NewRootScope().ExitOnError(); - Node* const_node; - Tensor const_val(DT_INT32, TensorShape({})); - const_val.scalar()() = 1; - TF_ASSERT_OK(NodeBuilder("const", "Const") - .Attr("value", const_val) - .Attr("dtype", const_val.dtype()) - .Attr(kLowerUsingSwitchMergeAttr, true) - .Finalize(root.graph(), &const_node)); - TF_ASSERT_OK(root.DoShapeInference(const_node)); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - Status s = Rewrite(&graph); - ASSERT_EQ(s.code(), error::INTERNAL); - AssertHasSubstr(s.error_message(), "does not support lowering"); -} - } // namespace } // namespace tensorflow