Add function to preprocess TF graph before encapsulating XLA computations.

PiperOrigin-RevId: 217571411
This commit is contained in:
Tong Shen 2018-10-17 12:54:57 -07:00 committed by TensorFlower Gardener
parent d48968cc90
commit b021a8b041
6 changed files with 644 additions and 3 deletions

View File

@ -411,7 +411,11 @@ cc_library(
hdrs = ["encapsulate_util.h"],
deps = [
":shape_inference",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],

View File

@ -20,6 +20,10 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
namespace tensorflow {
@ -36,10 +40,319 @@ absl::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
}
}
// Adds a value to the node's list attribute.
template <typename T>
Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
std::vector<T> attr_value;
Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
if (!s.ok() && s.code() != error::NOT_FOUND) {
return s;
}
n->ClearAttr(attr_name);
attr_value.push_back(value);
n->AddAttr(attr_name, attr_value);
return Status::OK();
}
// Replaces attribute value.
template <typename T>
void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
n->ClearAttr(attr_name);
n->AddAttr(attr_name, value);
}
// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of
// PreprocessForEncapsulation() for details.
Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name) {
// Gather edges to remove. We should not remove the edge while iterating.
std::vector<const Edge*> edges_to_remove;
for (const Edge* e : g->edges()) {
if (!e->IsControlEdge()) {
continue;
}
auto src_xla_computation =
GetStringAttr(*e->src(), xla_computation_attr_name);
auto dst_xla_computation =
GetStringAttr(*e->dst(), xla_computation_attr_name);
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (!src_xla_computation && !dst_xla_computation) {
continue;
} else if (src_xla_computation && !dst_xla_computation) {
if (src_outside_compilation) {
// Case 1d: outside compilation to host computation control edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
}
} else if (!src_xla_computation && dst_xla_computation) {
if (dst_outside_compilation) {
// Case 1d: host computation control to outside compilation edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
}
} else { // src_xla_computation && dst_xla_computation
if (*src_xla_computation != *dst_xla_computation) {
if (src_outside_compilation && dst_outside_compilation) {
// Case 1c: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
} else if (src_outside_compilation && !dst_outside_compilation) {
// Case 1b: outside compilation to another XLA computaition control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
*dst_xla_computation));
} else if (!src_outside_compilation && dst_outside_compilation) {
// Case 1b: another XLA computaition to outside compilation control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
*src_xla_computation));
}
} else { // *src_xla_computation == *dst_xla_computation
if (src_outside_compilation && dst_outside_compilation) {
if (*src_outside_compilation != *dst_outside_compilation) {
// Case 1c: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
}
} else if (src_outside_compilation && !dst_outside_compilation) {
// Case 1a: outside compilation to its XLA computation control edge.
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
} else if (!src_outside_compilation && dst_outside_compilation) {
// Case 1a: XLA computation to outside compilation in it control edge.
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
}
}
}
}
for (auto e : edges_to_remove) {
g->RemoveEdge(e);
}
return Status::OK();
}
// Step 2 for PreprocessForEncapsulation(). See comments of
// PreprocessForEncapsulation() for details.
Status ProcessXlaToXlaDataEdges(Graph* g,
const string& xla_computation_attr_name,
const string& outside_compilation_attr_name) {
// Gather edges between XLA computations. Notice that we do not store `Edge*`
// directly because we remove some nodes while adding Identity nodes, and
// those Edge pointers might be invalidated.
struct EdgeInfo {
int dst_input, dst_node_id;
};
std::vector<EdgeInfo> edges;
for (const Edge* e : g->edges()) {
if (e->IsControlEdge()) {
continue;
}
auto src_xla_computation =
GetStringAttr(*e->src(), xla_computation_attr_name);
auto dst_xla_computation =
GetStringAttr(*e->dst(), xla_computation_attr_name);
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (!src_xla_computation || !dst_xla_computation) {
continue;
}
if (*src_xla_computation != *dst_xla_computation) {
if (src_outside_compilation || dst_outside_compilation) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
}
} else { // *src_xla_computation == *dst_xla_computation
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
}
}
}
// For each XLA -> XLA edge, add an Identity node between src and dst.
for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e;
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
Node* src = e->src();
int src_output = e->src_output(), dst_input = e->dst_input();
g->RemoveEdge(e);
// Create Identity node, and connect it between `src` and `dst`.
string identity_node_name =
absl::StrCat("bridge_", src->name(), "_", dst->name());
DataType dtype = src->output_type(src_output);
TF_ASSIGN_OR_RETURN(Node * identity_node,
BuildIdentityNode(g, identity_node_name, dtype, src,
/*requested_device=*/absl::nullopt));
identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name());
g->AddEdge(src, src_output, identity_node, 0);
g->AddEdge(identity_node, 0, dst, dst_input);
// Replace `e->dst()` because its input node changed.
NodeDef new_def = dst->def();
*new_def.mutable_input(dst_input) = identity_node->name();
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
// Other edge in `edges` might have `e->dst()` as src or dst
// node. Before removing `e->dst()`, replace those edges with corresponding
// edges for `dst_replace_node`.
for (int j = i + 1; j < edges.size(); j++) {
if (edges[j].dst_node_id == edges[i].dst_node_id) {
edges[j].dst_node_id = dst_replace_node->id();
}
}
}
return Status::OK();
}
// Step 3 for PreprocessForEncapsulation(). See comments of
// PreprocessForEncapsulation() for details.
Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name) {
// Gather edges between outside compilation and host computation. Notice that
// we do not store `Edge*` directly because we remove some nodes while adding
// Identity nodes, and those Edge pointers might be invalidated.
struct EdgeInfo {
int dst_input, dst_node_id;
bool is_host_to_outside_compilation;
};
std::vector<EdgeInfo> edges;
for (const Edge* e : g->edges()) {
if (e->IsControlEdge()) {
continue;
}
if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr &&
e->dst()->attrs().Find(xla_computation_attr_name) != nullptr &&
e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
/*is_host_to_outside_compilation=*/true});
VLOG(4) << "Host -> oc edge: " << e->DebugString();
} else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr &&
e->src()->attrs().Find(xla_computation_attr_name) != nullptr &&
e->src()->attrs().Find(outside_compilation_attr_name) !=
nullptr) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
/*is_host_to_outside_compilation=*/false});
VLOG(4) << "Oc -> host edge: " << e->DebugString();
}
}
// Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input.
std::map<string, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e;
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
Node* src = e->src();
int src_output = e->src_output(), dst_input = e->dst_input();
g->RemoveEdge(e);
// Find or create placeholder node.
string new_name =
edges[i].is_host_to_outside_compilation
? absl::StrCat(src->name(), "_host_to_oc_placeholder")
: absl::StrCat(src->name(), "_oc_to_host_placeholder");
auto iter = placeholders.find(new_name);
Node* placeholder_node;
if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
placeholder_builder.Attr("dtype", src->output_type(src_output));
if (edges[i].is_host_to_outside_compilation) {
placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName,
src->name());
placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName,
src_output);
// If this placeholder node is in outside compilation, we need to set
// `xla_computation_attr_name` and `outside_compilation_attr_name`.
string xla_computation_attr, outside_compilation_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name,
&xla_computation_attr));
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
outside_compilation_attr_name,
&outside_compilation_attr));
placeholder_builder.Attr(xla_computation_attr_name,
xla_computation_attr);
placeholder_builder.Attr(outside_compilation_attr_name,
outside_compilation_attr);
} else {
placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName,
src->name());
placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName,
src_output);
}
NodeDef placeholder_def;
TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
Status s;
placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s);
placeholders[new_name] = placeholder_node;
} else {
placeholder_node = iter->second;
}
g->AddEdge(placeholder_node, 0, dst, dst_input);
g->RemoveEdge(e);
// Replace `e->dst()` because its input node changed.
NodeDef new_def = dst->def();
*new_def.mutable_input(dst_input) = placeholder_node->name();
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
// Other edge in `edges` might have `e->dst()` as src or dst
// node. Before removing `e->dst()`, replace those edges with corresponding
// edges for `dst_replace_node`.
for (int j = i + 1; j < edges.size(); j++) {
if (edges[j].dst_node_id == edges[i].dst_node_id) {
edges[j].dst_node_id = dst_replace_node->id();
}
}
}
return Status::OK();
}
} // namespace
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
const char kXlaConnectedFromXlaComputationAttrName[] =
"_xla_connected_from_xla_computation";
const char kXlaConnectedToOtherXlaComputationAttrName[] =
"_xla_connected_to_other_xla_computation";
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
"_xla_connected_from_other_xla_computation";
const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies";
const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src";
const char kOutsideCompilationToHostOriginalNodeAttrName[] =
"_xla_oc_to_host_node_name";
const char kOutsideCompilationToHostSrcOutputAttrName[] =
"_xla_oc_to_host_src_output";
const char kHostToOutsideCompilationOriginalNodeAttrName[] =
"_xla_host_to_oc_node_name";
const char kHostToOutsideCompilationSrcOutputAttrName[] =
"_xla_host_to_oc_src_output";
Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name) {
@ -91,4 +404,16 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
return Status::OK();
}
Status PreprocessForEncapsulation(Graph* g,
const string& xla_computation_attr_name,
const string& outside_compilation_attr_name) {
TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name,
outside_compilation_attr_name));
TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name,
outside_compilation_attr_name));
TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
g, xla_computation_attr_name, outside_compilation_attr_name));
return Status::OK();
}
} // namespace tensorflow

