diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index b8680c18f7b..dfb5fccf9af 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -781,6 +781,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index f9af5581a67..7fbd3dfa62c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -22,10 +22,10 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -285,8 +285,12 @@ string StateMap::AncestorStateToString(const Node* node) const { } FunctionalizeCond::FunctionalizeCond(Graph* graph, - FunctionLibraryDefinition* library) - : state_map_(graph), library_(library), graph_(graph) {} + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) + : state_map_(graph), + library_(library), + graph_(graph), + node_filter_(node_filter) {} // Class representing the merge/switch nodes that will become a conditional. class Conditional { @@ -807,11 +811,13 @@ Status Conditional::BuildIfNode(Graph* graph, << PartialTensorShapeUtils::PartialShapeListString(output_shapes); builder.Attr("Tcond", DT_BOOL); - string outside_compilation; - if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName, - &outside_compilation) - .ok()) { - builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + // Add all underscore attributes, these need to be propagated. + for (const auto& attr : predicate_.node->def().attr()) { + const string& name(attr.first); + const AttrValue& value(attr.second); + if (absl::StartsWith(name, "_")) { + builder.Attr(name, value); + } } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... @@ -1076,7 +1082,7 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a // disjunction of the states along the different input edges. For a merge that - // can be transformed into a If the two inputs paths have to have a predicate + // can be transformed into an If the two inputs paths have to have a predicate // on which they differ (e.g., along one edge predicate `p` has to hold while // on another it should not). This function first determines this predicate // and then the resultant state is the common path between the two inputs @@ -1368,8 +1374,9 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( deleted[graph_->kSourceId] = true; deleted[graph_->kSinkId] = true; - // All remaining Switch nodes are not reachable from a Merge node and - // removed. This is to account for dead Switch nodes. + // All remaining switch nodes that were not excluded from functionalization + // according to `node_filter_` are not reachable from a merge node and + // removed. This is to account for dead switch nodes. for (int s_id : switch_ids_) { Node* s = graph_->FindNodeId(s_id); if (s == nullptr) continue; @@ -1379,11 +1386,17 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( // conditional. if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); } - deleted[s_id] = true; - graph_->RemoveNode(s); + // Only remove switch node if we have functionalized the corresponding + // condition before (according to `node_filter_`). + if (!node_filter_ || node_filter_(s)) { + VLOG(2) << "Removing obsolete switch node " << s->name(); + deleted[s_id] = true; + graph_->RemoveNode(s); + } } - // All merge nodes should have been transformed at this point and we remove + // All merge nodes that were not excluded from functionalization according to + // `node_filter_` should have been transformed at this point and we remove // them from the graph here. for (Node* m : merge_order) { for (const Edge* e : m->out_edges()) { @@ -1393,8 +1406,13 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( // being removed in AddOutputEdges. if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); } - deleted[m->id()] = true; - graph_->RemoveNode(m); + // Only remove merge node if we have functionalized the corresponding + // condition before (according to `node_filter_`). + if (!node_filter_ || node_filter_(m)) { + VLOG(2) << "Removing obsolete merge node " << m->name(); + deleted[m->id()] = true; + graph_->RemoveNode(m); + } } // Enqueue all the dead nodes. @@ -1403,7 +1421,7 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( delete_nodes.push_back(n->id()); } } - + // Remove dead nodes and nodes that are reachable from dead nodes. while (!delete_nodes.empty()) { int d_id = delete_nodes.front(); delete_nodes.pop_front(); @@ -1414,6 +1432,7 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( for (const Edge* e : d->out_edges()) { delete_nodes.push_back(e->dst()->id()); } + VLOG(2) << "Removing obsolete node " << d->name(); deleted[d_id] = true; graph_->RemoveNode(d); } @@ -1454,6 +1473,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { // AncestorState from the innermost to the outermost into IfOps; // Note: In the above only nodes that feed into a merge node will be // considered for functionalization. + // Note: Nodes for which `node_filter_` returns false are excluded. // Perform a DFS over the graph and // * Determine the reverse topological order of the nodes (there should be no @@ -1463,12 +1483,18 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector rev_topo_order; std::vector merge_order; DFS(*graph_, nullptr, [&](Node* n) { - if (IsSwitch(n)) { - AddSwitchId(n->id()); - } - if (IsMerge(n)) { - merge_order.push_back(n); + // Only collect switch and merge nodes that are not filtered out, those form + // the conditions that will be functionalized. + if (!node_filter_ || node_filter_(n)) { + if (IsSwitch(n)) { + AddSwitchId(n->id()); + } + if (IsMerge(n)) { + merge_order.push_back(n); + } } + // Collect all other nodes here, independent of `node_filter_`, because they + // might belong to a condition that should be functionalized. if (n->IsOp()) { rev_topo_order.push_back(n); } @@ -1571,19 +1597,22 @@ void FunctionalizeCond::AddSwitchId(int switch_id) { } Status FunctionalizeCond::Functionalize(Graph* graph, - FunctionLibraryDefinition* library) { + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library); + FunctionalizeCond fc(graph, library, node_filter); return fc.FunctionalizeInternal(); } } // namespace functionalize_cond -Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) { +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled // in successive invocations. - return functionalize_cond::FunctionalizeCond::Functionalize(graph, library); + return functionalize_cond::FunctionalizeCond::Functionalize(graph, library, + node_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 7940732a11d..741fe04a500 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ #include + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -26,8 +28,17 @@ namespace tensorflow { // Functionalize all the switch-merge nodes of a loop-free graph into If // nodes. That is, attempt to transform every remaining switch and merge nodes // in the graph into If nodes. -// Precondition: All while loops have been removed from graph. -Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); +// +// If `node_filter` is defined, then only conditions for whose nodes +// `node_filter` returns true are functionalized. +// +// Preconditions: +// a) Same as for `FunctionalizeControlFlow` (see comment there). +// b) While loops must have been functionalized before according to +// `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same +// filter before calling this function). +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); // Internal functions/classes exposed for testing purposes. namespace functionalize_cond { @@ -172,11 +183,9 @@ class StateMap { // of the given graph together. class FunctionalizeCond { public: - // Functionalize all the switch-merge nodes of a loop-free graph into If - // nodes. That is, attempt to transform every remaining switch and merge nodes - // in the graph into If nodes. - // Precondition: All while loops have been removed from graph. - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + // See comment for function `FunctionalizeCond`. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter); // Build identity node with the same name as the merge that will be replaced // in case the output is fetched/colocated. @@ -197,7 +206,8 @@ class FunctionalizeCond { void AddSwitchId(int switch_id); private: - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter); // Performs the actual cond functionalization. Iterate over groups of merge // nodes (linked by common predicates & ancestor IDs), from innermost to @@ -268,6 +278,9 @@ class FunctionalizeCond { friend class FunctionalizeCondTest; std::vector switch_ids_; + + // Controls which nodes are skipped for functionalization. + NodeFilter node_filter_ = {}; }; } // namespace functionalize_cond diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index aba0b411b08..0438c41c5d6 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -40,8 +40,8 @@ class FunctionalizeCondTest : public ::testing::Test { graph_.reset(new Graph(OpRegistry::Global())); flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_)); - fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(), - flib_def_.get())); + fc_.reset(new functionalize_cond::FunctionalizeCond( + graph_.get(), flib_def_.get(), NodeFilter{})); } StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 2fcfd20f49f..10b26f9801c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -46,20 +46,19 @@ limitations under the License. namespace tensorflow { -// Transformation that converts TensorFlow's graph control flow constructs into -// functional equivalents. Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library) { + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); // Functionalize and remove while loops from graph. - TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter)); // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled // in successive invocations. - TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter)); VLOG(2) << "FunctionalizeControlFlow (final): " << DumpGraphToFile("functionalize_final", *graph, library); @@ -68,12 +67,13 @@ Status FunctionalizeControlFlow(Graph* graph, } Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, - FunctionLibraryDefinition* library) { + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index fb35d1b4198..f9e751e2d67 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" @@ -26,11 +27,27 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators and tf.cond() conditionals into function If operators, suitable for // XLA compilation. +// +// If `node_filter` is defined, then only loops and conditions for whose +// nodes `node_filter` returns true are functionalized. +// +// Precondition: +// For any node in a loop or condition for which `node_filter` returns true, +// all nodes inside of the same loop or condition must also return true +// (including nodes in other nested loops and conditions inside of that loop or +// condition). +// This means that a "not to be functionalized" loop or condition is not allowed +// inside a "to be functionalized" loop or condition. +// +// The user of this function is responsible for using a node filter that +// satisfies the above conditions. Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library); + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, - FunctionLibraryDefinition* library); + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); // This pass looks at the graph, and turns V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 8f53d227249..79a042ad680 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -62,7 +62,18 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, // z = control_flow_ops.cond( // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), // lambda: math_ops.add(x, 23)) -TEST(FunctionalizeControlFlow, Conditional) { +// +// Tests different node filters. +class ConditionalTestFixture : public ::testing::TestWithParam { + protected: + void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); } + void RunTest(); + + private: + bool restrict_to_tpu_nodes_ = false; +}; + +void ConditionalTestFixture::RunTest() { Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -92,14 +103,25 @@ TEST(FunctionalizeControlFlow, Conditional) { std::initializer_list{add, mul}); TF_EXPECT_OK(scope.ToGraph(&graph)); + + // Set `_tpu_replicate` attribute for all nodes. + for (Node* n : graph.nodes()) { + n->AddAttr("_tpu_replicate", "cluster"); + } } + // If `restrict_to_tpu_nodes_` is true let filter function return true for + // `_tpu_replicate` nodes. + NodeFilter node_filter = + restrict_to_tpu_nodes_ + ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } + : NodeFilter{}; FunctionLibraryDefinition library(OpRegistry::Global(), {}); GraphDef optimized_graph_def; graph.ToGraphDef(&optimized_graph_def); - TF_ASSERT_OK( - FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); - TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def, + &library, node_filter)); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter)); GraphDef converted_graph_def; graph.ToGraphDef(&converted_graph_def); @@ -180,6 +202,13 @@ TEST(FunctionalizeControlFlow, Conditional) { } } +TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } + +INSTANTIATE_TEST_SUITE_P( + FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(), + [](const ::testing::TestParamInfo& + info) { return info.param ? "with_filter" : "without_filter"; }); + // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, @@ -758,25 +787,75 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { } } -// Example with nesting, loop-invariant arguments, and resource variables. -// -// accum = resource_variable_ops.ResourceVariable(1) -// x = array_ops.placeholder(2, dtype=dtypes.int32) -// y = 3 + x -// -// def inner_body(j, k): -// add = state_ops.assign_add(accum, k * j + x) -// with ops.control_dependencies([add]): -// return [j + 1, k] -// -// def body(i): -// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, -// [1, y], name="inner") -// with ops.control_dependencies(m): -// return [i + 1] -// -// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") -TEST(FunctionalizeControlFlow, Complex) { +// More complex example with nesting, loop-invariant arguments, and resource +// variables. Used for multiple tests with different node filters. +class ComplexTestFixture + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + restrict_to_tpu_nodes_ = std::get<0>(GetParam()); + mark_inner_loop_tpu_ = std::get<1>(GetParam()); + mark_outer_loop_tpu_ = std::get<2>(GetParam()); + } + void RunTest(); + + private: + void CheckOuterNodesFunctionalized(const GraphDef& graph_def, + const FunctionLibraryDefinition& library, + NameAttrList& inner_cond_fn, + NameAttrList& inner_body_fn); + void CheckInnerNodesFunctionalized(const GraphDef& graph_def, + const FunctionLibraryDefinition& library, + const NameAttrList& inner_cond_fn, + const NameAttrList& inner_body_fn); + + bool restrict_to_tpu_nodes_ = false; + bool mark_inner_loop_tpu_ = false; + bool mark_outer_loop_tpu_ = false; +}; + +TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); } + +INSTANTIATE_TEST_SUITE_P( + FunctionalizeControlFlow, ComplexTestFixture, + ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()), + [](const ::testing::TestParamInfo& info) { + bool restrict_to_tpu_nodes = std::get<0>(info.param); + bool mark_inner_loop_tpu = std::get<1>(info.param); + bool mark_outer_loop_tpu = std::get<2>(info.param); + + string node_string; + if (mark_inner_loop_tpu && mark_outer_loop_tpu) + node_string = "both_loops_tpu"; + else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu) + node_string = "no_loop_tpu"; + else + node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu"; + + string name = absl::StrCat( + restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string); + return name; + }); + +void ComplexTestFixture::RunTest() { + // Graph: + // + // accum = resource_variable_ops.ResourceVariable(1) + // x = array_ops.placeholder(2, dtype=dtypes.int32) + // y = 3 + x + // + // def inner_body(j, k): + // add = state_ops.assign_add(accum, k * j + x) + // with ops.control_dependencies([add]): + // return [j + 1, k] + // + // def body(i): + // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, + // [1, y], name="inner") + // with ops.control_dependencies(m): + // return [i + 1] + // + // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -846,7 +925,8 @@ TEST(FunctionalizeControlFlow, Complex) { 5); auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); - auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); + auto loop_cond = + ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j); auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), merge_j.output, loop_cond); @@ -906,193 +986,246 @@ TEST(FunctionalizeControlFlow, Complex) { TF_EXPECT_OK(scope.ToGraph(&graph)); } + // Add '_tpu_replicate' attributes as specified. + for (Node* n : graph.nodes()) { + string name = n->name(); + bool is_inner_node = name.find("outer/inner/") != string::npos; + bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos; + if ((is_inner_node && mark_inner_loop_tpu_) || + (is_outer_node && mark_outer_loop_tpu_)) { + n->AddAttr("_tpu_replicate", "cluster"); + } + } FunctionLibraryDefinition library(OpRegistry::Global(), {}); - GraphDef optimized_graph_def; - graph.ToGraphDef(&optimized_graph_def); - TF_ASSERT_OK( - FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); - TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); - GraphDef converted_graph_def; - graph.ToGraphDef(&converted_graph_def); + GraphDef orig_graph_def, optimized_graph_def; + graph.ToGraphDef(&orig_graph_def); + optimized_graph_def = orig_graph_def; + // If `restrict_to_tpu_nodes_` is true let filter function return true for + // `_tpu_replicate` nodes, otherwise don't set filter. + NodeFilter node_filter = + restrict_to_tpu_nodes_ + ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } + : NodeFilter{}; - for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { - NameAttrList outer_cond_fn, outer_body_fn; - TF_EXPECT_OK( - FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def, + &library, node_filter); + Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter); + ASSERT_EQ(status1, status2); + if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) { + // This case violates the precondition of `FunctionalizeControlFlow`, we + // expect an internal error. + ASSERT_EQ(errors::IsInternal(status1), true); + return; + } else { + // Supported cases, no error expected. + TF_ASSERT_OK(status1); + } - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); - auto y = ops::Add(scope.WithOpName("y"), x, three); - - auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, - TensorShape({})); - - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - - auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Outer condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - - auto ten = ops::Const( - scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); - auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Outer body graph. + GraphDef optimized_converted_graph_def; + graph.ToGraphDef(&optimized_converted_graph_def); + for (const GraphDef& graph_def : + {optimized_graph_def, optimized_converted_graph_def}) { NameAttrList inner_cond_fn, inner_body_fn; - { - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - - // Find the inner condition and body names. - TF_EXPECT_OK( - FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - - auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( - scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); - auto while_op = - ops::While(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); - - auto one_outer = ops::Const( - scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), - 1); - auto add_i = - ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(absl::Span{ - while_op[0].op(), while_op[1].op()}), - identity_i, one_outer); - - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0); - auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1); - auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2); - auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + if (!restrict_to_tpu_nodes_ || + (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && + mark_inner_loop_tpu_)) { + // We expect that both inner and outer nodes have been functionalized. + CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn, + inner_body_fn); + CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn, + inner_body_fn); + } else /*restrict_to_tpu_nodes_ == true*/ { + if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) { + // Graph has no TPU nodes so we expect no functionalization. + TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def); + } else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) { + // We expect that only inner nodes have been functionalized. + TF_EXPECT_OK( + FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn)); + CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn, + inner_body_fn); + } } + } +} - // Inner condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); +void ComplexTestFixture::CheckOuterNodesFunctionalized( + const GraphDef& graph_def, const FunctionLibraryDefinition& library, + NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) { + NameAttrList outer_cond_fn, outer_body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - auto five = ops::Const( - scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), - 5); - auto less_j = - ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); - auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0); + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Inner body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - auto identity_j = - ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); - auto identity_k = - ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); - auto mul_jk = - ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); - auto add_jkx = - ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); - auto assign = ops::AssignAddVariableOp( - scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - auto one = ops::Const( - scope.WithOpName("outer/inner/One") - .WithControlDependencies( - absl::Span{assign.operation}), - 1); - auto add_j = - ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0); - auto retval1 = - ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1); - auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2); - auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // Outer body graph. + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::While(scope.WithOpName("outer/inner/LoopCond"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(absl::Span{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2); + auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +void ComplexTestFixture::CheckInnerNodesFunctionalized( + const GraphDef& graph_def, const FunctionLibraryDefinition& library, + const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) { + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); + auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2); + auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); } } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index c31d2a4f07f..8df1c5f0c50 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -51,7 +51,8 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames) { + std::unordered_map* frames, + const NodeFilter& node_filter) { for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; @@ -81,6 +82,9 @@ Status ExtractWhileLoopFrames( frame.loop_cond = node; } frame.nodes.insert(node); + if (node->IsControlFlow() && node_filter && !node_filter(node)) { + frame.should_be_functionalized = false; + } } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index f986376c8e3..1152a14f961 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -26,6 +26,8 @@ limitations under the License. namespace tensorflow { +using NodeFilter = std::function; + // Information about a loop argument. struct WhileLoopArg { // Every loop argument has an Enter node. @@ -60,13 +62,22 @@ struct WhileLoopFrame { // Set of nodes that belong to the loop frame. std::unordered_set nodes; + + // After `ExtractWhileLoopFrames` this is true if for all control flow nodes + // of this frame `node_filter` returns true, i.e., the frame should be + // functionalized, and false otherwise. + bool should_be_functionalized = true; }; // Extracts v1 while loops within a graph and creates a map of // . +// If `node_filter` is defined, then we keep track of frames that should be +// functionalized according to the filter (see comment for +// `FunctionalizeControlFlow` for more details about node filters). Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames); + std::unordered_map* frames, + const NodeFilter& node_filter = {}); // Check that the graph has no cycle containing the given node. Status CheckNodeNotInCycle(const Node* node, const int num_nodes); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 4e55aef3713..ab2a6958723 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -22,11 +22,10 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/match.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -162,7 +161,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, *body_output = absl::make_unique(graph.op_registry()); Graph* output = body_output->get(); - // Map from nodes in the original graph to the condition graph. + // Map from nodes in the original graph to the body graph. std::vector node_map(graph.num_node_ids(), nullptr); std::vector squash_src_outputs(graph.num_node_ids(), false); @@ -212,7 +211,14 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, } Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, - FunctionLibraryDefinition* library) { + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { + if (node_filter && !frame->should_be_functionalized) { + VLOG(2) << "Skipping functionalization for frame " << frame->name + << " because it has control flow nodes that are filtered out by " + "the specified node filter."; + return Status::OK(); + } VLOG(2) << "Frame " << frame->name << " before: " << DumpGraphToFile("functionalize_before", *graph, library); @@ -349,10 +355,6 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, return errors::InvalidArgument("Missing Switch successor to ", FormatNodeForError(*arg.merge)); } - - // Update the device on the Identity outputs of the switch to match their - // target. These Identity outputs do not - // Loop over the switch node's output to: // - Find the Exit successor. // - Set the sharding on all Identity outputs of the switch. These @@ -402,12 +404,12 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); FixupSourceAndSinkEdges(cond_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); FixupSourceAndSinkEdges(body_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter)); VLOG(2) << "Frame " << frame->name << " condition: " << DumpGraphToFile("loop_condition", *cond_graph, library) @@ -433,17 +435,13 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); - string outside_compilation; - string frontend_attributes; - if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName, - &frontend_attributes) - .ok()) { - builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes); - } - if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, - &outside_compilation) - .ok()) { - builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + // Add all underscore attributes, these need to be propagated. + for (const auto& attr : frame->loop_cond->def().attr()) { + const string& name(attr.first); + const AttrValue& value(attr.second); + if (absl::StartsWith(name, "_")) { + builder.Attr(name, value); + } } std::vector inputs; for (int i = 0; i < frame->args.size(); ++i) { @@ -495,6 +493,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, // Remove the old nodes from the graph, and add the while node to the parent // frame. for (Node* node : frame->nodes) { + VLOG(2) << "Removing obsolete node " << node->name(); graph->RemoveNode(node); } frame->nodes.clear(); @@ -507,8 +506,8 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, } } // namespace -Status FunctionalizeWhileLoop(Graph* graph, - FunctionLibraryDefinition* library) { +Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. @@ -523,7 +522,8 @@ Status FunctionalizeWhileLoop(Graph* graph, // Builds Frames, indexed by name. std::unordered_map frames; - TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames)); + TF_RETURN_IF_ERROR( + ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter)); // Adds frames with no children (i.e., the innermost frames) to a worklist. std::deque worklist; @@ -533,7 +533,9 @@ Status FunctionalizeWhileLoop(Graph* graph, } } - // Eliminate loops from innermost to outermost. + // Eliminate loops from innermost to outermost. Note that the precondition for + // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach + // works. while (!worklist.empty()) { WhileLoopFrame* frame = worklist.front(); worklist.pop_front(); @@ -542,7 +544,7 @@ Status FunctionalizeWhileLoop(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; @@ -551,14 +553,16 @@ Status FunctionalizeWhileLoop(Graph* graph, } } - // There should be no cycle at this point, since while loops have been removed - // from graph. - // Check that the newly added While nodes don't feed into themselves. - for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "While") { - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNodeNotInCycle(node, graph->num_node_ids()), - "Functionalizing loop failed."); + if (!node_filter) { + // There should be no cycle at this point, since while loops have been + // removed from graph. Check that the newly added While nodes don't feed + // into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "While") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(node, graph->num_node_ids()), + "Functionalizing loop failed."); + } } } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h index 207b29b8498..ddd6b655cd5 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.h +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -24,7 +25,14 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators, suitable for XLA compilation. If lookup_library is provided, use // it to make the library for control flow self-contained. -Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library); +// +// If `node_filter` is defined, then only loops for whose nodes `node_filter` +// returns true are functionalized. +// +// Preconditions: +// Same as for `FunctionalizeControlFlow` (see comment there). +Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); } // namespace tensorflow