diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 1aee747476c..8cae193fa30 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -786,4 +786,144 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g, } return Status::OK(); } + +Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld) { + for (Node* n : g->nodes()) { + if (n->type_string() != "EmptyTensorList") { + continue; + } + + // Find the forward While op. + std::vector fwd_while_edges; + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge() && e->dst()->type_string() == "While") { + fwd_while_edges.push_back(e); + } + } + if (fwd_while_edges.size() != 1) { + // No forward While op found, or multiple forward While ops. + continue; + } + + // Find the backward While op. + Node* fwd_while = fwd_while_edges[0]->dst(); + int fwd_while_dst_input = fwd_while_edges[0]->dst_input(); + std::vector bwd_while_edges; + for (const Edge* e : fwd_while->out_edges()) { + if (e->src_output() == fwd_while_dst_input && + e->dst()->type_string() == "While") { + bwd_while_edges.push_back(e); + } + } + if (bwd_while_edges.size() != 1) { + // No backward While op found, or multiple backward While ops. + continue; + } + + Node* bwd_while = bwd_while_edges[0]->dst(); + int bwd_while_dst_input = bwd_while_edges[0]->dst_input(); + + // Look into forward While body function and check if TensorListPushBack op + // has a Const input. + NameAttrList fwd_body_attr; + TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr)); + const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name()); + if (!fwd_body) { + return errors::InvalidArgument("Cannot find function ", + fwd_body_attr.name(), " for While node ", + fwd_while->DebugString()); + } + std::unique_ptr fwd_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody)); + + // Find the TensorListPushBack node; it's one of fwd_arg's successors. + Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input]; + std::vector tl_push_nodes; + for (const Edge* out_edge : fwd_arg->out_edges()) { + if (out_edge->dst()->type_string() == "TensorListPushBack") { + tl_push_nodes.push_back(out_edge->dst()); + } + } + if (tl_push_nodes.size() != 1) { + // No TensorListPushBack found, or multiple TensorListPushBack. + continue; + } + + // Get input for the TensorListPushBack node. + Node* input_node; + TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node)); + if (input_node->type_string() != "Const") { + // Input for the TensorList is not Const node. + continue; + } + + NodeDef const_input_nodedef = input_node->def(); + + // Rewrite backward While body function, replace usages of + // TensorListPopBack with a Const node. + NameAttrList bwd_body_attr; + TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr)); + const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name()); + if (!bwd_body) { + return errors::InvalidArgument("Cannot find function ", + bwd_body_attr.name(), " for While node ", + bwd_while->DebugString()); + } + std::unique_ptr bwd_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper( + *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody)); + + // Find the TensorListPopBack node; it's one of bwd_arg's successors. + Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input]; + std::vector tl_pop_nodes; + for (const Edge* out_edge : bwd_arg->out_edges()) { + if (out_edge->dst()->type_string() == "TensorListPopBack") { + tl_pop_nodes.push_back(out_edge->dst()); + } + } + if (tl_pop_nodes.size() != 1) { + // No TensorListPopBack found, or multiple TensorListPopBack. + continue; + } + + // Replace TensorListPopBack usages with Const node. + std::vector edges_to_replace; + for (const Edge* e : tl_pop_nodes[0]->out_edges()) { + if (e->src_output() == 1) { + edges_to_replace.push_back(e); + } + } + if (edges_to_replace.empty()) { + continue; + } + Status s; + const_input_nodedef.set_name( + bwd_fbody->graph->NewName(const_input_nodedef.name())); + Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s); + TF_RETURN_IF_ERROR(s); + for (const Edge* e : edges_to_replace) { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + bwd_fbody->graph->RemoveEdge(e); + bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input); + } + + // Add rewritten backward While body function. + FunctionDef new_fdef; + string new_name = fld->UniqueFunctionName( + absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_")); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef)); + + // Change backward While op to use the new body function. + bwd_body_attr.set_name(new_name); + bwd_while->ClearAttr("body"); + bwd_while->AddAttr("body", bwd_body_attr); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 0b78631fd24..c9d73450425 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -202,6 +202,16 @@ Status PropagateConstIntoFunctionalNodes( Status PruneUnreachableFunctionsFromGraph(const Graph& g, FunctionLibraryDefinition* fld); +// Finds the following pattern in the graph: +// 1) EmptyTensorList -> forward While op -> backward While op, +// 2) in forward While op, a Const node is pushed, +// 3) in backward While op, data is popped from the tensor list. +// And rewrites backward While op to use Const node instead of TensorListPopBack +// result. +// TODO(b/128633174) remove the TensorList and related TensorList ops. +Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 28b4744470e..0fde45c2696 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" @@ -416,5 +418,86 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { EXPECT_EQ(const_def->second.op(), "Const"); } +TEST(PropagateConstIntoFunctionalNodes, RewriteTensorListWithConstMember) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto result = + ops::Const(scope.WithOpName("result"), false, TensorShape({})); + auto ret = ops::_Retval(scope.WithOpName("ret"), result, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + { + // Forward body graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto element = ops::Const(scope.WithOpName("element"), 0, TensorShape({})); + auto push = + ops::TensorListPushBack(scope.WithOpName("push"), input, element); + auto ret = ops::_Retval(scope.WithOpName("ret"), push.output_handle, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "fwd_body", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + { + // Backward body graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({})); + auto pop = + ops::TensorListPopBack(scope.WithOpName("pop"), input, shape, DT_INT32); + auto identity = ops::Identity(scope.WithOpName("identity"), pop.tensor); + auto ret = ops::_Retval(scope.WithOpName("ret"), pop.output_handle, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "bwd_body", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({})); + auto max_num_elements = + ops::Const(scope.WithOpName("max_num_elements"), 10, TensorShape({})); + auto tl = ops::EmptyTensorList(scope.WithOpName("tl"), shape, + max_num_elements, DT_INT32); + NameAttrList cond_fn, fwd_body_fn, bwd_body_fn; + cond_fn.set_name("cond"); + fwd_body_fn.set_name("fwd_body"); + bwd_body_fn.set_name("bwd_body"); + auto fwd_while_op = + ops::While(scope.WithOpName("fwd_while"), + std::initializer_list{tl}, cond_fn, fwd_body_fn); + auto bwd_while_op = + ops::While(scope.WithOpName("bwd_while"), + std::initializer_list{fwd_while_op.output[0]}, cond_fn, + bwd_body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(RewriteTensorListWithConstElement(&graph, &fld)); + + // Check that in rewritten backward While body function, the Identity node now + // has Const node as input. + const FunctionDef* bwd_body = fld.Find("bwd_body_tl_rewrite_0"); + ASSERT_NE(bwd_body, nullptr); + std::unique_ptr bwd_fbody; + TF_CHECK_OK( + FunctionDefToBodyHelper(*bwd_body, AttrSlice(), &fld, &bwd_fbody)); + auto node_name_index = bwd_fbody->graph->BuildNodeNameIndex(); + const Node* identity = node_name_index.at("identity"); + ASSERT_NE(identity, nullptr); + const Node* input; + TF_ASSERT_OK(identity->input_node(0, &input)); + EXPECT_EQ(input->type_string(), "Const"); +} + } // namespace } // namespace tensorflow