View File

@ -44,6 +44,77 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
extern const char kXlaConnectedToXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
// Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA
// computation names).
extern const char kXlaConnectedToOtherXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// other XLA computation. Attribute value will be a list of string (XLA
// computation names).
extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
// Attribute indicating that this node has control dependencies on some other
// nodes. Attribute value will be a list of string (node names).
extern const char kXlaControlDependenciesAttrName[];
// Attribute indicating that this is an Identity node added to act as a bridge
// between different XLA computations. Attribute value will be string (source
// node name).
extern const char kBridgeSourceNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// int (src_output for original edge).
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string
// (original input node name).
extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for a host node. Attribute value will be int (src_output
// for original edge).
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
// Preprocesses the graph for encapsulation. It will perform the following
// operations in order:
//
// 1a. For control edges between outside compilation and its XLA computation,
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
// outside compilation node.
// 1b. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node.
// 1c. For control edges between different outside compilations, remove the edge
// and add attr "kXlaControlDependenciesAttrName = src node name" to dst
// node.
// 1d. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst
// is outside compilation, add an Identity node in between the edge. The
// identity node will have attr kBridgeSourceNodeAttrName.
// 3. For data edges between outside compilation and host computation, remove
// the edge and create a Placeholder node as dst node's input.
Status PreprocessForEncapsulation(Graph* g,
const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_

View File

@ -47,8 +47,8 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
// Check that only "add" node now has _xla_inferred_shapes attr.
std::vector<Node*> nodes_with_inferred_shape;
for (Node* n : g.nodes()) {
std::vector<Node *> nodes_with_inferred_shape;
for (Node *n : g.nodes()) {
if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
nodes_with_inferred_shape.push_back(n);
}
@ -65,4 +65,175 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
EXPECT_EQ(shape_proto.dim(0).size(), 2);
}
TEST(PreprocessForEncapsulationTest, ControlEdges) {
// Build the graph:
// "const_0" and "const_1" in host computation
// "add" = "const_0" + "const_1" in XLA computation 0
// "identity0" = "add" in XLA computation 0 & outside compilation 0
// "identity1" = "identity0" in XLA computation 0
// "identity2" = "identity1" in host computation
// "identity3" = "identity2" in XLA computation 1
// "identity4" = "identity3" in XLA computation 1 & outside compilation 1
// "identity5" = "identity4" in XLA computation 1
// "identity6" = "identity5" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set XLA computation/outside compilation attr, and add control edges.
Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
*identity0_node = node_index["identity0"],
*identity1_node = node_index["identity1"],
*identity2_node = node_index["identity2"],
*identity3_node = node_index["identity3"],
*identity4_node = node_index["identity4"],
*identity5_node = node_index["identity5"];
add_node->AddAttr("_xla", "0");
identity0_node->AddAttr("_xla", "0");
identity0_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "0");
identity3_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_oc", "0");
identity5_node->AddAttr("_xla", "1");
// Case 1a: control edges between outside compilation and its XLA computation.
g.AddControlEdge(add_node, identity0_node);
g.AddControlEdge(identity0_node, identity1_node);
// Case 1b: control edges between outside compilation and another XLA
// computation.
g.AddControlEdge(identity0_node, identity3_node);
g.AddControlEdge(identity1_node, identity4_node);
// Case 1c: control edges between different outside compilations.
g.AddControlEdge(identity0_node, identity4_node);
// Case 1d: control edges between outside compilation and host computation.
g.AddControlEdge(const0_node, identity0_node);
g.AddControlEdge(identity0_node, identity2_node);
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
// Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the
// outside compilation node.
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
kXlaConnectedFromXlaComputationAttrName));
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
kXlaConnectedToXlaComputationAttrName));
// Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node.
std::vector<string> attr;
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaConnectedToOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "1");
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "0");
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
// Case 1d: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "const_0");
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
}
TEST(PreprocessForEncapsulationTest, DataEdges) {
// Build the graph:
// "const_0" and "const_1" in host computation
// "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
// "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set XLA computation/outside compilation attr.
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
*identity0_node = node_index["identity0"],
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
*identity1_node = node_index["identity1"];
add0_node->AddAttr("_xla", "0");
add1_node->AddAttr("_xla", "0");
add1_node->AddAttr("_oc", "0");
identity0_node->AddAttr("_xla", "0");
add3_node->AddAttr("_xla", "1");
add4_node->AddAttr("_xla", "1");
add4_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "1");
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
// Check input nodes for related data edges.
node_index = g.BuildNodeNameIndex();
// Step 2: add an Identity node between different XLA computations.
Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
EXPECT_NE(bridge_add1_add3, nullptr);
string str;
TF_CHECK_OK(
GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
EXPECT_EQ(str, "add1");
Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
EXPECT_NE(bridge_identity0_add4, nullptr);
// Step 3: add placeholder for edges between host computation and outside
// compilation.
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostOriginalNodeAttrName, &str));
EXPECT_EQ(str, "add1");
int i;
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostSrcOutputAttrName, &i));
EXPECT_EQ(i, 0);
add4_node = node_index["add4"];
ASSERT_NE(add4_node, nullptr);
EXPECT_EQ(add4_node->def().input(0),
"bridge_identity0_add4_host_to_oc_placeholder");
Node *identity0_host_to_oc_placeholder =
node_index["bridge_identity0_add4_host_to_oc_placeholder"];
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationOriginalNodeAttrName, &str));
EXPECT_EQ(str, "bridge_identity0_add4");
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationSrcOutputAttrName, &i));
EXPECT_EQ(i, 0);
}
} // namespace tensorflow

