Fix for outside compilation.

When generating Placeholder node for host computation/outside compilation nodes, we have to generate different Placeholder node for different src_output edges.

PiperOrigin-RevId: 223544322
This commit is contained in:
Tong Shen 2018-11-30 11:12:00 -08:00 committed by TensorFlower Gardener
parent 5b14577d42
commit 92e236dd24
2 changed files with 53 additions and 14 deletions

View File

@ -241,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Remove the edge from host to outside compilation. Add a placeholder as // Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input. // outside compilation node input.
std::map<string, Node*> placeholders; std::map<std::pair<string, int>, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) { for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id); Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e; const Edge* e;
@ -253,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Find or create placeholder node. // Find or create placeholder node.
string new_name = string new_name =
edges[i].is_host_to_outside_compilation edges[i].is_host_to_outside_compilation
? absl::StrCat(src->name(), "_host_to_oc_placeholder") ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
: absl::StrCat(src->name(), "_oc_to_host_placeholder"); : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
auto iter = placeholders.find(new_name); auto placeholder_index = std::make_pair(src->name(), src_output);
auto iter = placeholders.find(placeholder_index);
Node* placeholder_node; Node* placeholder_node;
if (iter == placeholders.end()) { if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder"); NodeDefBuilder placeholder_builder(new_name, "Placeholder");
@ -288,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
Status s; Status s;
placeholder_node = g->AddNode(placeholder_def, &s); placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
placeholders[new_name] = placeholder_node; placeholders[placeholder_index] = placeholder_node;
} else { } else {
placeholder_node = iter->second; placeholder_node = iter->second;
} }
@ -642,7 +643,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
// Remove the edge from host to outside compilation. Add a placeholder as // Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input. // outside compilation node input.
std::map<string, Node*> placeholders; std::map<std::pair<string, int>, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) { for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id); Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e; const Edge* e;
@ -652,8 +653,10 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
g->RemoveEdge(e); g->RemoveEdge(e);
// Find or create placeholder node. // Find or create placeholder node.
string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder"); string new_name =
auto iter = placeholders.find(new_name); absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
auto placeholder_index = std::make_pair(src->name(), src_output);
auto iter = placeholders.find(placeholder_index);
Node* placeholder_node; Node* placeholder_node;
if (iter == placeholders.end()) { if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder"); NodeDefBuilder placeholder_builder(new_name, "Placeholder");
@ -673,7 +676,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
Status s; Status s;
placeholder_node = g->AddNode(placeholder_def, &s); placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
placeholders[new_name] = placeholder_node; placeholders[placeholder_index] = placeholder_node;
} else { } else {
placeholder_node = iter->second; placeholder_node = iter->second;
} }

View File

@ -153,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
TEST(PreprocessForEncapsulationTest, DataEdges) { TEST(PreprocessForEncapsulationTest, DataEdges) {
// Build the graph: // Build the graph:
// "const_0" and "const_1" in host computation // "const_0" and "const_1" in host computation
// "identityn0" = ("const_0", "const_1") in host computation 0
// "add0" = "const_0" + "const_1" in XLA computation 0 // "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0 // "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation // "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1 // "add3" = "add1" + "add2" in XLA computation 1
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1 // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
// outside compilation 0
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
// outside compilation 0
// "identity1" = "add4" in XLA computation 1 // "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation // "identity2" = "identity1" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
auto identityn0 =
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
{identityn0[0], identityn0[1]});
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
@ -180,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
*identity0_node = node_index["identity0"], *identity0_node = node_index["identity0"],
*add3_node = node_index["add3"], *add4_node = node_index["add4"], *add3_node = node_index["add3"], *add4_node = node_index["add4"],
*add5_node = node_index["add5"],
*identityn1_node = node_index["identityn_1"],
*identity1_node = node_index["identity1"]; *identity1_node = node_index["identity1"];
add0_node->AddAttr("_xla", "0"); add0_node->AddAttr("_xla", "0");
add1_node->AddAttr("_xla", "0"); add1_node->AddAttr("_xla", "0");
@ -188,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add3_node->AddAttr("_xla", "1"); add3_node->AddAttr("_xla", "1");
add4_node->AddAttr("_xla", "1"); add4_node->AddAttr("_xla", "1");
add4_node->AddAttr("_oc", "0"); add4_node->AddAttr("_oc", "0");
add5_node->AddAttr("_xla", "1");
add5_node->AddAttr("_oc", "0");
identityn1_node->AddAttr("_xla", "1");
identityn1_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "1"); identity1_node->AddAttr("_xla", "1");
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
@ -205,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
EXPECT_NE(bridge_identity0_add4, nullptr); EXPECT_NE(bridge_identity0_add4, nullptr);
// Step 3: add placeholder for edges between host computation and outside // Step 3: add placeholder for edges between host computation and outside
// compilation. // compilation.
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder"); EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"]; Node *add1_oc_to_host_placeholder =
node_index["add1_oc_to_host_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostOriginalNodeAttrName, &str)); kOutsideCompilationToHostOriginalNodeAttrName, &str));
EXPECT_EQ(str, "add1"); EXPECT_EQ(str, "add1");
@ -217,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add4_node = node_index["add4"]; add4_node = node_index["add4"];
ASSERT_NE(add4_node, nullptr); ASSERT_NE(add4_node, nullptr);
EXPECT_EQ(add4_node->def().input(0), EXPECT_EQ(add4_node->def().input(0),
"bridge_identity0_add4_host_to_oc_placeholder"); "bridge_identity0_add4_host_to_oc_placeholder_0");
Node *identity0_host_to_oc_placeholder = Node *identity0_host_to_oc_placeholder =
node_index["bridge_identity0_add4_host_to_oc_placeholder"]; node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationOriginalNodeAttrName, &str)); kHostToOutsideCompilationOriginalNodeAttrName, &str));
EXPECT_EQ(str, "bridge_identity0_add4"); EXPECT_EQ(str, "bridge_identity0_add4");
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationSrcOutputAttrName, &i)); kHostToOutsideCompilationSrcOutputAttrName, &i));
EXPECT_EQ(i, 0); EXPECT_EQ(i, 0);
// Check different placeholder nodes are created for different src_output.
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
EXPECT_NE(placeholder0, nullptr);
EXPECT_NE(placeholder1, nullptr);
// Check we only have 2 placeholder nodes created for "identityn_0".
int placeholder_count = 0;
for (Node *n : g.nodes()) {
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
string attr;
TF_CHECK_OK(GetNodeAttr(
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
if (attr == "identityn_0") {
++placeholder_count;
}
}
}
EXPECT_EQ(placeholder_count, 2);
} }
TEST(PostprocessForEncapsulationTest, ControlEdges) { TEST(PostprocessForEncapsulationTest, ControlEdges) {