STT-tensorflow/tensorflow/compiler/tf2xla/rearrange_function_argument_pass.cc
2019-04-29 14:09:06 -07:00

767 lines
28 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/rearrange_function_argument_pass.h"
#include <algorithm>
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
// Given original input types and argument index mapping, return the new input
// types.
std::vector<DataType> ShuffleInputDataTypeAttribute(
const std::vector<DataType>& in_types,
const std::vector<int>& index_mapping) {
std::vector<DataType> result(index_mapping.size());
for (int i = 0; i < in_types.size(); i++) {
result[index_mapping.at(i)] = in_types[i];
}
return result;
}
// Given original input types, check if we need to rewrite the function (by
// checking if all DT_RESOURCE inputs are in the end). If the function needs to
// be rewritten, `resource_input_count` will be set to number of DT_RESOURCE
// inputs, and `index_mapping` will hold a mapping for original input index to
// rearranged input index.
Status InputTypesNeedsRearrange(const std::vector<DataType>& in_types,
bool* need_rewrite, int* resource_input_count,
std::vector<int>* index_mapping) {
int first_resource_index = -1;
for (int i = 0; i < in_types.size(); i++) {
DataType type = in_types[i];
if (type == DT_RESOURCE) {
first_resource_index = i;
break;
}
}
if (first_resource_index == -1) {
// No resource input. No need to rewrite.
*need_rewrite = false;
return Status::OK();
}
*need_rewrite = false;
for (int i = first_resource_index + 1; i < in_types.size(); i++) {
if (in_types[i] != DT_RESOURCE) {
*need_rewrite = true;
break;
}
}
if (!*need_rewrite) {
return Status::OK();
}
*resource_input_count = 0;
for (int i = 0; i < in_types.size(); i++) {
DataType type = in_types[i];
if (type == DT_RESOURCE) {
++(*resource_input_count);
}
}
int non_resource_index = 0,
resource_index = in_types.size() - *resource_input_count;
index_mapping->resize(in_types.size());
for (int i = 0; i < in_types.size(); i++) {
if (in_types[i] != DT_RESOURCE) {
(*index_mapping)[i] = non_resource_index;
non_resource_index++;
} else {
(*index_mapping)[i] = resource_index;
resource_index++;
}
}
return Status::OK();
}
// Given mapping between original input index and rearranged input index,
// reorder input edges for the node.
Status ReorderInputEdges(Graph* g, Node* n,
const std::vector<int>& index_mapping) {
std::vector<const Edge*> input_edges;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
input_edges.push_back(e);
}
for (const Edge* e : input_edges) {
Node* src = e->src();
int src_output = e->src_output();
int dst_input = e->dst_input();
int new_dst_input = index_mapping.at(dst_input);
g->RemoveEdge(e);
g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
}
return Status::OK();
}
// For While node, given mapping between original input index and rearranged
// input index, reorder output edges for the node. DT_RESOURCE outputs are
// removed from the node and we will use the node's corresponding input for the
// edge.
Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
int resource_input_count,
const std::vector<int>& index_mapping) {
std::vector<const Edge*> output_edges;
for (const Edge* e : n->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
output_edges.push_back(e);
}
for (const Edge* e : output_edges) {
int src_output = e->src_output();
int new_src_output = index_mapping.at(src_output);
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
if (new_src_output < input_count - resource_input_count) {
g->AddEdge(n, new_src_output, dst, dst_input);
} else {
const Edge* input_edge;
TF_RETURN_IF_ERROR(n->input_edge(new_src_output, &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
}
}
return Status::OK();
}
// Given mapping between original input index and rearranged input index, change
// "index" attribute for _Arg nodes.
void RearrangeArgNodes(gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
const std::vector<int>& index_mapping) {
for (int i = 0; i < arg_nodes->size(); i++) {
Node* n = (*arg_nodes)[i];
int new_index = index_mapping.at(i);
n->ClearAttr("index");
n->AddAttr("index", new_index);
}
}
// Given all _Retval nodes in the function, return if we need to rewrite the
// function (by checking if we have DT_RESOURCE return values). If we need to
// rewrite the function, `retval_index_mapping` will hold the mapping from
// original _Retval to rearranged _Retval, and `resource_retval_to_arg` will
// hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we
// assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly.
Status CalculateRetvalRearrange(
const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
std::map<int, int>* retval_index_mapping,
std::map<int, int>* resource_retval_to_arg) {
for (int i = 0; i < ret_nodes.size(); i++) {
Node* n = ret_nodes[i];
DataType t;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t));
if (t != DT_RESOURCE) {
int new_retval_index = retval_index_mapping->size();
retval_index_mapping->insert(std::make_pair(i, new_retval_index));
continue;
}
const Edge* e;
TF_RETURN_IF_ERROR(n->input_edge(0, &e));
if (!e->src()->IsArg()) {
return errors::Unimplemented(
"Resource _Retval node's input does not come from _Arg "
"directly: ",
e->DebugString());
}
Node* arg = e->src();
int src_index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index));
resource_retval_to_arg->insert(std::make_pair(i, src_index));
}
return Status::OK();
}
// Given original output types and return value index mapping, return the new
// output types. Notice that DT_RESOURCE will be removed.
std::vector<DataType> ShuffleOutputDataTypeAttribute(
const std::vector<DataType>& out_types,
const std::map<int, int>& index_mapping) {
std::vector<DataType> result(index_mapping.size());
for (int i = 0; i < out_types.size(); i++) {
auto iter = index_mapping.find(i);
if (iter != index_mapping.end()) {
result[iter->second] = out_types[i];
}
}
return result;
}
// For StatefulPartitionedCall node, given mapping between original input index
// and rearranged input index, reorder output edges for the node. DT_RESOURCE
// outputs are removed from the node and we will use the node's corresponding
// input for the edge.
Status RearrangeOutputEdges(Node* n, Graph* g,
const std::map<int, int>& retval_index_mapping,
const std::map<int, int>& resource_retval_to_arg) {
std::vector<const Edge*> out_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
out_edges.push_back(e);
}
}
for (const Edge* e : out_edges) {
Node* dst = e->dst();
int dst_input = e->dst_input();
int src_output = e->src_output();
auto iter = retval_index_mapping.find(src_output);
if (iter == retval_index_mapping.end()) {
TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
resource_retval_to_arg.end());
g->RemoveEdge(e);
const Edge* input_edge;
TF_RETURN_IF_ERROR(
n->input_edge(resource_retval_to_arg.at(src_output), &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
} else {
g->RemoveEdge(e);
g->AddEdge(n, iter->second, dst, dst_input);
}
}
return Status::OK();
}
// Given mapping between original output index and rearranged output index,
// change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval
// nodes will be removed.
void RearrangeRetvalNodes(
const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
Graph* g, const std::map<int, int>& retval_index_mapping) {
for (int i = 0; i < ret_nodes.size(); i++) {
Node* n = ret_nodes[i];
auto iter = retval_index_mapping.find(i);
if (iter == retval_index_mapping.end()) {
g->RemoveNode(n);
} else {
n->ClearAttr("index");
n->AddAttr("index", iter->second);
}
}
}
Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// Check if this While node needs rewrite.
std::vector<DataType> types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
types, &input_need_rearrange, &resource_input_count, &index_mapping));
if (!input_need_rearrange) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
// Modify "T" attribute for this While node.
std::vector<DataType> new_types =
ShuffleInputDataTypeAttribute(types, index_mapping);
n->ClearAttr("T");
n->AddAttr("T", new_types);
// Reorder input and output edges.
TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
TF_RETURN_IF_ERROR(ReorderOutputEdges(g, n, types.size(),
resource_input_count, index_mapping));
// Modify cond and body functions.
for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
NameAttrList attr_value;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
const FunctionDef* fdef = fld->Find(attr_value.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
// Check that resource _Arg nodes for While node are always returned with
// the same index, and we don't have cases like this:
// tf.while_loop(
// cond,
// lambda resource_var1, resource_var2: [resource_var2, resource_var1],
// [resource_var1, resource_var2])
if (attr_name == "body") {
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
Node* n = fbody->ret_nodes[i];
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
if (dtype != DT_RESOURCE) {
continue;
}
Node* input_node;
TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
while (input_node->IsIdentity()) {
TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node));
}
if (input_node->IsArg()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
if (index != i) {
return errors::Unimplemented("While node ", n->DebugString(),
" has resource _Retval[", i,
"] coming from _Arg[", index, "]");
}
} else {
return errors::Unimplemented("Encountered node ",
input_node->DebugString(),
" while tracing _Arg node for _Retval[",
i, "] of while node ", n->DebugString());
}
}
}
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
if (attr_name == "body") {
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
Node* n = fbody->ret_nodes[i];
int new_index = index_mapping.at(i);
if (new_index < types.size() - resource_input_count) {
n->ClearAttr("index");
n->AddAttr("index", new_index);
} else {
fbody->graph->RemoveNode(n);
}
}
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
attr_value.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, attr_value);
}
return Status::OK();
}
Status MaybeRewriteCallNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
std::vector<DataType> in_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
std::vector<DataType> out_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
bool has_resource_output = std::find(out_types.begin(), out_types.end(),
DT_RESOURCE) != out_types.end();
if (!resource_input_count && !has_resource_output) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
string attr_name = "f";
NameAttrList f;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
const FunctionDef* fdef = fld->Find(f.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
if (input_need_rearrange) {
// Reorder input edges.
TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
// Change Tin attribute.
std::vector<DataType> new_in_types =
ShuffleInputDataTypeAttribute(in_types, index_mapping);
n->ClearAttr("Tin");
n->AddAttr("Tin", new_in_types);
// Change _Arg node index.
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
}
if (has_resource_output) {
// Resource _Retval must come from resource _Arg directly, or we do not
// support it.
std::map<int, int> resource_retval_to_arg, retval_index_mapping;
TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
// Rearrange output edges.
TF_RETURN_IF_ERROR(RearrangeOutputEdges(n, g, retval_index_mapping,
resource_retval_to_arg));
// Change Tout attribute for the node.
std::vector<DataType> new_out_types =
ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
n->ClearAttr("Tout");
n->AddAttr("Tout", new_out_types);
// Change index for _Retval nodes.
RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph, retval_index_mapping);
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
f.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, f);
return Status::OK();
}
Status MaybeRewriteIfNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
bool* node_rewritten) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
std::vector<DataType> in_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
bool input_need_rearrange;
int resource_input_count;
std::vector<int> index_mapping;
TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
std::vector<DataType> out_types;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
bool has_resource_output = std::find(out_types.begin(), out_types.end(),
DT_RESOURCE) != out_types.end();
if (!input_need_rearrange && !has_resource_output) {
*node_rewritten = false;
return Status::OK();
}
*node_rewritten = true;
if (input_need_rearrange) {
// Reorder input edges.
std::vector<const Edge*> input_edges;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge() || e->dst_input() == 0) {
continue;
}
input_edges.push_back(e);
}
for (const Edge* e : input_edges) {
Node* src = e->src();
int src_output = e->src_output();
int dst_input = e->dst_input();
int new_dst_input = index_mapping.at(dst_input - 1) + 1;
g->RemoveEdge(e);
g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
}
// Change Tin attribute.
std::vector<DataType> new_in_types =
ShuffleInputDataTypeAttribute(in_types, index_mapping);
n->ClearAttr("Tin");
n->AddAttr("Tin", new_in_types);
}
std::map<int, int> resource_retval_to_arg, retval_index_mapping;
for (auto const& attr_name :
std::vector<string>{"then_branch", "else_branch"}) {
NameAttrList f;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
const FunctionDef* fdef = fld->Find(f.name());
TF_RET_CHECK(fdef != nullptr);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
if (input_need_rearrange) {
// Change _Arg node index.
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
}
if (has_resource_output) {
// Resource _Retval must come from resource _Arg directly, or we do
// not support it.
TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
// Change index for _Retval nodes.
RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph,
retval_index_mapping);
}
// Save the new FunctionDef.
FunctionDef new_fdef;
string new_name =
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change node to use rewritten function.
f.set_name(new_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, f);
}
if (has_resource_output) {
// Rearrange output edges.
std::vector<const Edge*> out_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
out_edges.push_back(e);
}
}
for (const Edge* e : out_edges) {
Node* dst = e->dst();
int dst_input = e->dst_input();
int src_output = e->src_output();
auto iter = retval_index_mapping.find(src_output);
if (iter == retval_index_mapping.end()) {
TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
resource_retval_to_arg.end());
g->RemoveEdge(e);
const Edge* input_edge;
TF_RETURN_IF_ERROR(n->input_edge(
resource_retval_to_arg.at(src_output) + 1, &input_edge));
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
} else {
g->RemoveEdge(e);
g->AddEdge(n, iter->second, dst, dst_input);
}
}
// Change Tout attribute for the node.
std::vector<DataType> new_out_types =
ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
n->ClearAttr("Tout");
n->AddAttr("Tout", new_out_types);
}
return Status::OK();
}
} // namespace
Status RearrangeFunctionArgumentForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
bool* modified) {
*modified = false;
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
Status ret_status = Status::OK();
auto cleanup_handle = gtl::MakeCleanup([&]() {
auto s = flr->ReleaseHandle(handle);
if (!s.ok()) {
ret_status.Update(s);
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
// If any node has associated functions, rewrite them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, fld);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
}
}
for (auto iter : nodes_to_associated_functions) {
Node* n = iter.first;
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
string canonicalized_name =
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
// If we already processed this function, check if it was rewritten. If
// the function was rewritten, the entry will be non-empty. Otherwise
// the entry will be empty.
function_modified = iter->second.has_value();
if (function_modified) {
new_name = iter->second.value();
}
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
// For SymbolicGradient, `name` is always "SymbolicGradient",
// which is not very informative. Use node name instead.
new_name =
fld->UniqueFunctionName(absl::StrCat(n->name(), "_rearrange_"));
} else {
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_rearrange_"));
}
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name, &function_modified));
if (function_modified) {
// If the function was rewritten, add an non-empty entry. So later we
// know we have processed this function, and it was rewritten into
// another function.
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
} else {
// If the function was not rewritten, add an empty entry. So later
// we know we have processed this function, and it does not need to be
// rewritten.
(*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
}
}
if (function_modified) {
*modified = true;
// Notice that if "n" is a function call, RewriteAssociatedFunction()
// will delete it and create a new node instead, making "n" an invalid
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
g, n, fld, associated_function, new_name));
}
}
}
for (Node* n : g->nodes()) {
if (n->type_string() == "While") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
} else if (n->type_string() == "StatefulPartitionedCall") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteCallNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
} else if (n->type_string() == "If") {
bool node_rewritten;
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(g, n, fld, &node_rewritten));
if (node_rewritten) {
*modified = true;
}
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(new_func_name, functionalized_fdef));
} else {
VLOG(2) << "Adding function " << new_func_name;
TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
}
}
return ret_status;
} // namespace tensorflow
Status RearrangeFunctionArgumentPass::Run(
const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_before", *graph,
options.flib_def);
}
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, options.session_options->env,
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
// Find XLA compile ops and its corresponding FunctionDef.
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
new std::map<string, string>{
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"TPUReplicate", "computation"},
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
continue;
}
const string func_attr = it->second;
NameAttrList func;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
VLOG(2) << "Graph has node " << n->type_string()
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_rearrange_"));
bool modified = false;
TF_RETURN_IF_ERROR(RearrangeFunctionArgumentForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
&canonicalized_name_to_new_name, &modified));
if (modified) {
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
}
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_after", *graph,
options.flib_def);
}
return Status::OK();
}
} // namespace tensorflow