View File

@ -21,7 +21,6 @@ limitations under the License.
#include <unordered_map>
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -465,4 +464,60 @@ Status CachedFunctionHandles::ReleaseAllHandles() {
return result;
}
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
// Create the replacement node.
Status s;
Node* new_node = g->AddNode(node_def, &s);
if (!s.ok()) {
return s;
}
// Record original node's output edges and remove them first. This is to avoid
// multiple producers for dst nodes' input.
std::vector<OutEdgeInfo> out_edge_info;
std::vector<const Edge*> out_edges;
for (const Edge* edge : n->out_edges()) {
out_edges.push_back(edge);
out_edge_info.push_back(
{edge->dst(), edge->src_output(), edge->dst_input()});
}
for (const Edge* edge : out_edges) {
g->RemoveEdge(edge);
}
// Add original node's input and output edges to the replacement node.
for (const Edge* in_edge : n->in_edges()) {
g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
in_edge->dst_input());
}
for (const OutEdgeInfo& out_edge : out_edge_info) {
g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
}
// Remove the original node.
g->RemoveNode(n);
return new_node;
}
xla::StatusOr<Node*> BuildIdentityNode(
Graph* graph, const string& node_name, DataType dtype, const Node* input,
absl::optional<string> requested_device) {
// Create identity node.
NodeDef ndef;
ndef.set_name(node_name);
ndef.set_op("Identity");
if (input) {
ndef.add_input(input->name());
}
if (requested_device) {
ndef.set_device(*requested_device);
}
AddNodeAttr("T", dtype, &ndef);
Status s;
Node* id_node = graph->AddNode(ndef, &s);
TF_RETURN_IF_ERROR(s);
return id_node;
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -168,6 +169,20 @@ class CachedFunctionHandles {
TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
};
// Struct for node's output edge info.
struct OutEdgeInfo {
Node* dst;
int src_output, dst_input;
};
// Replaces node `n` with a new node whose NodeDef is `node_def`.
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
// Helper function that builds an Identity node.
xla::StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
DataType dtype, const Node* input,
absl::optional<string> requested_device);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_