diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 2420033fc9e..c17e75fdfa6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -322,7 +322,6 @@ cc_library( deps = [ ":compilation_passes", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", - "//tensorflow/compiler/tf2xla:rearrange_function_argument_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, @@ -702,7 +701,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/tf2xla:rearrange_function_argument_pass", + "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 6c470fa51c8..69186da38f2 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -39,10 +39,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, // third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc // FunctionalizeControlFlowPass: 27 // -// from -// third_party/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc -// RearrangeFunctionArgumentPass: 28 -// // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional // control flow structure (XlaIf/XlaWhile). Following passes must diff --git a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc index cd97671750a..fb56ff2ddf5 100644 --- a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc +++ b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h" - #include "absl/strings/match.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" @@ -22,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -37,37 +36,7 @@ limitations under the License. namespace tensorflow { -class RearrangeFunctionArgumentForFunctionTest : public ::testing::Test { - public: - void SetUp() override { - SessionOptions session_options; - std::vector> devices; - TF_CHECK_OK(DeviceFactory::AddDevices( - session_options, "/job:localhost/replica:0/task:0", &devices)); - device_mgr_ = absl::make_unique(std::move(devices)); - } - - Status RearrangeFunctionArgumentTest( - const string &func_name, const string &new_func_name, - const protobuf::Map &attrs, - FunctionLibraryDefinition *fld, bool *modified) { - OptimizerOptions opts; - pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, - /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); - std::map> canonicalized_name_to_new_name; - auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); - return RearrangeFunctionArgumentForFunction( - func_name, new_func_name, attrs, fld, flr, - &canonicalized_name_to_new_name, modified); - } - - private: - std::unique_ptr device_mgr_; - std::unique_ptr pflr_; -}; - -TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) { +TEST(RearrangeFunctionArgumentForFunctionTest, Basic) { FunctionDefLibrary fdl; { // Function for StatefulPartitionedCall's "f", If's @@ -113,40 +82,45 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) { FunctionDef *xla_fdef = fdl.add_function(); TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef)); } - { - // Build the XLA computation func. - // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32) - // "arg0", "arg1" -> "if" (If) -> "ret0", "ret1" - // "arg0", "arg1" -> "while" (While) -> "ret2", "ret3" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); - Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); - NameAttrList f; - f.set_name("f1"); - auto if_op = ops::If(s.WithOpName("if"), arg1, - std::initializer_list{arg0, arg1}, - {DT_BOOL, DT_RESOURCE}, f, f); - auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0); - auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1); - NameAttrList cond_fn, body_fn; - cond_fn.set_name("f3"); - body_fn.set_name("f2"); - auto while_op = - ops::While(s.WithOpName("while"), - std::initializer_list{arg0, arg1}, cond_fn, body_fn); - auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2); - auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3); - std::unique_ptr g(new Graph(OpRegistry::Global())); - TF_CHECK_OK(s.ToGraph(g.get())); - FunctionDef *xla_fdef = fdl.add_function(); - TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); - } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - bool modified; - protobuf::Map attrs; - TF_CHECK_OK(RearrangeFunctionArgumentTest("cluster", "cluster_rewritten", - attrs, &fld, &modified)); + // Build the XLA computation graph. + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32) + // "arg0", "arg1" -> "if" (If) -> "ret0", "ret1" + // "arg0", "arg1" -> "while" (While) -> "ret2", "ret3" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); + NameAttrList f; + f.set_name("f1"); + auto if_op = ops::If(s.WithOpName("if"), arg1, + std::initializer_list{arg0, arg1}, + {DT_BOOL, DT_RESOURCE}, f, f); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("f3"); + body_fn.set_name("f2"); + auto while_op = + ops::While(s.WithOpName("while"), + std::initializer_list{arg0, arg1}, cond_fn, body_fn); + auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2); + auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + std::vector> fbodies; + TF_CHECK_OK(RearrangeFunctionArguments( + [&](const NameAttrList &function, const FunctionBody **fbody) { + std::unique_ptr new_fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()), + AttrSlice(&function.attr()), + &fld, &new_fbody)); + *fbody = new_fbody.get(); + fbodies.push_back(std::move(new_fbody)); + return Status::OK(); + }, + g.get(), &fld)); // Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE} // and output types should be {DT_BOOL}. @@ -159,10 +133,7 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) { EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL); // Check node "if" input and output edges. - std::unique_ptr xla_fbody; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, &xla_fbody)); - auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); + auto node_name_index = g->BuildNodeNameIndex(); const Node *if_node = node_name_index.at("if"); ASSERT_NE(if_node, nullptr); const Node *input_node; @@ -170,11 +141,13 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) { EXPECT_EQ(input_node->name(), "arg1"); TF_CHECK_OK(if_node->input_node(2, &input_node)); EXPECT_EQ(input_node->name(), "arg0"); - const Node *ret2_node = xla_fbody->ret_nodes[0]; - TF_CHECK_OK(ret2_node->input_node(0, &input_node)); + const Node *ret0_node = node_name_index.at("ret0"); + ASSERT_NE(ret0_node, nullptr); + TF_CHECK_OK(ret0_node->input_node(0, &input_node)); EXPECT_EQ(input_node->name(), "if"); - const Node *ret3_node = xla_fbody->ret_nodes[1]; - TF_CHECK_OK(ret3_node->input_node(0, &input_node)); + const Node *ret1_node = node_name_index.at("ret1"); + ASSERT_NE(ret1_node, nullptr); + TF_CHECK_OK(ret1_node->input_node(0, &input_node)); EXPECT_EQ(input_node->name(), "arg0"); // Check node "while" input and output edges. @@ -184,16 +157,18 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) { EXPECT_EQ(input_node->name(), "arg1"); TF_CHECK_OK(while_node->input_node(1, &input_node)); EXPECT_EQ(input_node->name(), "arg0"); - const Node *ret4_node = xla_fbody->ret_nodes[2]; - TF_CHECK_OK(ret4_node->input_node(0, &input_node)); + const Node *ret2_node = node_name_index.at("ret2"); + ASSERT_NE(ret2_node, nullptr); + TF_CHECK_OK(ret2_node->input_node(0, &input_node)); EXPECT_EQ(input_node->name(), "arg0"); - const Node *ret5_node = xla_fbody->ret_nodes[3]; - TF_CHECK_OK(ret5_node->input_node(0, &input_node)); + const Node *ret3_node = node_name_index.at("ret3"); + ASSERT_NE(ret3_node, nullptr); + TF_CHECK_OK(ret3_node->input_node(0, &input_node)); EXPECT_EQ(input_node->name(), "while"); } -TEST_F(RearrangeFunctionArgumentForFunctionTest, - WhileResourceRetvalFromDifferentArgUnimplemented) { +TEST(RearrangeFunctionArgumentForFunctionTest, + WhileResourceRetvalFromDifferentArgUnimplemented) { FunctionDefLibrary fdl; { // Function for While's "body". @@ -227,32 +202,37 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, FunctionDef *xla_fdef = fdl.add_function(); TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef)); } - { - // Build the XLA computation func. - // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32) - // "arg0", "arg1" -> "while" (While) - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); - Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1); - Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); - NameAttrList cond_fn, body_fn; - cond_fn.set_name("f1"); - body_fn.set_name("f2"); - auto while_op = ops::While(s.WithOpName("while"), - std::initializer_list{arg0, arg1, arg2}, - cond_fn, body_fn); - std::unique_ptr g(new Graph(OpRegistry::Global())); - TF_CHECK_OK(s.ToGraph(g.get())); - FunctionDef *xla_fdef = fdl.add_function(); - TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); - } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - bool modified; - protobuf::Map attrs; - Status s = RearrangeFunctionArgumentTest("cluster", "cluster_rewritten", - attrs, &fld, &modified); - EXPECT_EQ(s.code(), error::UNIMPLEMENTED); + // Build the XLA computation graph. + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32) + // "arg0", "arg1" -> "while" (While) + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1); + Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("f1"); + body_fn.set_name("f2"); + auto while_op = ops::While(s.WithOpName("while"), + std::initializer_list{arg0, arg1, arg2}, + cond_fn, body_fn); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + std::vector> fbodies; + Status status = RearrangeFunctionArguments( + [&](const NameAttrList &function, const FunctionBody **fbody) { + std::unique_ptr new_fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()), + AttrSlice(&function.attr()), + &fld, &new_fbody)); + *fbody = new_fbody.get(); + fbodies.push_back(std::move(new_fbody)); + return Status::OK(); + }, + g.get(), &fld); + EXPECT_EQ(status.code(), error::UNIMPLEMENTED); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a2f25724b71..dcce43cbe70 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -196,6 +196,7 @@ cc_library( ":tf2xla_util", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -514,12 +515,12 @@ cc_library( ) cc_library( - name = "rearrange_function_argument_pass", + name = "rearrange_function_argument", srcs = [ - "rearrange_function_argument_pass.cc", + "rearrange_function_argument.cc", ], hdrs = [ - "rearrange_function_argument_pass.h", + "rearrange_function_argument.h", ], deps = [ "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -535,17 +536,6 @@ cc_library( ], ) -cc_library( - name = "rearrange_function_argument_pass_registration", - srcs = [ - "rearrange_function_argument_pass_registration.cc", - ], - deps = [ - ":rearrange_function_argument_pass", - ], - alwayslink = 1, -) - cc_library( name = "functionalize_control_flow_pass_registration", srcs = [ diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc similarity index 67% rename from tensorflow/compiler/tf2xla/rearrange_function_argument_pass.cc rename to tensorflow/compiler/tf2xla/rearrange_function_argument.cc index fb86df07e37..5bef118c633 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h" +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include @@ -158,8 +158,9 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, // Given mapping between original input index and rearranged input index, change // "index" attribute for _Arg nodes. -void RearrangeArgNodes(gtl::InlinedVector* arg_nodes, // non-absl ok - const std::vector& index_mapping) { +void RearrangeArgNodes( + const gtl::InlinedVector* arg_nodes, // non-absl ok + const std::vector& index_mapping) { for (int i = 0; i < arg_nodes->size(); i++) { Node* n = (*arg_nodes)[i]; int new_index = index_mapping.at(i); @@ -271,8 +272,10 @@ void RearrangeRetvalNodes( } } -Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, - bool* node_rewritten) { +Status MaybeRewriteWhileNode( + std::function + get_function_body_fn, + Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) { // Check if this While node needs rewrite. std::vector types; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types)); @@ -303,11 +306,8 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, for (auto const& attr_name : std::vector{"cond", "body"}) { NameAttrList attr_value; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value)); - const FunctionDef* fdef = fld->Find(attr_value.name()); - TF_RET_CHECK(fdef != nullptr); - std::unique_ptr fbody; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody)); // Check that resource _Arg nodes for While node are always returned with // the same index, and we don't have cases like this: @@ -375,8 +375,10 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, return Status::OK(); } -Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, - bool* node_rewritten) { +Status MaybeRewriteIfNode( + std::function + get_function_body_fn, + Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) { // This node needs rewrite when either of these is true: // 1) Tin has DT_RESOURCE which requires rearrange; // 2) Tout has DT_RESOURCE. @@ -428,11 +430,8 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, std::vector{"then_branch", "else_branch"}) { NameAttrList f; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f)); - const FunctionDef* fdef = fld->Find(f.name()); - TF_RET_CHECK(fdef != nullptr); - std::unique_ptr fbody; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody)); if (input_need_rearrange) { // Change _Arg node index. @@ -501,95 +500,10 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld, } // namespace -Status RearrangeFunctionArgumentForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map>* canonicalized_name_to_new_name, - bool* modified) { - *modified = false; - - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - Graph* g = body->graph; - - // If any node has associated functions, rewrite them first. - // Gather nodes with associated functions first, because rewriting those nodes - // might involve node deletion/addition. Avoid modifying nodes while iterating - // it. - std::vector>> - nodes_to_associated_functions; - for (auto* n : g->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, fld); - if (!associated_functions.empty()) { - nodes_to_associated_functions.push_back({n, associated_functions}); - } - } - for (auto iter : nodes_to_associated_functions) { - Node* n = iter.first; - auto associated_functions = iter.second; - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = - Canonicalize(name, AttrSlice(&associated_function.attrs())); - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - string new_name; - bool function_modified; - if (iter != canonicalized_name_to_new_name->end()) { - // If we already processed this function, check if it was rewritten. If - // the function was rewritten, the entry will be non-empty. Otherwise - // the entry will be empty. - function_modified = iter->second.has_value(); - if (function_modified) { - new_name = iter->second.value(); - } - } else { - if (associated_function.type() == - AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { - // For SymbolicGradient, `name` is always "SymbolicGradient", - // which is not very informative. Use node name instead. - new_name = - fld->UniqueFunctionName(absl::StrCat(n->name(), "_rearrange_")); - } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_rearrange_")); - } - TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction( - name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name, &function_modified)); - if (function_modified) { - // If the function was rewritten, add an non-empty entry. So later we - // know we have processed this function, and it was rewritten into - // another function. - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - } else { - // If the function was not rewritten, add an empty entry. So later - // we know we have processed this function, and it does not need to be - // rewritten. - (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; - } - } - if (function_modified) { - *modified = true; - - // Notice that if "n" is a function call, RewriteAssociatedFunction() - // will delete it and create a new node instead, making "n" an invalid - // pointer. That's fine because in that case, associated_functions will - // only have one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - g, n, fld, associated_function, new_name)); - } - } - } - +Status RearrangeFunctionArguments( + std::function + get_function_body_fn, + Graph* g, FunctionLibraryDefinition* fld) { // Inline StatefulPartitionedCall nodes. std::vector call_nodes; for (Node* n : g->nodes()) { @@ -598,114 +512,30 @@ Status RearrangeFunctionArgumentForFunction( } } for (Node* n : call_nodes) { - *modified = true; - NameAttrList func_name_attrs; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs)); - const FunctionDef* fdef = fld->Find(func_name_attrs.name()); - if (!fdef) { - return errors::InvalidArgument("Cannot find function ", - func_name_attrs.name(), " for node ", - n->DebugString()); - } - std::unique_ptr fbody; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fdef, AttrSlice(&func_name_attrs.attr()), fld, &fbody)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody)); InlineFunctionBodyOptions opts; - TF_RETURN_IF_ERROR(InlineFunctionBody(*fld, g, n, fbody.get(), opts)); + Status s = InlineFunctionBody(*fld, g, n, fbody, opts); + // Inlining might fail because the function is marked with attribute + // _noinline. + s.IgnoreError(); } + // Rewrite If/While nodes. for (Node* n : g->nodes()) { if (n->type_string() == "While") { bool node_rewritten; - TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(g, n, fld, &node_rewritten)); - if (node_rewritten) { - *modified = true; - } + TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld, + &node_rewritten)); } else if (n->type_string() == "If") { bool node_rewritten; - TF_RETURN_IF_ERROR(MaybeRewriteIfNode(g, n, fld, &node_rewritten)); - if (node_rewritten) { - *modified = true; - } - } - } - - if (*modified) { - // Add rewritten FunctionDef into library. - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten)); } } - return ret_status; -} // namespace tensorflow - -Status RearrangeFunctionArgumentPass::Run( - const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - if (VLOG_IS_ON(4)) { - DumpGraphToFile("rearrange_function_argument_before", *graph, - options.flib_def); - } - std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime( - /*device_mgr=*/nullptr, options.session_options->env, - TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); - FunctionLibraryRuntime* flr = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - - // Find XLA compile ops and its corresponding FunctionDef. - static std::map* kNodeTypeToFunctionAttrMapping = - new std::map{ - // TPUReplicate ops are generated by EncapsulateTPUComputationsPass. - {"TPUReplicate", "computation"}, - // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. - {"XlaLaunch", "function"}, - }; - std::map> canonicalized_name_to_new_name; - bool fld_modified = false; - for (Node* n : graph->nodes()) { - auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); - if (it == kNodeTypeToFunctionAttrMapping->end()) { - continue; - } - const string func_attr = it->second; - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_rearrange_")); - bool modified = false; - TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - - fld_modified = true; - } - } - if (fld_modified) { - TF_RETURN_IF_ERROR( - PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def)); - } - - if (VLOG_IS_ON(4)) { - DumpGraphToFile("rearrange_function_argument_after", *graph, - options.flib_def); - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.h b/tensorflow/compiler/tf2xla/rearrange_function_argument.h new file mode 100644 index 00000000000..c553d8b6e41 --- /dev/null +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// For the given graph `g`: +// 1. Rewrite If/While node functions to rearrange arguments and return values, +// so that all resource arguments/return values are placed in the end (as +// required by XlaCompiler), +// 2. Inline StatefulPartitionedCall nodes so we do not need to rearrange +// arguments and return values. +// `get_function_body_fn` is used to instantiate FunctionDef. +// `fld` is used to store rewritten functions. +Status RearrangeFunctionArguments( + std::function + get_function_body_fn, + Graph* g, FunctionLibraryDefinition* fld); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h b/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h deleted file mode 100644 index 98ffd628c0e..00000000000 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_ - -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/graph/graph.h" - -namespace tensorflow { - -// For the function with `func_name`, rewrite any -// StatefulPartitionedCall/If/While node that does not satisfy the rules. -// We will rewrite related FunctionDef to rearrange arguments and return values, -// also adjust node's input/output edges accordingly. -Status RearrangeFunctionArgumentForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map>* canonicalized_name_to_new_name, - bool* modified); - -// TF/XLA bridge expects FunctionDef to satisfy the following rules: -// 1. DT_RESOURCE arguments are always in the last; -// 2. Do not return DT_RESOURCE as return values. -// But functions defined by Tensorflow might not satisfy them. -// This rewrite pass rewrites the function for TPUCompile/XlaLaunch node -// to follow the rules, using RearrangeFunctionArgumentForFunction() above. -class RearrangeFunctionArgumentPass : public GraphOptimizationPass { - public: - Status Run(const GraphOptimizationPassOptions& options) override; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_ diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc deleted file mode 100644 index 0661902d85e..00000000000 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h" - -namespace tensorflow { - -// This pass is required for some AOT backends and all JIT backends, so this -// file exists as a separate lib and will be linked to both AOT and JIT. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 28, - RearrangeFunctionArgumentPass); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3b87b52355c..7548442d1ad 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -1097,6 +1098,11 @@ Status XlaCompiler::CompileGraph( TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( graph.get(), options_.flib_def, local_flib_def_.get())); + TF_RETURN_IF_ERROR(RearrangeFunctionArguments( + [this](const NameAttrList& function, const FunctionBody** fbody) { + return FindFunctionBody(function, fbody); + }, + graph.get(), local_flib_def_.get())); if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,