Move function argument rearrangement from a graph pass to XlaCompiler.

1. This is only required for XLA, so it makes sense to move it into XlaCompiler;
2. We need it inside XLA compiler so TPU eager mode works (TPU eager mode does not call graph rewrite passes).

PiperOrigin-RevId: 248432264
This commit is contained in:
Tong Shen 2019-05-15 16:50:19 -07:00 committed by TensorFlower Gardener
parent 5166de391c
commit 83668b0826
9 changed files with 163 additions and 398 deletions

View File

@ -322,7 +322,6 @@ cc_library(
deps = [
":compilation_passes",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
@ -702,7 +701,7 @@ tf_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",

View File

@ -39,10 +39,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27
//
// from
// third_party/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc
// RearrangeFunctionArgumentPass: 28
//
// This pass looks at the graph and all associated FunctionDefs, and turns
// traditional control flow structure (Switch/Merge/etc.) into functional
// control flow structure (XlaIf/XlaWhile). Following passes must

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -22,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
@ -37,37 +36,7 @@ limitations under the License.
namespace tensorflow {
class RearrangeFunctionArgumentForFunctionTest : public ::testing::Test {
public:
void SetUp() override {
SessionOptions session_options;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
}
Status RearrangeFunctionArgumentTest(
const string &func_name, const string &new_func_name,
const protobuf::Map<string, tensorflow::AttrValue> &attrs,
FunctionLibraryDefinition *fld, bool *modified) {
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
return RearrangeFunctionArgumentForFunction(
func_name, new_func_name, attrs, fld, flr,
&canonicalized_name_to_new_name, modified);
}
private:
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
};
TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
FunctionDefLibrary fdl;
{
// Function for StatefulPartitionedCall's "f", If's
@ -113,40 +82,45 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
}
{
// Build the XLA computation func.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
NameAttrList f;
f.set_name("f1");
auto if_op = ops::If(s.WithOpName("if"), arg1,
std::initializer_list<Input>{arg0, arg1},
{DT_BOOL, DT_RESOURCE}, f, f);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f3");
body_fn.set_name("f2");
auto while_op =
ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
bool modified;
protobuf::Map<string, tensorflow::AttrValue> attrs;
TF_CHECK_OK(RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
attrs, &fld, &modified));
// Build the XLA computation graph.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
NameAttrList f;
f.set_name("f1");
auto if_op = ops::If(s.WithOpName("if"), arg1,
std::initializer_list<Input>{arg0, arg1},
{DT_BOOL, DT_RESOURCE}, f, f);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f3");
body_fn.set_name("f2");
auto while_op =
ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::vector<std::unique_ptr<FunctionBody>> fbodies;
TF_CHECK_OK(RearrangeFunctionArguments(
[&](const NameAttrList &function, const FunctionBody **fbody) {
std::unique_ptr<FunctionBody> new_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
AttrSlice(&function.attr()),
&fld, &new_fbody));
*fbody = new_fbody.get();
fbodies.push_back(std::move(new_fbody));
return Status::OK();
},
g.get(), &fld));
// Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
// and output types should be {DT_BOOL}.
@ -159,10 +133,7 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
// Check node "if" input and output edges.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
auto node_name_index = g->BuildNodeNameIndex();
const Node *if_node = node_name_index.at("if");
ASSERT_NE(if_node, nullptr);
const Node *input_node;
@ -170,11 +141,13 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(if_node->input_node(2, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret2_node = xla_fbody->ret_nodes[0];
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
const Node *ret0_node = node_name_index.at("ret0");
ASSERT_NE(ret0_node, nullptr);
TF_CHECK_OK(ret0_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "if");
const Node *ret3_node = xla_fbody->ret_nodes[1];
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
const Node *ret1_node = node_name_index.at("ret1");
ASSERT_NE(ret1_node, nullptr);
TF_CHECK_OK(ret1_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
// Check node "while" input and output edges.
@ -184,16 +157,18 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
EXPECT_EQ(input_node->name(), "arg1");
TF_CHECK_OK(while_node->input_node(1, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret4_node = xla_fbody->ret_nodes[2];
TF_CHECK_OK(ret4_node->input_node(0, &input_node));
const Node *ret2_node = node_name_index.at("ret2");
ASSERT_NE(ret2_node, nullptr);
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "arg0");
const Node *ret5_node = xla_fbody->ret_nodes[3];
TF_CHECK_OK(ret5_node->input_node(0, &input_node));
const Node *ret3_node = node_name_index.at("ret3");
ASSERT_NE(ret3_node, nullptr);
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "while");
}
TEST_F(RearrangeFunctionArgumentForFunctionTest,
WhileResourceRetvalFromDifferentArgUnimplemented) {
TEST(RearrangeFunctionArgumentForFunctionTest,
WhileResourceRetvalFromDifferentArgUnimplemented) {
FunctionDefLibrary fdl;
{
// Function for While's "body".
@ -227,32 +202,37 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest,
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
}
{
// Build the XLA computation func.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "arg0", "arg1" -> "while" (While)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f1");
body_fn.set_name("f2");
auto while_op = ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1, arg2},
cond_fn, body_fn);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
bool modified;
protobuf::Map<string, tensorflow::AttrValue> attrs;
Status s = RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
attrs, &fld, &modified);
EXPECT_EQ(s.code(), error::UNIMPLEMENTED);
// Build the XLA computation graph.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "arg0", "arg1" -> "while" (While)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f1");
body_fn.set_name("f2");
auto while_op = ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1, arg2},
cond_fn, body_fn);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::vector<std::unique_ptr<FunctionBody>> fbodies;
Status status = RearrangeFunctionArguments(
[&](const NameAttrList &function, const FunctionBody **fbody) {
std::unique_ptr<FunctionBody> new_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
AttrSlice(&function.attr()),
&fld, &new_fbody));
*fbody = new_fbody.get();
fbodies.push_back(std::move(new_fbody));
return Status::OK();
},
g.get(), &fld);
EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
}
} // namespace tensorflow

