Move function argument rearrangement from a graph pass to XlaCompiler.
1. This is only required for XLA, so it makes sense to move it into XlaCompiler; 2. We need it inside XLA compiler so TPU eager mode works (TPU eager mode does not call graph rewrite passes). PiperOrigin-RevId: 248432264
This commit is contained in:
parent
5166de391c
commit
83668b0826
@ -322,7 +322,6 @@ cc_library(
|
||||
deps = [
|
||||
":compilation_passes",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass_registration",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -702,7 +701,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument_pass",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:test_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
|
@ -39,10 +39,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
|
||||
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
|
||||
// FunctionalizeControlFlowPass: 27
|
||||
//
|
||||
// from
|
||||
// third_party/tensorflow/compiler/tf2xla/rearrange_function_argument_pass_registration.cc
|
||||
// RearrangeFunctionArgumentPass: 28
|
||||
//
|
||||
// This pass looks at the graph and all associated FunctionDefs, and turns
|
||||
// traditional control flow structure (Switch/Merge/etc.) into functional
|
||||
// control flow structure (XlaIf/XlaWhile). Following passes must
|
||||
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
@ -22,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -37,37 +36,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RearrangeFunctionArgumentForFunctionTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
SessionOptions session_options;
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
}
|
||||
|
||||
Status RearrangeFunctionArgumentTest(
|
||||
const string &func_name, const string &new_func_name,
|
||||
const protobuf::Map<string, tensorflow::AttrValue> &attrs,
|
||||
FunctionLibraryDefinition *fld, bool *modified) {
|
||||
OptimizerOptions opts;
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts,
|
||||
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
return RearrangeFunctionArgumentForFunction(
|
||||
func_name, new_func_name, attrs, fld, flr,
|
||||
&canonicalized_name_to_new_name, modified);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
};
|
||||
|
||||
TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
// Function for StatefulPartitionedCall's "f", If's
|
||||
@ -113,40 +82,45 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Build the XLA computation func.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
|
||||
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
NameAttrList f;
|
||||
f.set_name("f1");
|
||||
auto if_op = ops::If(s.WithOpName("if"), arg1,
|
||||
std::initializer_list<Input>{arg0, arg1},
|
||||
{DT_BOOL, DT_RESOURCE}, f, f);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f3");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op =
|
||||
ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
|
||||
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
|
||||
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
bool modified;
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
TF_CHECK_OK(RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
|
||||
attrs, &fld, &modified));
|
||||
// Build the XLA computation graph.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
|
||||
// "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
|
||||
NameAttrList f;
|
||||
f.set_name("f1");
|
||||
auto if_op = ops::If(s.WithOpName("if"), arg1,
|
||||
std::initializer_list<Input>{arg0, arg1},
|
||||
{DT_BOOL, DT_RESOURCE}, f, f);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f3");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op =
|
||||
ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
|
||||
auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
|
||||
auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
|
||||
std::vector<std::unique_ptr<FunctionBody>> fbodies;
|
||||
TF_CHECK_OK(RearrangeFunctionArguments(
|
||||
[&](const NameAttrList &function, const FunctionBody **fbody) {
|
||||
std::unique_ptr<FunctionBody> new_fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
|
||||
AttrSlice(&function.attr()),
|
||||
&fld, &new_fbody));
|
||||
*fbody = new_fbody.get();
|
||||
fbodies.push_back(std::move(new_fbody));
|
||||
return Status::OK();
|
||||
},
|
||||
g.get(), &fld));
|
||||
|
||||
// Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
|
||||
// and output types should be {DT_BOOL}.
|
||||
@ -159,10 +133,7 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
|
||||
|
||||
// Check node "if" input and output edges.
|
||||
std::unique_ptr<FunctionBody> xla_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
|
||||
AttrSlice(), &fld, &xla_fbody));
|
||||
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
const Node *if_node = node_name_index.at("if");
|
||||
ASSERT_NE(if_node, nullptr);
|
||||
const Node *input_node;
|
||||
@ -170,11 +141,13 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
EXPECT_EQ(input_node->name(), "arg1");
|
||||
TF_CHECK_OK(if_node->input_node(2, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret2_node = xla_fbody->ret_nodes[0];
|
||||
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
|
||||
const Node *ret0_node = node_name_index.at("ret0");
|
||||
ASSERT_NE(ret0_node, nullptr);
|
||||
TF_CHECK_OK(ret0_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "if");
|
||||
const Node *ret3_node = xla_fbody->ret_nodes[1];
|
||||
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
|
||||
const Node *ret1_node = node_name_index.at("ret1");
|
||||
ASSERT_NE(ret1_node, nullptr);
|
||||
TF_CHECK_OK(ret1_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
|
||||
// Check node "while" input and output edges.
|
||||
@ -184,16 +157,18 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
EXPECT_EQ(input_node->name(), "arg1");
|
||||
TF_CHECK_OK(while_node->input_node(1, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret4_node = xla_fbody->ret_nodes[2];
|
||||
TF_CHECK_OK(ret4_node->input_node(0, &input_node));
|
||||
const Node *ret2_node = node_name_index.at("ret2");
|
||||
ASSERT_NE(ret2_node, nullptr);
|
||||
TF_CHECK_OK(ret2_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "arg0");
|
||||
const Node *ret5_node = xla_fbody->ret_nodes[3];
|
||||
TF_CHECK_OK(ret5_node->input_node(0, &input_node));
|
||||
const Node *ret3_node = node_name_index.at("ret3");
|
||||
ASSERT_NE(ret3_node, nullptr);
|
||||
TF_CHECK_OK(ret3_node->input_node(0, &input_node));
|
||||
EXPECT_EQ(input_node->name(), "while");
|
||||
}
|
||||
|
||||
TEST_F(RearrangeFunctionArgumentForFunctionTest,
|
||||
WhileResourceRetvalFromDifferentArgUnimplemented) {
|
||||
TEST(RearrangeFunctionArgumentForFunctionTest,
|
||||
WhileResourceRetvalFromDifferentArgUnimplemented) {
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
// Function for While's "body".
|
||||
@ -227,32 +202,37 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest,
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Build the XLA computation func.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "while" (While)
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f1");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op = ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1, arg2},
|
||||
cond_fn, body_fn);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
bool modified;
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
Status s = RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
|
||||
attrs, &fld, &modified);
|
||||
EXPECT_EQ(s.code(), error::UNIMPLEMENTED);
|
||||
// Build the XLA computation graph.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "while" (While)
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f1");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op = ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1, arg2},
|
||||
cond_fn, body_fn);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
|
||||
std::vector<std::unique_ptr<FunctionBody>> fbodies;
|
||||
Status status = RearrangeFunctionArguments(
|
||||
[&](const NameAttrList &function, const FunctionBody **fbody) {
|
||||
std::unique_ptr<FunctionBody> new_fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
|
||||
AttrSlice(&function.attr()),
|
||||
&fld, &new_fbody));
|
||||
*fbody = new_fbody.get();
|
||||
fbodies.push_back(std::move(new_fbody));
|
||||
return Status::OK();
|
||||
},
|
||||
g.get(), &fld);
|
||||
EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -196,6 +196,7 @@ cc_library(
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -514,12 +515,12 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rearrange_function_argument_pass",
|
||||
name = "rearrange_function_argument",
|
||||
srcs = [
|
||||
"rearrange_function_argument_pass.cc",
|
||||
"rearrange_function_argument.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"rearrange_function_argument_pass.h",
|
||||
"rearrange_function_argument.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
@ -535,17 +536,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rearrange_function_argument_pass_registration",
|
||||
srcs = [
|
||||
"rearrange_function_argument_pass_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":rearrange_function_argument_pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "functionalize_control_flow_pass_registration",
|
||||
srcs = [
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -158,8 +158,9 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
|
||||
|
||||
// Given mapping between original input index and rearranged input index, change
|
||||
// "index" attribute for _Arg nodes.
|
||||
void RearrangeArgNodes(gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
|
||||
const std::vector<int>& index_mapping) {
|
||||
void RearrangeArgNodes(
|
||||
const gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
|
||||
const std::vector<int>& index_mapping) {
|
||||
for (int i = 0; i < arg_nodes->size(); i++) {
|
||||
Node* n = (*arg_nodes)[i];
|
||||
int new_index = index_mapping.at(i);
|
||||
@ -271,8 +272,10 @@ void RearrangeRetvalNodes(
|
||||
}
|
||||
}
|
||||
|
||||
Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
bool* node_rewritten) {
|
||||
Status MaybeRewriteWhileNode(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
|
||||
// Check if this While node needs rewrite.
|
||||
std::vector<DataType> types;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
|
||||
@ -303,11 +306,8 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
|
||||
NameAttrList attr_value;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
|
||||
const FunctionDef* fdef = fld->Find(attr_value.name());
|
||||
TF_RET_CHECK(fdef != nullptr);
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody));
|
||||
|
||||
// Check that resource _Arg nodes for While node are always returned with
|
||||
// the same index, and we don't have cases like this:
|
||||
@ -375,8 +375,10 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
bool* node_rewritten) {
|
||||
Status MaybeRewriteIfNode(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
|
||||
// This node needs rewrite when either of these is true:
|
||||
// 1) Tin has DT_RESOURCE which requires rearrange;
|
||||
// 2) Tout has DT_RESOURCE.
|
||||
@ -428,11 +430,8 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
std::vector<string>{"then_branch", "else_branch"}) {
|
||||
NameAttrList f;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
|
||||
const FunctionDef* fdef = fld->Find(f.name());
|
||||
TF_RET_CHECK(fdef != nullptr);
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody));
|
||||
|
||||
if (input_need_rearrange) {
|
||||
// Change _Arg node index.
|
||||
@ -501,95 +500,10 @@ Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RearrangeFunctionArgumentForFunction(
|
||||
const string& func_name, const string& new_func_name,
|
||||
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
|
||||
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
|
||||
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
|
||||
// Convert the function to Graph.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
|
||||
Status ret_status = Status::OK();
|
||||
auto cleanup_handle = gtl::MakeCleanup([&]() {
|
||||
auto s = flr->ReleaseHandle(handle);
|
||||
if (!s.ok()) {
|
||||
ret_status.Update(s);
|
||||
}
|
||||
});
|
||||
const FunctionBody* body = flr->GetFunctionBody(handle);
|
||||
Graph* g = body->graph;
|
||||
|
||||
// If any node has associated functions, rewrite them first.
|
||||
// Gather nodes with associated functions first, because rewriting those nodes
|
||||
// might involve node deletion/addition. Avoid modifying nodes while iterating
|
||||
// it.
|
||||
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
|
||||
nodes_to_associated_functions;
|
||||
for (auto* n : g->nodes()) {
|
||||
auto associated_functions = GetAssociatedFunctions(*n, fld);
|
||||
if (!associated_functions.empty()) {
|
||||
nodes_to_associated_functions.push_back({n, associated_functions});
|
||||
}
|
||||
}
|
||||
for (auto iter : nodes_to_associated_functions) {
|
||||
Node* n = iter.first;
|
||||
auto associated_functions = iter.second;
|
||||
for (auto& associated_function : associated_functions) {
|
||||
string name = associated_function.func_name();
|
||||
string canonicalized_name =
|
||||
Canonicalize(name, AttrSlice(&associated_function.attrs()));
|
||||
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
|
||||
string new_name;
|
||||
bool function_modified;
|
||||
if (iter != canonicalized_name_to_new_name->end()) {
|
||||
// If we already processed this function, check if it was rewritten. If
|
||||
// the function was rewritten, the entry will be non-empty. Otherwise
|
||||
// the entry will be empty.
|
||||
function_modified = iter->second.has_value();
|
||||
if (function_modified) {
|
||||
new_name = iter->second.value();
|
||||
}
|
||||
} else {
|
||||
if (associated_function.type() ==
|
||||
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
|
||||
// For SymbolicGradient, `name` is always "SymbolicGradient",
|
||||
// which is not very informative. Use node name instead.
|
||||
new_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(n->name(), "_rearrange_"));
|
||||
} else {
|
||||
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_rearrange_"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
|
||||
name, new_name, associated_function.attrs(), fld, flr,
|
||||
canonicalized_name_to_new_name, &function_modified));
|
||||
if (function_modified) {
|
||||
// If the function was rewritten, add an non-empty entry. So later we
|
||||
// know we have processed this function, and it was rewritten into
|
||||
// another function.
|
||||
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
|
||||
} else {
|
||||
// If the function was not rewritten, add an empty entry. So later
|
||||
// we know we have processed this function, and it does not need to be
|
||||
// rewritten.
|
||||
(*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
|
||||
}
|
||||
}
|
||||
if (function_modified) {
|
||||
*modified = true;
|
||||
|
||||
// Notice that if "n" is a function call, RewriteAssociatedFunction()
|
||||
// will delete it and create a new node instead, making "n" an invalid
|
||||
// pointer. That's fine because in that case, associated_functions will
|
||||
// only have one member and the loop will only run once.
|
||||
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
|
||||
g, n, fld, associated_function, new_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status RearrangeFunctionArguments(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, FunctionLibraryDefinition* fld) {
|
||||
// Inline StatefulPartitionedCall nodes.
|
||||
std::vector<Node*> call_nodes;
|
||||
for (Node* n : g->nodes()) {
|
||||
@ -598,114 +512,30 @@ Status RearrangeFunctionArgumentForFunction(
|
||||
}
|
||||
}
|
||||
for (Node* n : call_nodes) {
|
||||
*modified = true;
|
||||
|
||||
NameAttrList func_name_attrs;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs));
|
||||
const FunctionDef* fdef = fld->Find(func_name_attrs.name());
|
||||
if (!fdef) {
|
||||
return errors::InvalidArgument("Cannot find function ",
|
||||
func_name_attrs.name(), " for node ",
|
||||
n->DebugString());
|
||||
}
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*fdef, AttrSlice(&func_name_attrs.attr()), fld, &fbody));
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody));
|
||||
InlineFunctionBodyOptions opts;
|
||||
TF_RETURN_IF_ERROR(InlineFunctionBody(*fld, g, n, fbody.get(), opts));
|
||||
Status s = InlineFunctionBody(*fld, g, n, fbody, opts);
|
||||
// Inlining might fail because the function is marked with attribute
|
||||
// _noinline.
|
||||
s.IgnoreError();
|
||||
}
|
||||
|
||||
// Rewrite If/While nodes.
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->type_string() == "While") {
|
||||
bool node_rewritten;
|
||||
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(g, n, fld, &node_rewritten));
|
||||
if (node_rewritten) {
|
||||
*modified = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
|
||||
&node_rewritten));
|
||||
} else if (n->type_string() == "If") {
|
||||
bool node_rewritten;
|
||||
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(g, n, fld, &node_rewritten));
|
||||
if (node_rewritten) {
|
||||
*modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (*modified) {
|
||||
// Add rewritten FunctionDef into library.
|
||||
FunctionDef functionalized_fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
|
||||
if (func_name == new_func_name) {
|
||||
VLOG(2) << "Replacing function " << func_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->ReplaceFunction(new_func_name, functionalized_fdef));
|
||||
} else {
|
||||
VLOG(2) << "Adding function " << new_func_name;
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
|
||||
MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten));
|
||||
}
|
||||
}
|
||||
|
||||
return ret_status;
|
||||
} // namespace tensorflow
|
||||
|
||||
Status RearrangeFunctionArgumentPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
Graph* graph = options.graph->get();
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("rearrange_function_argument_before", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
/*device_mgr=*/nullptr, options.session_options->env,
|
||||
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
// Find XLA compile ops and its corresponding FunctionDef.
|
||||
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
|
||||
new std::map<string, string>{
|
||||
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
|
||||
{"TPUReplicate", "computation"},
|
||||
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
bool fld_modified = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
|
||||
if (it == kNodeTypeToFunctionAttrMapping->end()) {
|
||||
continue;
|
||||
}
|
||||
const string func_attr = it->second;
|
||||
NameAttrList func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
|
||||
VLOG(2) << "Graph has node " << n->type_string()
|
||||
<< ". Corresponding function: " << func.name();
|
||||
string new_func_name = options.flib_def->UniqueFunctionName(
|
||||
absl::StrCat(func.name(), "_rearrange_"));
|
||||
bool modified = false;
|
||||
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
|
||||
func.name(), new_func_name, func.attr(), options.flib_def, flr,
|
||||
&canonicalized_name_to_new_name, &modified));
|
||||
if (modified) {
|
||||
n->ClearAttr(func_attr);
|
||||
func.set_name(new_func_name);
|
||||
n->AddAttr(func_attr, func);
|
||||
|
||||
fld_modified = true;
|
||||
}
|
||||
}
|
||||
if (fld_modified) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("rearrange_function_argument_after", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
39
tensorflow/compiler/tf2xla/rearrange_function_argument.h
Normal file
39
tensorflow/compiler/tf2xla/rearrange_function_argument.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// For the given graph `g`:
|
||||
// 1. Rewrite If/While node functions to rearrange arguments and return values,
|
||||
// so that all resource arguments/return values are placed in the end (as
|
||||
// required by XlaCompiler),
|
||||
// 2. Inline StatefulPartitionedCall nodes so we do not need to rearrange
|
||||
// arguments and return values.
|
||||
// `get_function_body_fn` is used to instantiate FunctionDef.
|
||||
// `fld` is used to store rewritten functions.
|
||||
Status RearrangeFunctionArguments(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, FunctionLibraryDefinition* fld);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_
|
@ -1,50 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// For the function with `func_name`, rewrite any
|
||||
// StatefulPartitionedCall/If/While node that does not satisfy the rules.
|
||||
// We will rewrite related FunctionDef to rearrange arguments and return values,
|
||||
// also adjust node's input/output edges accordingly.
|
||||
Status RearrangeFunctionArgumentForFunction(
|
||||
const string& func_name, const string& new_func_name,
|
||||
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
|
||||
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
|
||||
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
|
||||
bool* modified);
|
||||
|
||||
// TF/XLA bridge expects FunctionDef to satisfy the following rules:
|
||||
// 1. DT_RESOURCE arguments are always in the last;
|
||||
// 2. Do not return DT_RESOURCE as return values.
|
||||
// But functions defined by Tensorflow might not satisfy them.
|
||||
// This rewrite pass rewrites the function for TPUCompile/XlaLaunch node
|
||||
// to follow the rules, using RearrangeFunctionArgumentForFunction() above.
|
||||
class RearrangeFunctionArgumentPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_PASS_H_
|
@ -1,25 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass is required for some AOT backends and all JIT backends, so this
|
||||
// file exists as a separate lib and will be linked to both AOT and JIT.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 28,
|
||||
RearrangeFunctionArgumentPass);
|
||||
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
@ -1097,6 +1098,11 @@ Status XlaCompiler::CompileGraph(
|
||||
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
|
||||
graph.get(), options_.flib_def, local_flib_def_.get()));
|
||||
TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
|
||||
[this](const NameAttrList& function, const FunctionBody** fbody) {
|
||||
return FindFunctionBody(function, fbody);
|
||||
},
|
||||
graph.get(), local_flib_def_.get()));
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "XlaCompiler::CompileGraph: "
|
||||
<< DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
|
||||
|
Loading…
Reference in New Issue
Block a user