View File

@ -196,6 +196,7 @@ cc_library(
":tf2xla_util",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -514,12 +515,12 @@ cc_library(
)
cc_library(
name = "rearrange_function_argument_pass",
name = "rearrange_function_argument",
srcs = [
"rearrange_function_argument_pass.cc",
"rearrange_function_argument.cc",
],
hdrs = [
"rearrange_function_argument_pass.h",
"rearrange_function_argument.h",
],
deps = [
"//tensorflow/compiler/tf2xla:tf2xla_util",
@ -535,17 +536,6 @@ cc_library(
],
)
cc_library(
name = "rearrange_function_argument_pass_registration",
srcs = [
"rearrange_function_argument_pass_registration.cc",
],
deps = [
":rearrange_function_argument_pass",
],
alwayslink = 1,
)
cc_library(
name = "functionalize_control_flow_pass_registration",
srcs = [

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include <algorithm>
@ -158,8 +158,9 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
// Given mapping between original input index and rearranged input index, change
// "index" attribute for _Arg nodes.
void RearrangeArgNodes(gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
const std::vector<int>& index_mapping) {
void RearrangeArgNodes(
const gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
const std::vector<int>& index_mapping) {
for (int i = 0; i < arg_nodes->size(); i++) {
Node* n = (*arg_nodes)[i];
int new_index = index_mapping.at(i);
@ -271,8 +272,10 @@ void RearrangeRetvalNodes(
}
}
Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
Status MaybeRewriteWhileNode(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
// Check if this While node needs rewrite.
std::vector<DataType> types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
@ -303,11 +306,8 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
NameAttrList attr_value;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
const FunctionDef* fdef = fld->Find(attr_value.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody));
// Check that resource _Arg nodes for While node are always returned with
// the same index, and we don't have cases like this:
@ -375,8 +375,10 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
return Status::OK();
}
Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
Status MaybeRewriteIfNode(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
@ -428,11 +430,8 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
std::vector<string>{"then_branch", "else_branch"}) {
NameAttrList f;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
const FunctionDef* fdef = fld->Find(f.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody));
if (input_need_rearrange) {
// Change _Arg node index.
@ -501,95 +500,10 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
} // namespace
Status RearrangeFunctionArgumentForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
bool* modified) {
*modified = false;
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
Status ret_status = Status::OK();
auto cleanup_handle = gtl::MakeCleanup([&]() {
auto s = flr->ReleaseHandle(handle);
if (!s.ok()) {
ret_status.Update(s);
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
// If any node has associated functions, rewrite them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, fld);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
}
}
for (auto iter : nodes_to_associated_functions) {
Node* n = iter.first;
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
string canonicalized_name =
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
// If we already processed this function, check if it was rewritten. If
// the function was rewritten, the entry will be non-empty. Otherwise
// the entry will be empty.
function_modified = iter->second.has_value();
if (function_modified) {
new_name = iter->second.value();
}
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
// For SymbolicGradient, `name` is always "SymbolicGradient",
// which is not very informative. Use node name instead.
new_name =
fld->UniqueFunctionName(absl::StrCat(n->name(), "_rearrange_"));
} else {
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_rearrange_"));
}
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name, &function_modified));
if (function_modified) {
// If the function was rewritten, add an non-empty entry. So later we
// know we have processed this function, and it was rewritten into
// another function.
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
} else {
// If the function was not rewritten, add an empty entry. So later
// we know we have processed this function, and it does not need to be
// rewritten.
(*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
}
}
if (function_modified) {
*modified = true;
// Notice that if "n" is a function call, RewriteAssociatedFunction()
// will delete it and create a new node instead, making "n" an invalid
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
g, n, fld, associated_function, new_name));
}
}
}
Status RearrangeFunctionArguments(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, FunctionLibraryDefinition* fld) {
// Inline StatefulPartitionedCall nodes.
std::vector<Node*> call_nodes;
for (Node* n : g->nodes()) {
@ -598,114 +512,30 @@ Status RearrangeFunctionArgumentForFunction(
}
}
for (Node* n : call_nodes) {
*modified = true;
NameAttrList func_name_attrs;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs));
const FunctionDef* fdef = fld->Find(func_name_attrs.name());
if (!fdef) {
return errors::InvalidArgument("Cannot find function ",
func_name_attrs.name(), " for node ",
n->DebugString());
}
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fdef, AttrSlice(&func_name_attrs.attr()), fld, &fbody));
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody));
InlineFunctionBodyOptions opts;
TF_RETURN_IF_ERROR(InlineFunctionBody(*fld, g, n, fbody.get(), opts));
Status s = InlineFunctionBody(*fld, g, n, fbody, opts);
// Inlining might fail because the function is marked with attribute
// _noinline.
s.IgnoreError();
}
// Rewrite If/While nodes.
for (Node* n : g->nodes()) {
if (n->type_string() == "While") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
&node_rewritten));
} else if (n->type_string() == "If") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(new_func_name, functionalized_fdef));
} else {
VLOG(2) << "Adding function " << new_func_name;
TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten));
}
}
return ret_status;
} // namespace tensorflow
Status RearrangeFunctionArgumentPass::Run(
const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_before", *graph,
options.flib_def);
}
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, options.session_options->env,
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
// Find XLA compile ops and its corresponding FunctionDef.
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
new std::map<string, string>{
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"TPUReplicate", "computation"},
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
bool fld_modified = false;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
continue;
}
const string func_attr = it->second;
NameAttrList func;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
VLOG(2) << "Graph has node " << n->type_string()
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_rearrange_"));
bool modified = false;
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
&canonicalized_name_to_new_name, &modified));
if (modified) {
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
fld_modified = true;
}
}
if (fld_modified) {
TF_RETURN_IF_ERROR(
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_after", *graph,
options.flib_def);
}
return Status::OK();
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_
#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// For the given graph `g`:
// 1. Rewrite If/While node functions to rearrange arguments and return values,
// so that all resource arguments/return values are placed in the end (as
// required by XlaCompiler),
// 2. Inline StatefulPartitionedCall nodes so we do not need to rearrange
// arguments and return values.
// `get_function_body_fn` is used to instantiate FunctionDef.
// `fld` is used to store rewritten functions.
Status RearrangeFunctionArguments(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, FunctionLibraryDefinition* fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_

View File

@ -1,50 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// For the function with `func_name`, rewrite any
// StatefulPartitionedCall/If/While node that does not satisfy the rules.
// We will rewrite related FunctionDef to rearrange arguments and return values,
// also adjust node's input/output edges accordingly.
Status RearrangeFunctionArgumentForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
bool* modified);
// TF/XLA bridge expects FunctionDef to satisfy the following rules:
// 1. DT_RESOURCE arguments are always in the last;
// 2. Do not return DT_RESOURCE as return values.
// But functions defined by Tensorflow might not satisfy them.
// This rewrite pass rewrites the function for TPUCompile/XlaLaunch node
// to follow the rules, using RearrangeFunctionArgumentForFunction() above.
class RearrangeFunctionArgumentPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_

View File

@ -1,25 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
namespace tensorflow {
// This pass is required for some AOT backends and all JIT backends, so this
// file exists as a separate lib and will be linked to both AOT and JIT.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 28,
RearrangeFunctionArgumentPass);
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
@ -1097,6 +1098,11 @@ Status XlaCompiler::CompileGraph(
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
graph.get(), options_.flib_def, local_flib_def_.get()));
TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
[this](const NameAttrList& function, const FunctionBody** fbody) {
return FindFunctionBody(function, fbody);
},
graph.get(), local_flib_def_.get()));
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,